diff options
256 files changed, 9727 insertions, 3255 deletions
diff --git a/configure.py b/configure.py index ad585fa52e..5243e09b24 100644 --- a/configure.py +++ b/configure.py @@ -1134,7 +1134,9 @@ def set_tf_nccl_install_path(environ_cp): nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') - if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + nccl_license_path = os.path.join(nccl_install_path, 'NCCL-SLA.txt') + if os.path.exists(nccl_lib_path) and os.path.exists( + nccl_hdr_path) and os.path.exists(nccl_license_path): # Set NCCL_INSTALL_PATH environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f362900387..67749ec04e 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -603,3 +603,13 @@ py_library( visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], ) + +cc_library( + name = "grpc", + deps = ["@grpc"], +) + +cc_library( + name = "grpc++", + deps = ["@grpc//:grpc++"], +) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index a8ad8e4b94..5c218d3f25 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2068,7 +2068,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return nullptr; } @@ -2098,7 +2099,8 @@ void TF_GraphImportGraphDefWithReturnOutputs( return; } GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index d976f8296c..c2245b8eae 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -176,9 +176,11 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:queue_op", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 11e45d2823..a605335a94 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,9 +23,11 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" @@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel { .Device(DEVICE) \ .HostMemory("input") \ .HostMemory("output"), \ - LoopCondOp); + LoopCondOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \ + REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \ + .Device(DEVICE) \ + .HostMemory("size") \ + .HostMemory("handle"), \ + QueueSizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \ + QueueIsClosedOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); + +// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read +// and write the tensors they access in order to concatenate them into a batch. +// We would need either to call out to an XLA computation to perform the +// concatenation, or we would need to refactor those kernels so the splitting +// or merging is done in a separate operator that can be compiled. } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c1f65416b4..366822f0b7 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -372,6 +372,20 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "fifo_queue_test", + size = "medium", + srcs = ["fifo_queue_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index 9a93b32164..d775850a80 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,7 +28,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import adagrad -class AdagradOptimizerTest(XLATestCase): +class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 3215dc36e5..03554d6933 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops @@ -48,7 +48,7 @@ def adam_update_numpy(param, return param_t, m_t, v_t -class AdamOptimizerTest(XLATestCase): +class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index afef36d9d2..9cb3d04546 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -class BinaryOpsTest(XLATestCase): +class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index fde9759a1c..ef4d5f6322 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -26,7 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class BucketizationOpTest(XLATestCase): +class BucketizationOpTest(xla_test.XLATestCase): def testInt(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 035cdea178..a4e7f75081 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -22,7 +22,7 @@ import collections import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest # TODO(srvasude): Merge this with # third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py. -class CategoricalTest(XLATestCase): +class CategoricalTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def output_dtypes(self): diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 1a8989d7c2..d2867278af 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -23,7 +23,7 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CholeskyOpTest(XLATestCase): +class CholeskyOpTest(xla_test.XLATestCase): # Cholesky defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/cholesky) diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 574f82fc71..e42ebf8f9e 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" -class ClusteringTest(XLATestCase): +class ClusteringTest(xla_test.XLATestCase): def testAdd(self): val1 = np.array([4, 3, 2, 1], dtype=np.float32) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index f10973e19f..d9ad428147 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ConcatTest(XLATestCase): +class ConcatTest(xla_test.XLATestCase): def testHStack(self): with self.test_session(): @@ -292,7 +292,7 @@ class ConcatTest(XLATestCase): array_ops.concat([scalar, scalar, scalar], dim) -class ConcatOffsetTest(XLATestCase): +class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: @@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase): self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) -class PackTest(XLATestCase): +class PackTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index d12e1ff1e8..98d41ba7ed 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -26,7 +26,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import test_utils -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops @@ -42,7 +42,7 @@ DATA_FORMATS = ( ) -class Conv2DTest(XLATestCase, parameterized.TestCase): +class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -236,7 +236,7 @@ class Conv2DTest(XLATestCase, parameterized.TestCase): expected=np.reshape([108, 128], [1, 1, 1, 2])) -class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase): +class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -534,7 +534,7 @@ class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase): expected=[5, 0, 11, 0, 0, 0, 17, 0, 23]) -class Conv2DBackpropFilterTest(XLATestCase, parameterized.TestCase): +class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 3bebf46511..31ee41f04f 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest # Test cloned from # tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py -class Conv3DBackpropFilterV2GradTest(XLATestCase): +class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): with self.test_session(), self.test_scope(): @@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase): # Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py -class Conv3DTransposeTest(XLATestCase): +class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 03d96a2cd8..98dc73e189 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -114,7 +114,7 @@ def CheckGradConfigsToTest(): yield i, f, o, s, p -class DepthwiseConv2DTest(XLATestCase): +class DepthwiseConv2DTest(xla_test.XLATestCase): # This is testing that depthwise_conv2d and depthwise_conv2d_native # produce the same results. It also tests that NCHW and NWHC diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 6a46d2ec3e..154e36b10e 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class DynamicUpdateSliceOpsTest(XLATestCase): +class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index c109c27abe..edd78153b5 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.platform import googletest -class DynamicStitchTest(XLATestCase): +class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index e438832a23..3524666499 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -40,7 +40,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import adam -class EagerTest(XLATestCase): +class EagerTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): @@ -286,7 +286,7 @@ class EagerTest(XLATestCase): [2.0, 2.0]], embedding_matrix.numpy()) -class EagerFunctionTest(XLATestCase): +class EagerFunctionTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): @@ -419,7 +419,7 @@ class EagerFunctionTest(XLATestCase): self.assertAllEqual((2, 3, 4), dz.shape.as_list()) -class ExcessivePaddingTest(XLATestCase): +class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. Tensors that would normally be excessively padded when written diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 0361702e7a..5529fdbb09 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ExtractImagePatches(XLATestCase): +class ExtractImagePatches(xla_test.XLATestCase): """Functional tests for ExtractImagePatches op.""" def _VerifyValues(self, image, ksizes, strides, rates, padding, patches): diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index dfe9400ef0..c48ab178bf 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -17,14 +17,14 @@ from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import googletest -class FakeQuantWithMinMaxArgsTest(XLATestCase): +class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgs operation.""" # 8 bits, wide range. @@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): +class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgsGradient operation.""" # 8 bits, wide range. @@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsTest(XLATestCase): +class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVars operation.""" # 8 bits, wide range. @@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsGradientTest(XLATestCase): +class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVarsGradient operation.""" # 8 bits, wide range. diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index b2360dd009..c64ea249ec 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -23,7 +23,7 @@ import itertools import numpy as np import scipy.signal as sps -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -58,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) -class FFTTest(XLATestCase): +class FFTTest(xla_test.XLATestCase): def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, tf_method): diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py new file mode 100644 index 0000000000..0f64cc87cd --- /dev/null +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -0,0 +1,201 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.platform import test + + +class FIFOQueueTest(xla_test.XLATestCase): + + def testEnqueue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + enqueue_op.run() + + def testEnqueueWithShape(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) + enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) + enqueue_correct_op.run() + with self.assertRaises(ValueError): + q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],)) + self.assertEqual(1, q.size().eval()) + + def testMultipleDequeues(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue([1])) + self.evaluate(q.enqueue([2])) + self.evaluate(q.enqueue([3])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + def testQueuesDontShare(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) + + def testEnqueueDictWithoutNames(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + with self.assertRaisesRegexp(ValueError, "must have names"): + q.enqueue({"a": 12.0}) + + def testParallelEnqueue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Run one producer thread for each element in elems. + def enqueue(enqueue_op): + sess.run(enqueue_op) + + threads = [ + self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Dequeue every element using a single thread. + results = [] + for _ in xrange(len(elems)): + results.append(dequeued_t.eval()) + self.assertItemsEqual(elems, results) + + def testParallelDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Enqueue every element using a single thread. + for enqueue_op in enqueue_ops: + enqueue_op.run() + + # Run one consumer thread for each element in elems. + results = [] + + def dequeue(): + results.append(sess.run(dequeued_t)) + + threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertItemsEqual(elems, results) + + def testDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + vals = dequeued_t.eval() + self.assertEqual([elems[i]], vals) + + def testEnqueueAndBlockingDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + def enqueue(): + # The enqueue_ops should run after the dequeue op has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + for enqueue_op in enqueue_ops: + sess.run(enqueue_op) + + results = [] + + def dequeue(): + for _ in xrange(len(elems)): + results.append(sess.run(dequeued_t)) + + enqueue_thread = self.checkedThread(target=enqueue) + dequeue_thread = self.checkedThread(target=dequeue) + enqueue_thread.start() + dequeue_thread.start() + enqueue_thread.join() + dequeue_thread.join() + + for elem, result in zip(elems, results): + self.assertEqual([elem], result) + + def testMultiEnqueueAndDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) + elems = [(5, 10.0), (10, 20.0), (15, 30.0)] + enqueue_ops = [q.enqueue((x, y)) for x, y in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + x_val, y_val = sess.run(dequeued_t) + x, y = elems[i] + self.assertEqual([x], x_val) + self.assertEqual([y], y_val) + + def testQueueSizeEmpty(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + self.assertEqual([0], q.size().eval()) + + def testQueueSizeAfterEnqueueAndDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + dequeued_t = q.dequeue() + size = q.size() + self.assertEqual([], size.get_shape()) + + enqueue_op.run() + self.assertEqual(1, size.eval()) + dequeued_t.op.run() + self.assertEqual(0, size.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 8e6407dffd..1da97fd512 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent -class FtrlOptimizerTest(XLATestCase): +class FtrlOptimizerTest(xla_test.XLATestCase): def initVariableAndGradient(self, dtype): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 8a3f4b0bdc..04fba44446 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class FunctionTest(XLATestCase): +class FunctionTest(xla_test.XLATestCase): def testFunction(self): """Executes a simple TensorFlow function.""" diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 5782e76734..132e42ac7a 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -22,7 +22,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import test_utils -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker @@ -30,7 +30,7 @@ from tensorflow.python.ops import nn from tensorflow.python.platform import test -class FusedBatchNormTest(XLATestCase, parameterized.TestCase): +class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): def _reference_training(self, x, scale, offset, epsilon, data_format): if data_format != "NHWC": diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 9378b1db72..23b0aed34f 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GatherNdTest(XLATestCase): +class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): with self.test_session(): diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 1a8c451911..e9c8ef7c91 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual( [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) + def testGatherPrecision(self): + with self.test_session() as session, self.test_scope(): + data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], + [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) + indices = np.array([1, 2, 3, 1]) + dtype = dtypes.float32 + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices_tf = constant_op.constant(indices) + gather_t = array_ops.gather(params, indices_tf) + gather_val = session.run(gather_t, feed_dict={params: params_np}) + np_val = params_np[indices] + self.assertAllEqual(np_val, gather_val) + class GatherBenchmark(test.Benchmark): """Microbenchmarks for the gather op.""" diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 7cf953ef25..8b01ef96db 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -25,7 +25,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +41,7 @@ def GenerateNumpyRandomRGB(shape): return np.random.randint(0, 256, shape) / 256. -class RGBToHSVTest(XLATestCase): +class RGBToHSVTest(xla_test.XLATestCase): def testBatch(self): # Build an arbitrary RGB image @@ -104,7 +104,7 @@ class RGBToHSVTest(XLATestCase): self.assertAllCloseAccordingToType(hsv_tf, hsv_np) -class AdjustContrastTest(XLATestCase): +class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): with self.test_session(): @@ -168,7 +168,7 @@ class AdjustContrastTest(XLATestCase): self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) -class AdjustHueTest(XLATestCase): +class AdjustHueTest(xla_test.XLATestCase): def testAdjustNegativeHue(self): x_shape = [2, 2, 3] @@ -303,7 +303,7 @@ class AdjustHueTest(XLATestCase): self._adjustHueTf(x_np, delta_h) -class AdjustSaturationTest(XLATestCase): +class AdjustSaturationTest(xla_test.XLATestCase): def _adjust_saturation(self, image, saturation_factor): image = ops.convert_to_tensor(image, name="image") @@ -403,7 +403,7 @@ class AdjustSaturationTest(XLATestCase): self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) -class ResizeBilinearTest(XLATestCase): +class ResizeBilinearTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 69bd8f7230..253b45902f 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -22,7 +22,7 @@ import copy import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" # Local response normalization tests. The forward tests are copied from # tensorflow/python/kernel_tests/lrn_op_test.py -class LRNTest(XLATestCase): +class LRNTest(xla_test.XLATestCase): def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 29394f9ea5..0d9f99f8a6 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -19,14 +19,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase): def _testMatrixBandPart(self, dtype, shape): with self.test_session(): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 5819b2bf2b..2bb8a97bda 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -22,7 +22,7 @@ import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -35,7 +35,7 @@ def MakePlaceholder(x): return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) -class MatrixTriangularSolveOpTest(XLATestCase): +class MatrixTriangularSolveOpTest(xla_test.XLATestCase): # MatrixTriangularSolve defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index af9394e7d7..c2592c54cf 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import momentum as momentum_lib -class MomentumOptimizerTest(XLATestCase): +class MomentumOptimizerTest(xla_test.XLATestCase): def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): var += accum * lr * momentum diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index e4843b169b..da08225e9f 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -22,14 +22,14 @@ import unittest import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class NAryOpsTest(XLATestCase): +class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 6f588d8ab5..2f9122645d 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import googletest -class NullaryOpsTest(XLATestCase): +class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 5e6d1313bd..a75d99189b 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -class PlaceholderTest(XLATestCase): +class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): with self.test_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index d9285186ba..17f860db61 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): padding=padding) -class Pooling3DTest(XLATestCase): +class Pooling3DTest(xla_test.XLATestCase): def _VerifyValues(self, pool_func, input_sizes, window, strides, padding, expected): diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index fe270af3d6..9fc94752ea 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -69,7 +69,7 @@ def GetTestConfigs(): return test_configs -class PoolingTest(XLATestCase): +class PoolingTest(xla_test.XLATestCase): def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, data_format, expected): @@ -288,7 +288,7 @@ class PoolingTest(XLATestCase): expected=expected_output) -class PoolGradTest(XLATestCase): +class PoolGradTest(xla_test.XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 2e71b00ba6..b880b2a3fe 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -22,7 +22,7 @@ import math import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -31,7 +31,7 @@ from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import googletest -class RandomOpsTest(XLATestCase): +class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 7420724bdb..cea2ec816f 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -22,7 +22,7 @@ import functools import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ReduceOpsTest(XLATestCase): +class ReduceOpsTest(xla_test.XLATestCase): def _testReduction(self, tf_reduce_fn, @@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) -class ReduceOpPrecisionTest(XLATestCase): +class ReduceOpPrecisionTest(xla_test.XLATestCase): def _testReduceSum(self, expected_result, diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index e78a63465b..c69b6837b0 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReduceWindowTest(XLATestCase): +class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index 18fabca28c..d01c676e7c 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -21,14 +21,14 @@ from __future__ import print_function import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReverseOpsTest(XLATestCase): +class ReverseOpsTest(xla_test.XLATestCase): def testReverseOneDim(self): shape = (7, 5, 9, 11) diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 1a5d05094e..ccfa630016 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ReverseSequenceTest(XLATestCase): +class ReverseSequenceTest(xla_test.XLATestCase): def _testReverseSequence(self, x, diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index ecdce4f052..9489fded32 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,7 +28,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import rmsprop -class RmspropTest(XLATestCase): +class RmspropTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 3260e63b23..4292352e76 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse): return x -class CumsumTest(XLATestCase): +class CumsumTest(xla_test.XLATestCase): valid_dtypes = [np.float32] @@ -147,7 +147,7 @@ class CumsumTest(XLATestCase): math_ops.cumsum(input_tensor, [0]).eval() -class CumprodTest(XLATestCase): +class CumprodTest(xla_test.XLATestCase): valid_dtypes = [np.float32] diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 638946e234..f606f88545 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -22,7 +22,7 @@ import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape): return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) -class ScatterNdTest(XLATestCase): +class ScatterNdTest(xla_test.XLATestCase): def _VariableRankTest(self, np_scatter, diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 305ca0c6b7..6c4890565d 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class SliceTest(XLATestCase): +class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: @@ -110,7 +110,7 @@ class SliceTest(XLATestCase): self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result) -class StridedSliceTest(XLATestCase): +class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index f37c34156f..c685bc548f 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -68,7 +68,7 @@ def space_to_batch_direct(input_array, block_shape, paddings): return permuted_reshaped_padded.reshape(output_shape) -class SpaceToBatchTest(XLATestCase): +class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): @@ -149,7 +149,7 @@ class SpaceToBatchTest(XLATestCase): self._testOne(x_np, block_size, x_out) -class SpaceToBatchNDTest(XLATestCase): +class SpaceToBatchNDTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" def _testPad(self, inputs, block_shape, paddings, outputs): diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 94342f9567..b7dd787fef 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.platform import test -class StackOpTest(XLATestCase): +class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index abce190d83..d162675ef8 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -22,7 +22,7 @@ import math import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib import stateless from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test -class StatelessRandomOpsTest(XLATestCase): +class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index ef047005b6..effa5a59fe 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class TernaryOpsTest(XLATestCase): +class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a24abd7547..6a7011aea6 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -23,7 +23,7 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops @@ -44,7 +44,7 @@ def nhwc_to_format(x, data_format): raise ValueError("Unknown format {}".format(data_format)) -class UnaryOpsTest(XLATestCase): +class UnaryOpsTest(xla_test.XLATestCase): """Test cases for unary operators.""" def _assertOpOutputMatchesExpected(self, diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index bd616f2a20..dd2c252d38 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -37,7 +37,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training.gradient_descent import GradientDescentOptimizer -class VariableOpsTest(XLATestCase): +class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" def testOneWriteOneOutput(self): @@ -435,7 +435,7 @@ class StridedSliceAssignChecker(object): self.test.assertAllEqual(val, valnp) -class SliceAssignTest(XLATestCase): +class SliceAssignTest(xla_test.XLATestCase): def testSliceAssign(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index f79eb27435..b637cf31cf 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class WhileTest(XLATestCase): +class WhileTest(xla_test.XLATestCase): def testSingletonLoopHandrolled(self): # Define a function for the loop body diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f0b010fa67..06d977b93c 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test -class XlaDeviceTest(XLATestCase): +class XlaDeviceTest(xla_test.XLATestCase): def testCopies(self): """Tests that copies onto and off XLA devices work.""" diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index a7b9cc6c81..aa9c0596d1 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 45657bb150..e6cbf2349d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -121,6 +121,7 @@ tf_kernel_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index b0ba25b998..4cfe946b2e 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -28,11 +28,10 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), + auto result = BatchDot(ctx->Input(0), ctx->Input(1), /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index fe6651793d..9fcbc86adc 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -24,12 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - auto result = Cholesky(ctx->builder(), ctx->Input(0)); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 5d41fc708a..48ac4867ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" @@ -96,14 +97,9 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, // Create a M sized linspace and an M*N sized linspace that will be // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, - &input_feature_iota)); - xla::XlaOp expanded_feature_iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - input_feature * depthwise_multiplier, - &expanded_feature_iota)); + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); // Divide the M*N sized linspace by the depthwise_multiplier to create // [0 0 1 1 2 2] in the example in the function comment. diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 17bf0c069c..378b62c0d6 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -39,9 +40,7 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::XlaOp iota; - TF_RETURN_IF_ERROR( - XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size); xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size}); xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0}); diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index b2451236de..65d42a302f 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" @@ -111,9 +112,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { // Builds an identity matrix as a broadcast equality of iotas. // iota = np.arange(np.prod(ksize), depth) // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::XlaOp iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - kernel_size * depth, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth); auto lhs = xla::Reshape(iota, lhs_shape); auto filter = xla::ConvertElementType( diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index de971ce4ac..d6bf92fb3d 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -128,10 +129,7 @@ const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, gtl::ArraySlice<int64> kernel_size, int64 channels) { - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); auto diag = xla::ConvertElementType( xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, @@ -149,10 +147,7 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, gtl::ArraySlice<int64> kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); auto diag = xla::ConvertElementType( xla::Eq( diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index 9d3575e331..e06c87db7a 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -51,6 +52,7 @@ class MatrixBandPartOp : public XlaOpKernel { xla::XlaOp num_upper = context->Input(2); DataType input_type = context->input_type(0); DataType index_type = context->input_type(1); + xla::PrimitiveType index_xla_type = context->input_xla_type(1); TensorShape batch_shape = input_shape; batch_shape.RemoveLastDims(2); @@ -59,11 +61,8 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); - - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); + xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m); + xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n); auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, /*broadcast_dimensions=*/{0}); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index 7bf1894ea0..e2ab4b83cf 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -62,10 +63,8 @@ class MatrixSetDiagOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(builder, context->input_type(0)); // Create an indicator tensor that is true only on the diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); + xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m); + xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n); auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}), /*broadcast_dimensions=*/{0}); indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index eaed931464..f4def11d08 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -30,13 +30,9 @@ class MatrixTriangularSolveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { auto result = TriangularSolve( - ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true, + ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 51f2cdc9f4..d5b645d70a 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -84,8 +85,7 @@ class RandomShuffleOp : public XlaOpKernel { xla::ConstantR0<int32>(builder, n), swaps_shape); // Generate range(n) as the initial value for the indices to be swapped. - xla::XlaOp indices; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); + xla::XlaOp indices = xla::Iota(builder, xla::S32, n); // Swap the indices at i and swaps[i]. auto swap_body_fn = [&](xla::XlaOp i, diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 16491002b4..c810456f94 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -165,9 +166,8 @@ class ReverseSequenceOp : public XlaOpKernel { auto output = xla::GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::XlaOp iota; - OP_REQUIRES_OK( - context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + xla::XlaOp iota = + xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len); std::vector<int64> dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 3b19f8d872..50a455b520 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -127,7 +128,7 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, // Fill the generator inputs with unique counter values. ThreeFry2x32State inputs; - TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); + inputs[0] = xla::Iota(builder, xla::S32, half_size); inputs[1] = xla::Add(inputs[0], xla::ConstantR0<int32>(builder, half_size)); ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index beb7cf263d..8a1377fc38 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -62,8 +63,7 @@ class TopKOp : public XlaOpKernel { k = input_shape.dim_size(0); } const xla::XlaOp input_bf16 = context->Input(0); - xla::XlaOp iota_s32; - OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota_s32)); + xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, n); // TODO(b/73891930): add a key-value sort to HLO, rather than using // bit-packing tricks here. diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 3823f5c087..e996916461 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -118,7 +118,7 @@ XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Log, xla::Log(x)); -XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); +XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x)); XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); @@ -172,7 +172,7 @@ XLAJIT_MAKE_UNARY(Sinh, // max(x, 0) + log1p(exp(-abs(x))) XLAJIT_MAKE_UNARY(Softplus, xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))), - b->Log1p(xla::Exp(xla::Neg(xla::Abs(x)))))); + xla::Log1p(xla::Exp(xla::Neg(xla::Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index dd29bafcd9..f9f3a8c8cf 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -26,92 +26,94 @@ limitations under the License. namespace tensorflow { -xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, - bool conjugate_y) { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector<int64> batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, bool conjugate_y) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", + "Arguments to BatchedDot have different ranks: ", + xla::ShapeUtil::HumanString(x_shape), " vs. ", xla::ShapeUtil::HumanString(y_shape)); } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector<int64> dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + const int ndims = xla::ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to BatchedDot must have rank >= 2: ", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector<int64> batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return errors::InvalidArgument( + "Dimension ", i, " of inputs to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); + int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return errors::InvalidArgument( + "Dimensions ", x_inner_dim, " and ", y_inner_dim, + " of arguments to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), + " transpose: ", transpose_y); + } + + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::IsZeroElementArray(x_shape) || + xla::ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector<int64> dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::Literal::Zero(x_shape.element_type())), + dimensions); + } + + if (x_shape.element_type() == xla::C64 && conjugate_x) { + x = xla::Conj(x); + } + if (y_shape.element_type() == xla::C64 && conjugate_y) { + y = xla::Conj(y); + } + + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose + // HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; + return xla::Dot(lhs, rhs); + } + + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return xla::Broadcast( - xla::ConstantLiteral(builder, - xla::Literal::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = xla::Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = xla::Conj(y); - } - - // If there are no batch dimensions, use a regular Dot. - // TODO(b/69062148) Remove this code when Dot emitters can be passed - // dimensions to transpose directly (i.e. without requiring a Transpose HLO). - if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; - return xla::Dot(lhs, rhs); - } - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - return xla::DotGeneral(x, y, dot_dnums); + return xla::DotGeneral(x, y, dot_dnums); + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 1acc72033b..d07a9486f1 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -43,10 +43,9 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x = false, - bool conjugate_y = false); +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 397f0e3a72..a90178c7d9 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -48,173 +48,163 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder, - const xla::XlaOp& a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - 2); - - xla::XlaOp l = Zeros(builder, a_shape); - - // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr<std::vector<xla::XlaOp>> { - xla::Shape col_shape; - xla::Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = Zeros(body_builder, col_shape); - - std::vector<int32> mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector); - auto mask_range_row = - xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); - auto mask_range_col = - xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); - auto body_a = loop_vars[0]; - auto body_l = loop_vars[1]; - - // row = l[..., i, :i] - // select the whole i-th row, then mask out all columns past i-1 - auto zero = xla::ConstantR0<int32>(body_builder, 0); - TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, - {i, zero}, {1, n})); - auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); - // a[..., i, i] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, - {i, i}, {1, 1})); - // np.dot(row, np.swapaxes(row, -1, -2)) - xla::XlaOp diag_dot; - TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, - // np.swapaxes(row, -1, -2))) - auto l_ii = - xla::Pow(xla::Sub(a_ii, diag_dot), - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); - - // a[..., i+1:, i] - // select the whole i-th column, then mask out all rows above i+1 +xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - 2); + + xla::XlaOp l = Zeros(builder, a_shape); + + // Construct the for loop body to iterate over rows. + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr<std::vector<xla::XlaOp>> { + xla::Shape col_shape; + xla::Shape row_shape; + for (int64 d : major_dims) { + row_shape.add_dimensions(d); + col_shape.add_dimensions(d); + } + row_shape.add_dimensions(1); + row_shape.add_dimensions(n); + row_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_row = Zeros(body_builder, row_shape); + + col_shape.add_dimensions(n); + col_shape.add_dimensions(1); + col_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_col = Zeros(body_builder, col_shape); + + std::vector<int32> mask_vector(n); + std::iota(mask_vector.begin(), mask_vector.end(), 0); + auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector); + auto mask_range_row = + xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = + xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + + // row = l[..., i, :i] + // select the whole i-th row, then mask out all columns past i-1 + auto zero = xla::ConstantR0<int32>(body_builder, 0); + auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); + auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + // a[..., i, i] + auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + // np.dot(row, np.swapaxes(row, -1, -2)) + auto diag_dot = BatchDot(row, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, + // np.swapaxes(row, -1, -2))) + auto l_ii = + xla::Pow(a_ii - diag_dot, + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + + // a[..., i+1:, i] + // select the whole i-th column, then mask out all rows above i+1 + auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); + auto a_ip1i = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + + // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / + // l[..., i, i] + // The columns in [i, n] are zeroed out in `row`, so we just have to + // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], + // r.T) + auto dot = BatchDot(body_l, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // np.dot(l[..., i+1:, :i], r.T) + auto dot_ip1 = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + + body_l = + DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); + // Assign the diagonal after the rest of the column because otherwise the + // column assign will wrap around and overwrite the diagonal assign. + body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); + + return std::vector<xla::XlaOp>{body_a, body_l}; + }; + TF_ASSIGN_OR_RETURN( - auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); - auto a_ip1i = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); - - // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / - // l[..., i, i] - // The columns in [i, n] are zeroed out in `row`, so we just have to - // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], - // r.T) - TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); - - auto col_update = xla::Div(xla::Sub(a_ip1i, dot_ip1), l_ii); - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, col_update, {i})); - // Assign the diagonal after the rest of the column because otherwise the - // column assign will wrap around and overwrite the diagonal assign. - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, l_ii, {i, i})); - - return std::vector<xla::XlaOp>{body_a, body_l}; - }; - - TF_ASSIGN_OR_RETURN( - auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); - - return cholesky_while[1]; + auto cholesky_while, + XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + + return cholesky_while[1]; + }); } } // namespace -xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); - } - - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); - } - - // Blocked left-looking Cholesky factorization. - // Algorithm 1 from - // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only - // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = Zeros(builder, a_shape); - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - if (i > 0) { - // TODO(phawkins): consider implementing SYRK for the diagonal part of - // the panel. - // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) - TF_ASSIGN_OR_RETURN(auto lhs, - SliceInMinorDims(builder, l, {i, 0}, {n, i})); - TF_ASSIGN_OR_RETURN(auto rhs, - SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); - TF_ASSIGN_OR_RETURN(auto delta, - BatchDot(builder, lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto before, - SliceInMinorDims(builder, a, {i, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(a, UpdateSliceInMinorDims( - builder, a, xla::Sub(before, delta), {i, i})); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to Cholesky must have rank >= 2: ", ndims); + } + + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "Arguments to Cholesky must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to Cholesky must be >= 1; got ", block_size); } - // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) - TF_ASSIGN_OR_RETURN(auto x, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x)); - TF_ASSIGN_OR_RETURN(l, - UpdateSliceInMinorDims(builder, l, factorized, {i, i})); - - if (i + k < n) { - // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) - TF_ASSIGN_OR_RETURN(auto panel, - SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(auto update, - TriangularSolve(builder, factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*transpose_a=*/true, - /*conjugate_a=*/false, - /*block_size=*/block_size)); - TF_ASSIGN_OR_RETURN( - l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); + // Blocked left-looking Cholesky factorization. + // Algorithm 1 from + // Haidar, Azzam, et al. "High-performance Cholesky factorization for + // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. + xla::XlaOp l = Zeros(builder, a_shape); + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + if (i > 0) { + // TODO(phawkins): consider implementing SYRK for the diagonal part of + // the panel. + // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) + auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); + auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); + auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, + /*transpose_y=*/true); + auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); + a = UpdateSliceInMinorDims(a, before - delta, {i, i}); + } + + // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) + auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto factorized = CholeskyUnblocked(x); + l = UpdateSliceInMinorDims(l, factorized, {i, i}); + + if (i + k < n) { + // l[i+k:, i:i+k] = + // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) + auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); + auto update = TriangularSolve(factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*transpose_a=*/true, + /*conjugate_a=*/false, + /*block_size=*/block_size); + l = UpdateSliceInMinorDims(l, update, {i + k, i}); + } } - } - return l; + return l; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 20fca7969e..0f6e0e9d15 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -30,8 +30,7 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size = 256); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index b9f695ac4b..0d3ce129c7 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -30,621 +30,564 @@ limitations under the License. namespace tensorflow { -xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); - } - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); - } - // The batch dimensions must be equal. - std::vector<int64> batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - int64 b_size = b_shape.dimensions(i); - if (a_size != b_size) { +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(a_shape), " vs. ", xla::ShapeUtil::HumanString(b_shape)); } - batch_dimensions.push_back(a_size); - } - - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", - block_size); - } - - std::map<int, xla::XlaComputation> base_computations; - auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr<xla::XlaComputation*> { - xla::XlaComputation& computation = base_computations[k]; - if (computation.IsNull()) { - std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder( - tensorflow::strings::StrCat("trsm_base_", k)); - - auto a_param = xla::Parameter( - sub.get(), 0, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, {k, k})), - "a"); - - std::array<int64, 2> b_lastd; - if (left_side) { - b_lastd = {k, n}; - } else { - b_lastd = {m, k}; - } - auto b_param = xla::Parameter( - sub.get(), 1, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), - "b"); - - // We use a left-looking or right-looking subroutine on the block diagonal - // in the lower=true cases, while falling back to a recursive call in - // others. The left-looking and right-looking subroutines are written with - // a While loop and so yields much faster compile times. Moreover, they - // can give higher performance on smaller (sub)problems. - if (left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else if (!left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else { - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - left_side, lower, transpose_a, - conjugate_a, - /*block_size=*/1) - .status()); + const int ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector<int64> batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } + batch_dimensions.push_back(a_size); + } - TF_ASSIGN_OR_RETURN(computation, sub->Build()); + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); + } + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } - return &computation; - }; - - xla::XlaOp output = Zeros(builder, b_shape); - - // Right-looking blocked triangular solve. - // For an explanation of the algorithm, see the TRSM discussion in: - // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation - // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 - // (2008): 4. - - // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if - // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if - // conjugate_a is True. - - if (!left_side && lower == transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) - if (i + k < n) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); - } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); - } + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); } - } else if (left_side && lower != transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < m; i += block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) - if (i + k < m) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); + std::map<int, xla::XlaComputation> base_computations; + auto get_base_triangular_solve = + [&](int k) -> xla::StatusOr<xla::XlaComputation*> { + xla::XlaComputation& computation = base_computations[k]; + if (computation.IsNull()) { + std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder( + tensorflow::strings::StrCat("trsm_base_", k)); + + auto a_param = xla::Parameter( + sub.get(), 0, + xla::ShapeUtil::MakeShape(b_shape.element_type(), + ConcatVectors(batch_dimensions, {k, k})), + "a"); + + std::array<int64, 2> b_lastd; + if (left_side) { + b_lastd = {k, n}; + } else { + b_lastd = {m, k}; + } + auto b_param = xla::Parameter( + sub.get(), 1, + xla::ShapeUtil::MakeShape(b_shape.element_type(), + ConcatVectors(batch_dimensions, b_lastd)), + "b"); + + // We use a left-looking or right-looking subroutine on the block + // diagonal in the lower=true cases, while falling back to a recursive + // call in others. The left-looking and right-looking subroutines are + // written with a While loop and so yields much faster compile times. + // Moreover, they can give higher performance on smaller (sub)problems. + if (left_side && lower) { + TriangularSolveLeftLooking(a_param, b_param, transpose_a, + conjugate_a); + } else if (!left_side && lower) { + TriangularSolveRightLooking(a_param, b_param, transpose_a, + conjugate_a); } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); + TriangularSolve(a_param, b_param, left_side, lower, transpose_a, + conjugate_a, + /*block_size=*/1); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); - } - } - } else if (!left_side && lower != transpose_a) { - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); + TF_ASSIGN_OR_RETURN(computation, sub->Build()); } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + return &computation; + }; + + xla::XlaOp output = Zeros(builder, b_shape); + + // Right-looking blocked triangular solve. + // For an explanation of the algorithm, see the TRSM discussion in: + // Goto, Kazushige, and Robert Van De Geijn. "High-performance + // implementation of the level-3 BLAS." ACM Transactions on Mathematical + // Software (TOMS) 35.1 (2008): 4. + + // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if + // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if + // conjugate_a is True. + + if (!left_side && lower == transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; } + output = UpdateSliceInMinorDims(output, update, {0, i}); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) + if (i + k < n) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i + k, i}, {n, i + k}); + } else { + a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, n}); + } + + auto b_update = BatchDot(update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + auto b_slice_2 = SliceInMinorDims(b, {0, i + k}, {m, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, i + k}); + } + } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {m, i})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } else if (left_side && lower != transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < m; i += block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); + } else { + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {i, 0}); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) + if (i + k < m) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i + k, i}, {m, i + k}); + } else { + a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, m}); + } + + auto b_update = BatchDot(a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto b_slice_2 = SliceInMinorDims(b, {i + k, 0}, {m, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {i + k, 0}); + } } - } - } else { // left_side && lower == transpose_a - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); + } else if (!left_side && lower != transpose_a) { + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = + xla::RoundUpToNearest(n, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); + } else { + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {0, i}); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) + if (i - k >= 0) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i}); + } else { + a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k}); + } + + auto b_update = BatchDot(update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {m, i}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0}); + } } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { // left_side && lower == transpose_a + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = + xla::RoundUpToNearest(m, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n}); + xla::XlaOp update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, + get_base_triangular_solve(k)); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + update = b_slice / a_slice_conj; + } + output = UpdateSliceInMinorDims(output, update, {i, 0}); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) + if (i - k >= 0) { + xla::XlaOp a_slice_2; + if (lower) { + a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i}); + } else { + a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k}); + } + + auto b_update = BatchDot(a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {i, n}); + b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0}); } - - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {i, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); } } - } - return output; + return output; + }); } -xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector<int64> batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - - // Allocate the output and set its first or last row, - // output = np.zeros_like(b) - // if transpose_a: - // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] - // else: - // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::XlaOp output = Zeros(builder, b_shape); - { - auto i = transpose_a ? m - 1 : 0; - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - auto update = xla::Div(b_slice, a_slice_conj); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - } - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (m-2, output, a, b) - // else: - // init = (1, output, a, b) - std::vector<xla::Shape> tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1); - auto init = xla::Tuple(builder, {init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i >= 0 if transpose_a else i < m - std::unique_ptr<xla::XlaBuilder> condb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); - { - auto i = xla::GetTupleElement( - xla::Parameter(condb.get(), 0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"), - 0); - if (transpose_a) { - xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0)); - } else { - xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m)); +xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector<int64> batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) - // else: - // a_row = a[..., i:i+1, :i] - // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) - // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr<xla::XlaBuilder> bodyb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); - { - auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = xla::GetTupleElement(input_tuple, 0); - auto body_out = xla::GetTupleElement(input_tuple, 1); - auto body_a = xla::GetTupleElement(input_tuple, 2); - auto body_b = xla::GetTupleElement(input_tuple, 3); - auto zero = xla::ConstantR0<int32>(bodyb.get(), 0); - - // We'd like to implement this: - // if transpose_a: - // a_row = T(a[..., i+1:, i:i+1]) - // result_row = (b[..., i:i+1, :] - // - np.matmul(a_row, body_out[..., i+1:, :])) - // else: - // result_row = (b[..., i:i+1, :] - // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) - // But since we can't have intermediate array sizes depend on the loop - // counter, we instead exploit the fact that we initialized the output to - // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::XlaOp a_row; - if (transpose_a) { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {zero, i}, {m, 1})); - } else { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, zero}, {1, m})); + + // The main computation is performed in a While loop. + + // Allocate the output and set its first or last row, + // output = np.zeros_like(b) + // if transpose_a: + // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] + // else: + // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] + xla::XlaOp output = Zeros(builder, b_shape); + { + auto i = transpose_a ? m - 1 : 0; + auto a_slice = SliceInMinorDims(a, {i, i}, {i + 1, i + 1}); + auto b_slice = SliceInMinorDims(b, {i, 0}, {i + 1, n}); + auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a); + auto update = b_slice / a_slice_conj; + output = UpdateSliceInMinorDims(output, update, {i, 0}); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN( - auto result_row_slice, - DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); - auto result_row = xla::Sub(result_row_slice, b_update); - - // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_elt_conj, - MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); - auto div_result = xla::Div(result_row, a_elt_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {i, zero})); + // Construct the initial loop carry tuple, // if transpose_a: - // return (i - 1, body_out, a, b) + // init = (m-2, output, a, b) // else: - // return (i + 1, body_out, a, b) - auto next_i = - xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1)); - xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = xla::While(cond, body, init); - return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + // init = (1, output, a, b) + std::vector<xla::Shape> tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1); + auto init = xla::Tuple(builder, {init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i >= 0 if transpose_a else i < m + std::unique_ptr<xla::XlaBuilder> condb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); + { + auto i = xla::GetTupleElement( + xla::Parameter(condb.get(), 0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"), + 0); + if (transpose_a) { + xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0)); + } else { + xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) + // else: + // a_row = a[..., i:i+1, :i] + // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) + // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop + // counter. + std::unique_ptr<xla::XlaBuilder> bodyb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); + { + auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = xla::GetTupleElement(input_tuple, 0); + auto body_out = xla::GetTupleElement(input_tuple, 1); + auto body_a = xla::GetTupleElement(input_tuple, 2); + auto body_b = xla::GetTupleElement(input_tuple, 3); + auto zero = xla::ConstantR0<int32>(bodyb.get(), 0); + + // We'd like to implement this: + // if transpose_a: + // a_row = T(a[..., i+1:, i:i+1]) + // result_row = (b[..., i:i+1, :] + // - np.matmul(a_row, body_out[..., i+1:, :])) + // else: + // result_row = (b[..., i:i+1, :] + // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) + // But since we can't have intermediate array sizes depend on the loop + // counter, we instead exploit the fact that we initialized the output to + // all zeros and use that as zero-padding (doing unnecessary FLOPs). + xla::XlaOp a_row; + if (transpose_a) { + a_row = DynamicSliceInMinorDims(body_a, {zero, i}, {m, 1}); + } else { + a_row = DynamicSliceInMinorDims(body_a, {i, zero}, {1, m}); + } + auto b_update = BatchDot(a_row, body_out, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false); + auto result_row_slice = + DynamicSliceInMinorDims(body_b, {i, zero}, {1, n}); + auto result_row = result_row_slice - b_update; + + // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + auto a_elt = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + auto a_elt_conj = MaybeConjugate(a_elt, conjugate_a); + auto div_result = xla::Div(result_row, a_elt_conj); + body_out = DynamicUpdateSliceInMinorDims(body_out, div_result, {i, zero}); + + // if transpose_a: + // return (i - 1, body_out, a, b) + // else: + // return (i + 1, body_out, a, b) + auto next_i = xla::Add( + i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1)); + xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = xla::While(cond, body, init); + return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + }); } -xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector<int64> batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - xla::XlaOp output = Zeros(builder, b_shape); - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (0, output, a, b) - // else: - // init = (n-1, output, a, b) - std::vector<xla::Shape> tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1); - auto init = xla::Tuple(builder, {init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i < n if transpose_a else i >= 0 - std::unique_ptr<xla::XlaBuilder> condb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); - { - auto i = xla::GetTupleElement( - xla::Parameter(condb.get(), 0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"), - 0); - if (transpose_a) { - xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n)); - } else { - xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0)); +xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector<int64> batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) - // else: - // a_row = a[..., :, i:i+1] - // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) - // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr<xla::XlaBuilder> bodyb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); - { - auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = xla::GetTupleElement(input_tuple, 0); - auto body_out = xla::GetTupleElement(input_tuple, 1); - auto body_a = xla::GetTupleElement(input_tuple, 2); - auto body_b = xla::GetTupleElement(input_tuple, 3); - auto zero = xla::ConstantR0<int32>(bodyb.get(), 0); - - // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, - // i:i+1]) But since we can't have intermediate array sizes depend on the - // loop counter, we instead exploit the fact that we initialized the output - // to all zeros and use that as zero-padding (doing unnecessary FLOPs). - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - // result = b - np.matmul(output, a) - auto result = xla::Sub(body_b, b_update); - // result_row = result[..., :, i:i+1] - TF_ASSIGN_OR_RETURN( - auto result_row, - DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); - - // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_ii_conj, - MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); - auto div_result = xla::Div(result_row, a_ii_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {zero, i})); + // The main computation is performed in a While loop. + xla::XlaOp output = Zeros(builder, b_shape); + + // Construct the initial loop carry tuple, // if transpose_a: - // return (i + 1, body_out, a, b) + // init = (0, output, a, b) // else: - // return (i - 1, body_out, a, b) - auto next_i = - xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1)); - xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = xla::While(cond, body, init); - return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + // init = (n-1, output, a, b) + std::vector<xla::Shape> tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1); + auto init = xla::Tuple(builder, {init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i < n if transpose_a else i >= 0 + std::unique_ptr<xla::XlaBuilder> condb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); + { + auto i = xla::GetTupleElement( + xla::Parameter(condb.get(), 0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"), + 0); + if (transpose_a) { + xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n)); + } else { + xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) + // else: + // a_row = a[..., :, i:i+1] + // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) + // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop + // counter. + std::unique_ptr<xla::XlaBuilder> bodyb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); + { + auto input_tuple = xla::Parameter( + bodyb.get(), 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = xla::GetTupleElement(input_tuple, 0); + auto body_out = xla::GetTupleElement(input_tuple, 1); + auto body_a = xla::GetTupleElement(input_tuple, 2); + auto body_b = xla::GetTupleElement(input_tuple, 3); + auto zero = xla::ConstantR0<int32>(bodyb.get(), 0); + + // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, + // i:i+1]) But since we can't have intermediate array sizes depend on the + // loop counter, we instead exploit the fact that we initialized the + // output to all zeros and use that as zero-padding (doing unnecessary + // FLOPs). + auto b_update = BatchDot(body_out, body_a, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a); + // result = b - np.matmul(output, a) + auto result = body_b - b_update; + // result_row = result[..., :, i:i+1] + auto result_row = DynamicSliceInMinorDims(result, {zero, i}, {m, 1}); + + // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + auto a_ii_conj = MaybeConjugate(a_ii, conjugate_a); + auto div_result = xla::Div(result_row, a_ii_conj); + body_out = DynamicUpdateSliceInMinorDims(body_out, div_result, {zero, i}); + + // if transpose_a: + // return (i + 1, body_out, a, b) + // else: + // return (i - 1, body_out, a, b) + auto next_i = xla::Add( + i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1)); + xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = xla::While(cond, body, init); + return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 540c26b247..80c2bc4c9c 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -57,23 +57,15 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size = 256); +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size = 256); -xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); +xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a); -xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); +xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b, + bool transpose_a, bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 87ea4763f7..d5ffc1498e 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -85,11 +85,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<float> expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -107,11 +106,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<float> expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -129,11 +127,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<float> expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -151,11 +148,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<float> expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -173,11 +169,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<float> expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -196,11 +191,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<float> expected({ {0.5, 1.0, 1.5}, @@ -219,11 +213,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<float> expected({ {0.5, 1.0, 1.5}, @@ -242,11 +235,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<float> expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -267,11 +259,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/true, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/true, + /*block_size=*/2); xla::Array2D<complex64> expected({ {0.5, complex64(0.08333333, 0.08333333), @@ -295,11 +286,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D<complex64> expected({ {0.5, 1., 1.5}, @@ -323,10 +313,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); + TriangularSolveLeftLooking(a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); xla::Array2D<float> expected({ {0.5, 1.0, 1.5}, @@ -345,10 +334,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { xla::XlaOp a, b; auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); + TriangularSolveLeftLooking(a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); xla::Array2D<float> expected({ {0.5, 1.0, 1.5}, diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 11774dde08..6694729495 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -111,130 +111,137 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice<int64> start, - gtl::ArraySlice<int64> end) { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector<int64> padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector<int64> padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector<int64> strides(n_dims, 1); - return xla::Slice(x, padded_start, padded_end, strides); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start, + gtl::ArraySlice<int64> end) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector<int64> padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector<int64> padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector<int64> strides(n_dims, 1); + return xla::Slice(x, padded_start, padded_end, strides); + }); } -std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice<int64>& major_dims, - const gtl::ArraySlice<int64>& indices) { - std::vector<int64> output(indices.size() + major_dims.size()); - std::copy(major_dims.begin(), major_dims.end(), output.begin()); - std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); +std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs, + gtl::ArraySlice<int64> ys) { + std::vector<int64> output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); return output; } -xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector<xla::XlaOp>& starts, - const gtl::ArraySlice<int64>& sizes) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); - auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); - return xla::DynamicSlice(x, padded_starts, padded_sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice<xla::XlaOp> starts, + gtl::ArraySlice<int64> sizes) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return xla::DynamicSlice(x, padded_starts, padded_sizes); + }); } -xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice<int64> start) { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector<int32> start_as_int32(start.begin(), start.end()); - auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return xla::DynamicUpdateSlice(x, update, start_constant); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice<int64> start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector<int32> start_as_int32(start.begin(), start.end()); + auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return xla::DynamicUpdateSlice(x, update, start_constant); + }); } -xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice<int64> start) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector<int64> padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(builder, x, update, padded_start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice<int64> start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector<int64> padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); } -xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector<xla::XlaOp>& starts) { - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice<xla::XlaOp> starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); return xla::DynamicUpdateSlice(x, update, padded_starts); } -xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector<xla::XlaOp>& starts) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1}); - std::vector<xla::XlaOp> padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); - } - return xla::ConcatInDim(builder, padded_starts, 0); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice<xla::XlaOp> starts) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1}); + std::vector<xla::XlaOp> padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); + } + return xla::ConcatInDim(builder, padded_starts, 0); + }); } -xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector<int64> permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return xla::Transpose(x, permutation); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector<int64> permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return xla::Transpose(x, permutation); + }); } -xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? xla::Conj(x) : x; +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? xla::Conj(x) : x; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 3c120a2548..ac5d2940ff 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -33,7 +33,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. -xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, gtl::ArraySlice<xla::XlaOp> starts); // Returns a integer scalar constant of 'type' with 'value'. @@ -41,54 +41,43 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last two values being +// Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. -xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector<xla::XlaOp>& starts); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice<xla::XlaOp> starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice<int64> start, - gtl::ArraySlice<int64> end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start, + gtl::ArraySlice<int64> end); -// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. -std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice<int64>& major_dims, - const gtl::ArraySlice<int64>& indices); +// Returns the concatenation of `xs` and `ys`. +std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs, + gtl::ArraySlice<int64> ys); // Performs a dynamic slice in the minor dimensions of a Tensor. -xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector<xla::XlaOp>& starts, const gtl::ArraySlice<int64>& sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice<xla::XlaOp> starts, + gtl::ArraySlice<int64> sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice<int64> start); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice<int64> start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice<int64> start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice<int64> start); -xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector<xla::XlaOp>& starts); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice<xla::XlaOp> starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); // Applies a complex conjugation operation if `a` is complex and `conjugate_a` // is true, otherwise returns its argument. -xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate); +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 2a332c933f..7d0f2222a9 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -70,8 +70,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a); auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x); auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y); - auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); - TF_ASSERT_OK(result.status()); + DynamicSliceInMinorDims(a, {x, y}, {1, 1}); ComputeAndCompareR2<float>(&builder, {{10}}, {a_data.get(), x_data.get(), y_data.get()}, @@ -86,10 +85,8 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index); - TF_ASSERT_OK( - DynamicSliceInMinorDims( - &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, 4}) - .status()); + DynamicSliceInMinorDims(a, {index, xla::ConstantR0<int32>(&builder, 0)}, + {1, 4}); ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, {a_data.get(), index_data.get()}); @@ -104,8 +101,7 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x); auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y); - auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); - TF_ASSERT_OK(result.status()); + DynamicUpdateSliceInMinorDims(a, b, {x, y}); xla::Array2D<float> expected( {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); @@ -128,13 +124,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) { // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims( - &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n})); - TF_ASSERT_OK(BatchDot(&builder, l_index, row, - /*transpose_x=*/false, /*transpose_y=*/true) - .status()); + auto l_index = DynamicSliceInMinorDims( + a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n}); + BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 917ef4037d..81bdf139f5 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" @@ -72,10 +73,9 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its // index. - xla::XlaOp iota; const int64 axis_size = input_shape.dim_size(axis); - TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla_output_type, axis_size); xla::XlaOp product = xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); @@ -230,31 +230,6 @@ Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, axis, /*is_min=*/true, argmin); } -Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota) { - TensorShape linspace_shape({size}); - Tensor linspace; - switch (dtype) { - case DT_UINT8: - linspace = MakeLinspaceTensor<uint8>(linspace_shape, size); - break; - case DT_INT32: - linspace = MakeLinspaceTensor<int32>(linspace_shape, size); - break; - case DT_INT64: - linspace = MakeLinspaceTensor<int64>(linspace_shape, size); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(dtype)); - } - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - - *iota = xla::ConstantLiteral(builder, linspace_literal); - return Status::OK(); -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index c320016998..495bd2b8b6 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -89,10 +89,6 @@ class XlaHelpers { DataType input_type, DataType output_type, int axis, xla::XlaOp* argmin); - // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota); - // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c2298b97e1..0eabfb3a52 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -68,6 +69,20 @@ TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } +DataType XlaOpKernelContext::input_type(int index) const { + return context_->input(index).dtype(); +} + +xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(input_type(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + Status XlaOpKernelContext::ConstantInput(int index, xla::Literal* constant_literal) { return ConstantInputReshaped( diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 667dc262ca..2bde2c983d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -67,7 +68,12 @@ class XlaOpKernelContext { int num_inputs() const { return context_->num_inputs(); } // Returns the type of input 'index'. - DataType input_type(int index) { return context_->input(index).dtype(); } + DataType input_type(int index) const; + + // Returns the type of input 'index' as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType input_xla_type(int index); // Returns the shape of input 'index'. TensorShape InputShape(int index); diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d49d959a6c..273fa17371 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -13,6 +13,12 @@ filegroup( ]), ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") + +# Generate test_suites for all backends, named "${backend}_tests". +generate_backend_suites() + cc_library( name = "arithmetic", srcs = ["arithmetic.cc"], @@ -29,6 +35,32 @@ cc_library( ) cc_library( + name = "numeric", + srcs = ["numeric.cc"], + hdrs = ["numeric.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + ], +) + +xla_test( + name = "numeric_test", + srcs = ["numeric_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( name = "testing", srcs = ["testing.cc"], hdrs = ["testing.h"], diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc new file mode 100644 index 0000000000..cbe9e7fdd1 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/numeric.h" + +#include <numeric> +#include <vector> + +namespace xla { + +namespace { + +template <typename T> +XlaOp MakeIota(XlaBuilder* builder, int64 size) { + std::vector<T> values(size); + for (int64 i = 0; i < size; ++i) { + values[i] = static_cast<T>(i); + } + return xla::ConstantR1<T>(builder, values); +} + +} // namespace + +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + switch (type) { + case S8: + return MakeIota<int8>(builder, size); + case S16: + return MakeIota<int16>(builder, size); + case S32: + return MakeIota<int32>(builder, size); + case S64: + return MakeIota<int64>(builder, size); + case U8: + return MakeIota<uint8>(builder, size); + case U16: + return MakeIota<uint16>(builder, size); + case U32: + return MakeIota<uint32>(builder, size); + case U64: + return MakeIota<uint64>(builder, size); + case BF16: + return MakeIota<bfloat16>(builder, size); + case F16: + return MakeIota<half>(builder, size); + case F32: + return MakeIota<float>(builder, size); + case F64: + return MakeIota<double>(builder, size); + case C64: + return MakeIota<complex64>(builder, size); + default: + return builder->ReportError( + InvalidArgument("Unimplemented type for Iota: %s.", + PrimitiveType_Name(type).c_str())); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h new file mode 100644 index 0000000000..2a409ae311 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -0,0 +1,30 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc new file mode 100644 index 0000000000..bc8a73e9d7 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using NumericTest = ClientLibraryTestBase; + +XLA_TEST_F(NumericTest, Iota) { + XlaBuilder builder(TestName()); + Iota(&builder, S32, 10); + + ComputeAndCompareR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index b0f41ac1d3..ee00a9eada 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -1,7 +1,5 @@ # Description: # The new XLA client libraries. -# -# This is NOT YET ready to use. licenses(["notice"]) # Apache 2.0 @@ -41,6 +39,7 @@ cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], deps = [ ":xla_computation", "//tensorflow/compiler/xla:execution_options_util", diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 0145f60483..4f683a4115 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -60,36 +60,18 @@ bool CanBeRoot(HloOpcode opcode) { } // namespace -XlaOp operator-(const XlaOp& x) { return x.builder()->Neg(x); } -XlaOp operator+(const XlaOp& x, const XlaOp& y) { - return x.builder()->Add(x, y); -} -XlaOp operator-(const XlaOp& x, const XlaOp& y) { - return x.builder()->Sub(x, y); -} -XlaOp operator*(const XlaOp& x, const XlaOp& y) { - return x.builder()->Mul(x, y); -} -XlaOp operator/(const XlaOp& x, const XlaOp& y) { - return x.builder()->Div(x, y); -} -XlaOp operator%(const XlaOp& x, const XlaOp& y) { - return x.builder()->Rem(x, y); -} - -XlaOp operator~(const XlaOp& x) { return x.builder()->Not(x); } -XlaOp operator&(const XlaOp& x, const XlaOp& y) { - return x.builder()->And(x, y); -} -XlaOp operator|(const XlaOp& x, const XlaOp& y) { - return x.builder()->Or(x, y); -} -XlaOp operator^(const XlaOp& x, const XlaOp& y) { - return x.builder()->Xor(x, y); -} -XlaOp operator<<(const XlaOp& x, const XlaOp& y) { - return x.builder()->ShiftLeft(x, y); -} +XlaOp operator-(const XlaOp& x) { return Neg(x); } +XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); } +XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); } +XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); } +XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); } +XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); } + +XlaOp operator~(const XlaOp& x) { return Not(x); } +XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); } +XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); } +XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); } +XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); } XlaOp operator>>(const XlaOp& x, const XlaOp& y) { XlaBuilder* builder = x.builder(); @@ -101,9 +83,9 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) { ShapeUtil::HumanString(shape).c_str()); } if (ShapeUtil::ElementIsSigned(shape)) { - return builder->ShiftRightArithmetic(x, y); + return ShiftRightArithmetic(x, y); } else { - return builder->ShiftRightLogical(x, y); + return ShiftRightLogical(x, y); } }); } @@ -1366,8 +1348,25 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSort, operand); +XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + std::vector<const Shape*> operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + operand_shape_ptrs.push_back(&keys_shape); + Shape values_shape; + if (values.has_value()) { + TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); + operand_shape_ptrs.push_back(&values_shape); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + return values.has_value() + ? AddInstruction(std::move(instr), HloOpcode::kSort, + {keys, *values}) + : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + }); } XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { @@ -2538,7 +2537,9 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(const XlaOp& operand) { return operand.builder()->Sort(operand); } +XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) { + return keys.builder()->Sort(keys, std::move(values)); +} XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index fe31774b86..ac6ad87349 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -158,6 +158,93 @@ class XlaBuilder { die_immediately_on_error_ = enabled; } + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. + StatusOr<XlaComputation> Build(); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const; + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr<Shape> GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr<ProgramShape> GetProgramShape() const; + + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const Status& error); + + // A helper function that converts a StatusOr<XlaOp> into an XlaOp. + // If the Status was an error, reports the error to builder and returns an + // invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op); + + // A helper function that runs a function that returns a StatusOr<XlaOp> and + // returns an XlaOp. + XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr<bool> IsConstant(const XlaOp& operand) const; + + private: // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -378,26 +465,6 @@ class XlaBuilder { XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers); - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Returns an error if the convolution dimension numbers have conflicts. - static Status Validate(const ConvolutionDimensionNumbers& dnum); - // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, @@ -717,7 +784,18 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<int64> dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. - XlaOp Sort(const XlaOp& operand); + // If only keys are provided: + // * The keys must be a rank-1 tensor (i.e. an array). + // * The result is a sorted array of keys. + // + // If both keys and values are provided: + // * The keys and the values must be rank-1 tensors with the same dimensions. + // The element types of the tensors may be different. + // * The result is a tuple that consists of a sorted array of keys as the + // first element, and an array with their corresponding values as the second + // element. + XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values = + tensorflow::gtl::nullopt); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -764,14 +842,6 @@ class XlaBuilder { // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on any parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. - // - // This tests whether a computation is a compile-time constant without - // evaluating the computation. - StatusOr<bool> IsConstant(const XlaOp& operand) const; - // Normalizes operand across spatial and batch dimensions for each feature. // // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` @@ -810,65 +880,6 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - // Returns a new XlaBuilder whose resultant Computation is used only by this - // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error - // behavior as the parent. - std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name); - - // Builds the computation with the requested operations, or returns a non-ok - // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. - StatusOr<XlaComputation> Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent XlaBuilder and returns an empty computation if building failed. - // This function is intended to be used where the returned XlaComputation is - // only used by the parent XlaBuilder and hence further operation on the - // returned XlaComputation will simply be error'ed out if an error occurred - // while building this computation. If the built computation is to be used by - // a XlaBuilder other than the parent XlaBuilder then Build() should be used - // instead. - XlaComputation BuildAndNoteError(); - - // Returns a subgraph that roots on the given root. If the root is not a - // compile-time constant (see `IsConstant`), returns an error. - // - // This will copy the needed ops/computations to the subgraph. - StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const; - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // XlaOp and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - // Returns the shape of the given op. - StatusOr<Shape> GetShape(const XlaOp& op) const; - - // Returns the (inferred) result for the current computation's shape. - StatusOr<ProgramShape> GetProgramShape() const; - - // Reports an error to the builder, by - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to - // Build()/GetShape()/...) - // * dying if die_immediately_on_error_ is true. - // Returns an XlaOp with an invalid handle but a valid builder. This value can - // be returned in place of a value in APIs that return an XlaOp. - XlaOp ReportError(const Status& error); - - // A helper function that converts a StatusOr<XlaOp> into an XlaOp. - // If the Status was an error, reports the error to builder and returns an - // invalid XlaOp handle. - XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op); - - // A helper function that runs a function that returns a StatusOr<XlaOp> and - // returns an XlaOp. - XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator); - - private: StatusOr<XlaOp> AddInstruction( HloInstructionProto&& instr, HloOpcode opcode, tensorflow::gtl::ArraySlice<XlaOp> operands = {}); @@ -971,6 +982,284 @@ class XlaBuilder { bool die_immediately_on_error_ = false; XlaBuilder* parent_builder_{nullptr}; + + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, + const Shape& shape, const string& name); + friend XlaOp ConstantLiteral(XlaBuilder* builder, + const LiteralSlice& literal); + template <typename NativeT> + friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); + template <typename NativeT> + friend XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<NativeT> values); + friend XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values); + template <typename NativeT> + friend XlaOp ConstantR2( + XlaBuilder* builder, + std::initializer_list<std::initializer_list<NativeT>> values); + template <typename NativeT> + friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array<NativeT>& values, + const Layout& layout); + template <typename NativeT> + friend XlaOp ConstantFromArray(XlaBuilder* builder, + const Array<NativeT>& values); + template <typename NativeT> + friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D<NativeT>& values); + template <typename NativeT> + friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D<NativeT>& values); + template <typename NativeT> + friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D<NativeT>& values); + + template <typename NativeT> + friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + + friend XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> broadcast_sizes); + + friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<int64> new_sizes); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> new_sizes); + + friend XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions); + + friend XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> start_indices, + tensorflow::gtl::ArraySlice<int64> limit_indices, + tensorflow::gtl::ArraySlice<int64> strides); + + friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); + + friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice<int64> slice_sizes); + + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + friend XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> operands, + int64 dimension); + + friend void Trace(const string& tag, const XlaOp& operand); + + friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false); + friend XlaOp Tuple(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> elements); + friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + Padding padding); + friend XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + friend XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + tensorflow::gtl::ArraySlice<int64> lhs_dilation, + tensorflow::gtl::ArraySlice<int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice<int64> fft_length); + friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config); + friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice<XlaOp> operands); + friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice<XlaOp> operands, + const Shape& shape); + friend XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Conj(const XlaOp& operand); + friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Not(const XlaOp& operand); + friend XlaOp ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); + friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + friend XlaOp ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding); + friend XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> replica_group_ids); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const tensorflow::gtl::optional<ChannelHandle>& channel_id); + friend XlaOp SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp Abs(const XlaOp& operand); + friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp Exp(const XlaOp& operand); + friend XlaOp Expm1(const XlaOp& operand); + friend XlaOp Floor(const XlaOp& operand); + friend XlaOp Ceil(const XlaOp& operand); + friend XlaOp Round(const XlaOp& operand); + friend XlaOp Log(const XlaOp& operand); + friend XlaOp Log1p(const XlaOp& operand); + friend XlaOp Sign(const XlaOp& operand); + friend XlaOp Clz(const XlaOp& operand); + friend XlaOp Cos(const XlaOp& operand); + friend XlaOp Sin(const XlaOp& operand); + friend XlaOp Tanh(const XlaOp& operand); + friend XlaOp Real(const XlaOp& operand); + friend XlaOp Imag(const XlaOp& operand); + friend XlaOp SqrtF32(const XlaOp& operand); + friend XlaOp SquareF32(const XlaOp& operand); + friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + friend XlaOp IsFinite(const XlaOp& operand); + friend XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp ReciprocalF32(const XlaOp& operand); + friend XlaOp Neg(const XlaOp& operand); + friend XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> permutation); + friend XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions); + friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values); + friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + friend XlaOp Map(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<XlaOp> operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<XlaOp> static_operands); + friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape); + friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + friend XlaOp While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init); + friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + friend void Send(const XlaOp& operand, const ChannelHandle& handle); + friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1548,8 +1837,16 @@ XlaOp Transpose(const XlaOp& operand, // is moved to index dimension_size - 1 - i). XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions); -// Enqueues a sort (as increasing order) instruction onto the computation. -XlaOp Sort(const XlaOp& operand); +// * The result is a sorted array of keys. +// +// If both keys and values are provided: +// * The keys and the values must be rank-1 tensors with the same dimensions. +// The element types of the tensors may be different. +// * The result is a tuple that consists of a sorted array of keys as the +// first element, and an array with their corresponding values as the second +// element. +XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 48fd07371d..1ddeb27e40 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1252,9 +1252,10 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, switch (instruction->opcode()) { case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kSort: case HloOpcode::kTranspose: return true; + case HloOpcode::kSort: + return (!ShapeUtil::IsTuple(instruction->shape())); default: return false; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index b0ad433d8d..ab3d846403 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -1093,8 +1093,7 @@ void MaybeDumpModule(const string& message, const HloModule& module) { } // namespace Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module, + const HloOrdering& ordering, HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& fusion_can_share_buffer) { MaybeDumpModule("after adding copies to resolve interference", *module); @@ -1108,7 +1107,6 @@ Status RemoveUnnecessaryCopies( for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id()) && instruction->CopyElisionAllowed()) { TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); } @@ -1152,16 +1150,13 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) { "Call graph must be flattened before copy insertion."); } - // Gather Ids of existing kCopy instructions in the module. We avoid removing - // these copies (except via DCE in TupleSimplifier) because they may have been - // added for reasons not considered by copy insertion (eg, layout assignment). - // Instruction id is used instead of HloInstruction* because the pointer - // values may be recycled. - tensorflow::gtl::FlatSet<int> existing_copies; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - existing_copies.insert(instruction->unique_id()); + int64 num_existing_copies = 0; + if (VLOG_IS_ON(1)) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + ++num_existing_copies; + } } } } @@ -1181,8 +1176,7 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) { TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR( - RemoveUnnecessaryCopies(ordering, existing_copies, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1203,7 +1197,7 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) { } } } - VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); + VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies; VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 6d25706089..e1973db928 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -79,11 +78,10 @@ class CopyInsertion : public HloPassInterface { }; // Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. +// live range interference. Only copy instructions that are eligible for +// copy elision are considered for removal. Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module, + const HloOrdering& ordering, HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index e7539759ce..7ae8799b61 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -125,21 +125,27 @@ TEST_F(CopyInsertionTest, SingleConstant) { } TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { - // Verify that an kCopy instructions which exist in the pass before + // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); - HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); - HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}}))); + auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape()); + Layout reversed_layout = + LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major); + Shape copy_shape = constant->shape(); + *copy_shape.mutable_layout() = reversed_layout; + HloInstruction* copy_1 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); + HloInstruction* copy_2 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); - HloInstruction* add_copy = builder.AddInstruction( - HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); + builder.AddInstruction( + HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add)); module->AddEntryComputation(builder.Build()); @@ -147,12 +153,11 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 3); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); + EXPECT_EQ(module->entry_computation()->root_instruction(), add); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 7bb8df6581..5343497c03 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -55,33 +55,28 @@ Status GpuTransferManager::TransferLiteralToInfeed( return TransferBufferToInfeed(executor, size, literal.untyped_data()); } - if (ShapeUtil::IsNestedTuple(shape)) { - return Unimplemented( - "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); - } - // For a tuple, we transfer each of its elements to the device and // enqueue the resulting destination device addresses with the // infeed manager. std::vector<gpu::InfeedBuffer*> buffers; - buffers.reserve(ShapeUtil::TupleElementCount(shape)); auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } }); - for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); - int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); - TF_ASSIGN_OR_RETURN( - gpu::InfeedBuffer * buffer, - TransferBufferToInfeedInternal(executor, tuple_element_size, - literal.untyped_data({i}))); - buffers.push_back(buffer); - } + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(literal_subshape)) { + int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); + TF_ASSIGN_OR_RETURN( + gpu::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, tuple_element_size, + literal.untyped_data(index))); + buffers.push_back(buffer); + } + return Status::OK(); + })); cleanup.release(); return EnqueueBuffersToInfeed(executor, buffers); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 652b5c7687..ea661b3c2c 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -113,10 +113,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { // We can fuse reduces and loop fusions. return IsInputFusibleReduction(instr) || (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop && - // TODO(b/110202584): bitcasts make nested fusions, GPU has no support - // for nested fusions. - instr->fused_expression_root()->opcode() != HloOpcode::kBitcast); + instr->fusion_kind() == HloInstruction::FusionKind::kLoop); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index deb7f28d84..e65e1af20c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1068,6 +1068,19 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } +Status HloEvaluator::HandleSort(HloInstruction* sort) { + if (!ShapeUtil::IsTuple(sort->shape())) { + return DefaultAction(sort); + } + // The key-value version of Sort is a special snowflake, since the output + // shape is a tuple, so its element type is not meaningful. + // + // TODO(mkuper): Do something sane here, so that we can support different key + // and value types. + return sort->Visit( + typed_visitors_.at(sort->operand(0)->shape().element_type()).get()); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 2ad56080d8..b330c30eeb 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -176,6 +176,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleAfterAll(HloInstruction* token) override; + Status HandleSort(HloInstruction* sort) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 8b08756c64..1136178e90 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1025,83 +1025,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); - std::vector<int64> lhs_non_contracting_dims; + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + + // result_index_locations[i] contains one or two pointers to the locations + // in lhs_index or rhs_index where the i'th result index should go. + tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank> + result_index_locations; + result_index_locations.reserve(lhs_rank + rhs_rank - 2); + + // The first components in the output shape are the LHS and RHS batch + // dimensions: + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) { + result_index_locations.push_back( + {&lhs_index[dnums.lhs_batch_dimensions(i)], + &rhs_index[dnums.rhs_batch_dimensions(i)]}); + } + + // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); + if (i != lhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + result_index_locations.push_back({&lhs_index[i], nullptr}); } } - - std::vector<int64> rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet<int64> batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); + if (i != rhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + result_index_locations.push_back({&rhs_index[i], nullptr}); } } - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); auto result = MakeUnique<Literal>(dot->shape()); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> result_index) { ElementwiseT result_val = static_cast<ElementwiseT>(0); - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + for (int64 i = 0; i < result_index.size(); i++) { + *result_index_locations[i].first = result_index[i]; + if (result_index_locations[i].second) { + *result_index_locations[i].second = result_index[i]; + } } // Accumulates resulting product along the contracted dimension. @@ -1402,24 +1366,68 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !is_complex_t<NativeT>::value && !std::is_same<NativeT, bool>::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { - TF_RET_CHECK(ShapeUtil::Rank(sort->shape()) == 1) + auto keys = sort->operand(0); + TF_RET_CHECK(ShapeUtil::Rank(keys->shape()) == 1) << "Sort is only supported for R1 shapes"; - auto arg = sort->operand(0); - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleSort arg_literal: " << arg_literal.ToString(); - const auto& arg_data = arg_literal.data<ReturnT>(); + const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + const auto& keys_data = keys_literal.data<ReturnT>(); + + if (sort->operand_count() == 1) { + std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const ReturnT& a, const ReturnT& b) { + return SafeLess<ReturnT>(a, b); + }); + auto result_literal = MakeUnique<Literal>(sort->shape()); + result_literal->PopulateR1( + tensorflow::gtl::ArraySlice<ReturnT>(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + parent_->evaluated_[sort] = std::move(result_literal); + } else { + CHECK_EQ(sort->operand_count(), 2); + auto values = sort->operand(1); + if (values->shape().element_type() != + primitive_util::NativeToPrimitiveType<ReturnT>()) { + return InvalidArgument( + "Evaluator requires value and key types for Sort to match"); + } - std::vector<ReturnT> return_data(arg_data.begin(), arg_data.end()); - std::sort(return_data.begin(), return_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess<ReturnT>(a, b); - }); - auto result_literal = MakeUnique<Literal>(sort->shape()); - result_literal->PopulateR1( - tensorflow::gtl::ArraySlice<ReturnT>(return_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); - parent_->evaluated_[sort] = std::move(result_literal); + // We need to sort and array of keys and an array of values, where the + // sorted order of the values is determined by the keys. The simplest(?) + // way to do this is to go to an array-of-pairs representation, sort the + // array using the keys, and then go back to pair-of-arrays. + const Literal& values_literal = parent_->GetEvaluatedLiteralFor(values); + VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); + const auto& values_data = values_literal.data<ReturnT>(); + using kv_pair = std::pair<ReturnT, ReturnT>; + std::vector<kv_pair> key_value_vector; + CHECK_EQ(keys_data.size(), values_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back( + std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess<ReturnT>(a.first, b.first); + }); + std::vector<ReturnT> result_keys, result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + auto result_keys_literal = MakeUnique<Literal>(keys->shape()); + result_keys_literal->PopulateR1( + tensorflow::gtl::ArraySlice<ReturnT>(result_keys)); + auto result_values_literal = MakeUnique<Literal>(values->shape()); + result_values_literal->PopulateR1( + tensorflow::gtl::ArraySlice<ReturnT>(result_values)); + auto result_tuple = Literal::MakeTuple( + {result_keys_literal.get(), result_values_literal.get()}); + VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + parent_->evaluated_[sort] = std::move(result_tuple); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5aaeec802f..e0e3d301be 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -489,7 +489,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: break; default: @@ -908,6 +907,16 @@ HloInstruction::CreateBroadcastSequence( return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions); } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort( + const Shape& shape, HloInstruction* keys, HloInstruction* values) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSort, shape)); + instruction->AppendOperand(keys); + if (values) { + instruction->AppendOperand(values); + } + return instruction; +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root); @@ -1122,7 +1131,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1215,6 +1223,14 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kAfterAll: clone = CreateAfterAll(new_operands); break; + case HloOpcode::kSort: + CHECK(new_operands.size() == 1 || new_operands.size() == 2) + << "Too many operands for sort: " << new_operands.size(); + HloInstruction* keys = new_operands[0]; + HloInstruction* values = + new_operands.size() == 2 ? new_operands[1] : nullptr; + clone = CreateSort(shape, keys, values); + break; } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); @@ -1491,6 +1507,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: + case HloOpcode::kSort: case HloOpcode::kSin: case HloOpcode::kSubtract: case HloOpcode::kTanh: @@ -1520,10 +1537,6 @@ bool HloInstruction::IdenticalSlowPath( return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); - // These opcodes are not yet supported. - case HloOpcode::kSort: - return false; - // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 59a383218c..0459072127 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -611,6 +611,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions); + // Creates a sort op, with a keys operand, and an optional values operand. + static std::unique_ptr<HloInstruction> CreateSort( + const Shape& shape, HloInstruction* keys, + HloInstruction* values = nullptr); + // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 57d17064c1..6ffed62a09 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -509,7 +509,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -625,6 +624,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, builder->AddInstruction(HloInstruction::CreateAfterAll(operands)); break; } + case HloOpcode::kSort: { + auto loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + switch (operands.size()) { + case 1: + instruction = builder->AddInstruction( + HloInstruction::CreateSort(shape, /*keys=*/operands[0])); + break; + case 2: + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, + /*keys=*/operands[0], /*values=*/operands[1])); + break; + default: + return Error(loc, StrCat("expects either 1 or 2 operands, but has ", + operands.size(), " operands")); + } + break; + } case HloOpcode::kTuple: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index da1a34ae3c..504ea3fe7a 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -832,6 +832,31 @@ ENTRY ReducePrecision { )" }, +// Sort (Key) +{ +"SortKey", +R"(HloModule sort + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x) +} + +)" +}, +// Sort (Key, Value) +{ +"SortKeyValue", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values) +} + +)" +}, // Conditional { "Conditional", diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 62c07d7fac..59a8800a7d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1244,7 +1244,7 @@ StatusOr<bool> HloRematerialization::Run( // TODO(b/80249101): Instead of a separate copy elision pass, use the // ordering from the HLO schedule directly for copy insertion. SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); } // Compute peak memory usage of all computations in the module called in a diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index fb39c6f085..27c9529b11 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -167,7 +167,16 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - return CheckUnaryShape(sort); + if (sort->operand_count() == 2 && + !ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(1)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys and " + "the values. Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), + ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + } + return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 79b5a442aa..4166ef5baf 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -115,39 +115,18 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, HloInstruction* fused = instr2; // Make sure that if only one of the instructions is a fusion, or if only one // of the instructions is a multi-output fusion, it's what will be fused into. - // - // An invariant is that no bitcast nodes will show up in the middle of a - // fusion node. This invariant must hold in order for us to lower it. Given - // that, we require that during multi-output fusion, a fusion node ending with - // bitcast to preserve its structure as a nested fusion instead being - // merged and flattened. - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { std::swap(remaining, fused); } if (fused->IsMultiOutputFusion()) { std::swap(remaining, fused); } - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { remaining->MergeFusionInstructionIntoMultiOutput(fused); } else { - if (remaining->opcode() == HloOpcode::kFusion && - remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { - auto parent_computation = remaining->parent(); - // Create a nested fusion node. - auto remaining_nested_fused = - parent_computation->AddInstruction(HloInstruction::CreateFusion( - remaining->shape(), HloInstruction::FusionKind::kLoop, - remaining)); - TF_CHECK_OK(parent_computation->ReplaceInstruction( - remaining, remaining_nested_fused)); - remaining = remaining_nested_fused; - } remaining->FuseInstructionIntoMultiOutput(fused); } - return remaining; } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index d23822e33e..0019cd7254 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -78,6 +78,10 @@ class MultiOutputFusion : public HloPassInterface { // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + // Recompute reachability for the current computation. void RecomputeReachability(); @@ -101,10 +105,6 @@ class MultiOutputFusion : public HloPassInterface { virtual bool DoProducerConsumerMultiOutputFusion(); private: - // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. - // The other instruction is removed from its parent computation. - HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); - // Update the internal data structures after instr1 and instr2 are fused into // one fusion instruction. void Update(HloInstruction* instr1, HloInstruction* instr2); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 096bbde922..d05e995a95 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -239,7 +239,6 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kNegate: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSign: - case HloOpcode::kSort: return shape; case HloOpcode::kNot: @@ -962,6 +961,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } return result; } + case HloOpcode::kSort: { + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } else if (operand_shapes.size() == 2) { + return ShapeUtil::MakeTupleShape( + {*operand_shapes[0], *operand_shapes[1]}); + } + return InvalidArgument("Unexpected number of operands for sort"); + } default: return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode).c_str()); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 5a45e2e610..20b2885e90 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2040,6 +2040,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 37862fa9cb..5361ae6783 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -373,6 +373,13 @@ class ClientLibraryTestBase : public ::testing::Test { // The float type used in this test, BF16 or F32 according to use_bfloat16. PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + // Executes the computation and calculates the expected reference value using + // the reference client. Returns two literals in the order of (expected, + // actual). + StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> + ComputeValueAndReference(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<Literal> arguments); + Client* client_; Client* ref_client_; // To compute reference result. ExecutionOptions execution_options_; @@ -390,13 +397,6 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the reference client. Returns two literals in the order of (expected, - // actual). - StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> - ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice<Literal> arguments); - // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ba22530f1c..1a396b090c 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.ConstantR0<int32>(42); + auto computation = ConstantR0<int32>(&b, 42); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<int32>(client, computation, &b); @@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f)); + Add(ConstantR0<float>(&b, 42.5f), ConstantR0<float>(&b, 1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); @@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f), - ShapeUtil::MakeShape(F32, {})); + RngUniform(ConstantR0<float>(&b, 1.1f), ConstantR0<float>(&b, 2.1f), + ShapeUtil::MakeShape(F32, {})); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); @@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); @@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0<float>(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Add(ConstantR0<float>(&b, 1.0f), + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar<float>(client, computation, &b); @@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = - b.Add(b.ConstantR0<float>(2.5f), b.ConstantR0<float>(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + Add(ConstantR0<float>(&b, 2.5f), ConstantR0<float>(&b, 1.5f)); + auto not_constant_a = Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1"); auto constant_9 = - b.Mul(b.ConstantR0<float>(2.0f), b.ConstantR0<float>(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + Mul(ConstantR0<float>(&b, 2.0f), ConstantR0<float>(&b, 4.5f)); + auto not_constant_b = Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = Add(constant_4, constant_9); + Add(not_constant_b, Add(constant_13, not_constant_a)); EXPECT_TRUE(IsConstant(constant_13, &b)); @@ -201,7 +201,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4})); + Add(ConstantR1<int32>(&b, {1, 2}), ConstantR1<int32>(&b, {3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, @@ -216,7 +216,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3)); + auto computation = Div(ConstantR0<int32>(&b, 15), ConstantR0<int32>(&b, 3)); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, @@ -237,8 +237,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, ComputeConstantLiteral( client, - b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}), - b.ConstantR2<int32>({{10, 20}, {30, 40}})), + Add(ConstantR2<int32>(&b, {{1, 2}, {3, 4}}), + ConstantR2<int32>(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); std::unique_ptr<Literal> expected_literal = diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 33d79aebb1..cf2e645d47 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -853,10 +853,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSReverseMM)))) { + + DotOfGatherOptimizationWithConstRHSReverseMM) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -883,10 +882,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSReverseMM)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -913,10 +909,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, @@ -948,10 +941,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { std::unique_ptr<Array2D<float>> constant_lhs_array( new Array2D<float>({{1.0, 2.0}, {3.0, 4.0}, @@ -983,10 +973,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr<Array2D<float>> constant_rhs_array( @@ -1010,10 +997,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr<Array2D<float>> constant_rhs_array( @@ -1036,5 +1020,28 @@ XLA_TEST_F(DotOperationTest, Array2D<float> expected({{168.0}, {168.0}}); ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); } + +XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { + XlaBuilder builder(TestName()); + + Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array); + + Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + DotGeneral(lhs_constant, rhs_constant, dot_dnums); + + Array2D<float> expected({ + {26.f, 30.f}, + {38.f, 44.f}, + }); + + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 76bf47845c..fd85118849 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -37,8 +37,7 @@ class HalfTestBase : public ClientLibraryTestBase { static const int kNumElements = 4; }; -using UnaryBuildFuncTy = - std::function<void(xla::XlaBuilder*, const xla::XlaOp& src)>; +using UnaryBuildFuncTy = std::function<void(const xla::XlaOp& src)>; struct UnaryOpTestParam { std::function<half(half)> compute_func; @@ -62,7 +61,7 @@ XLA_TEST_P(UnaryOpTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1<half>(&builder, expected, {x_data.get()}, error_spec_); } @@ -79,18 +78,17 @@ half round_imp(half value) { INSTANTIATE_TEST_CASE_P( half, UnaryOpTest, ::testing::Values( - UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs}, - UnaryOpTestParam{[](half x) { return round_imp(x); }, - &XlaBuilder::Round}, - UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil}, - UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos}, - UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp}, - UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor}, - UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log}, - UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg}, - UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign}, - UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin}, - UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh} + UnaryOpTestParam{[](half x) { return abs(x); }, &Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, &Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, &Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, &Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, &Log}, + UnaryOpTestParam{[](half x) { return -x; }, &Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, &Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh} )); @@ -118,19 +116,18 @@ XLA_TEST_P(UnaryPredTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1<bool>(&builder, expected, {x_data.get()}); } INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ - [](half x) { return isfinite(x); }, - &XlaBuilder::IsFinite})); + [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = std::function<void( - xla::XlaBuilder*, const xla::XlaOp& x, const xla::XlaOp& y, - tensorflow::gtl::ArraySlice<int64>)>; +using BinaryBuildFuncTy = + std::function<void(const xla::XlaOp& x, const xla::XlaOp& y, + tensorflow::gtl::ArraySlice<int64>)>; struct BinaryOpTestParam { std::function<half(half, half)> compute_func; @@ -159,7 +156,7 @@ XLA_TEST_P(BinaryOpTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1<half>(&builder, expected, {x_data.get(), y_data.get()}, error_spec_); @@ -173,22 +170,15 @@ half atan2_imp(half x, half y) { INSTANTIATE_TEST_CASE_P( half, BinaryOpTest, ::testing::Values( - BinaryOpTestParam{[](half x, half y) { return x + y; }, - &XlaBuilder::Add}, + BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add}, BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, - &XlaBuilder::Atan2}, - BinaryOpTestParam{[](half x, half y) { return x / y; }, - &XlaBuilder::Div}, - BinaryOpTestParam{[](half x, half y) { return max(x, y); }, - &XlaBuilder::Max}, - BinaryOpTestParam{[](half x, half y) { return min(x, y); }, - &XlaBuilder::Min}, - BinaryOpTestParam{[](half x, half y) { return x * y; }, - &XlaBuilder::Mul}, - BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, - &XlaBuilder::Pow}, - BinaryOpTestParam{[](half x, half y) { return x - y; }, - &XlaBuilder::Sub} + &Atan2}, + BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div}, + BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max}, + BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min}, + BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul}, + BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow}, + BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub} )); @@ -221,27 +211,22 @@ XLA_TEST_P(BinaryPredTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1<bool>(&builder, expected, {x_data.get(), y_data.get()}); } INSTANTIATE_TEST_CASE_P( half, BinaryPredTest, - ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, - &XlaBuilder::Eq}, - BinaryPredTestParam{[](half x, half y) { return x != y; }, - &XlaBuilder::Ne}, - BinaryPredTestParam{[](half x, half y) { return x >= y; }, - &XlaBuilder::Ge}, - BinaryPredTestParam{[](half x, half y) { return x > y; }, - &XlaBuilder::Gt}, - BinaryPredTestParam{[](half x, half y) { return x <= y; }, - &XlaBuilder::Le}, - BinaryPredTestParam{[](half x, half y) { return x < y; }, - &XlaBuilder::Lt} - - )); + ::testing::Values( + BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq}, + BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne}, + BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge}, + BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt}, + BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le}, + BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt} + + )); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 6154ce671c..5c351b2d11 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -29,14 +29,14 @@ namespace { class PredTest : public ClientLibraryTestBase { protected: - void TestCompare( - bool lhs, bool rhs, bool expected, - XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&, - tensorflow::gtl::ArraySlice<int64>)) { + void TestCompare(bool lhs, bool rhs, bool expected, + std::function<XlaOp(const xla::XlaOp&, const xla::XlaOp&, + tensorflow::gtl::ArraySlice<int64>)> + op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0<bool>(&builder, lhs); XlaOp rhs_op = ConstantR0<bool>(&builder, rhs); - (builder.*op)(lhs_op, rhs_op, {}); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0<bool>(&builder, expected, {}); } }; @@ -54,27 +54,27 @@ TEST_F(PredTest, ConstantR0PredFalse) { } TEST_F(PredTest, ConstantR0PredCompareEq) { - TestCompare(true, false, false, &XlaBuilder::Eq); + TestCompare(true, false, false, &Eq); } TEST_F(PredTest, ConstantR0PredCompareNe) { - TestCompare(true, false, true, &XlaBuilder::Ne); + TestCompare(true, false, true, &Ne); } TEST_F(PredTest, ConstantR0PredCompareLe) { - TestCompare(true, false, false, &XlaBuilder::Le); + TestCompare(true, false, false, &Le); } TEST_F(PredTest, ConstantR0PredCompareLt) { - TestCompare(true, false, false, &XlaBuilder::Lt); + TestCompare(true, false, false, &Lt); } TEST_F(PredTest, ConstantR0PredCompareGe) { - TestCompare(true, false, true, &XlaBuilder::Ge); + TestCompare(true, false, true, &Ge); } TEST_F(PredTest, ConstantR0PredCompareGt) { - TestCompare(true, false, true, &XlaBuilder::Gt); + TestCompare(true, false, true, &Gt); } TEST_F(PredTest, ConstantR1Pred) { diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index d0ebb108ae..bc994315c3 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -44,25 +44,26 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template <typename NativeT> - void TestCompare( - NativeT lhs, NativeT rhs, bool expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice<int64>)) { + void TestCompare(NativeT lhs, NativeT rhs, bool expected, + std::function<XlaOp(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice<int64>)> + op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs); XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs); - (builder.*op)(lhs_op, rhs_op, {}); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0<bool>(&builder, expected, {}); } template <typename NativeT> void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice<int64>)) { + std::function<XlaOp(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice<int64>)> + op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs); XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs); - (builder.*op)(lhs_op, rhs_op, {}); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0<NativeT>(&builder, expected, {}); } }; @@ -583,117 +584,116 @@ XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare<int32>(2, 1, false, &XlaBuilder::Eq); + TestCompare<int32>(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare<int32>(3, 3, true, &XlaBuilder::Eq); + TestCompare<int32>(3, 3, true, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare<int32>(2, 1, true, &XlaBuilder::Ne); + TestCompare<int32>(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare<int32>(2, 1, true, &XlaBuilder::Ge); + TestCompare<int32>(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare<int32>(1, 5, false, &XlaBuilder::Gt); + TestCompare<int32>(1, 5, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare<int32>(2, 1, false, &XlaBuilder::Le); + TestCompare<int32>(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare<int32>(9, 7, false, &XlaBuilder::Lt); + TestCompare<int32>(9, 7, false, &Lt); TestCompare<int32>(std::numeric_limits<int32>::min(), - std::numeric_limits<int32>::max(), true, &XlaBuilder::Lt); + std::numeric_limits<int32>::max(), true, &Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare<uint32>(2, 1, false, &XlaBuilder::Eq); + TestCompare<uint32>(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare<uint32>(2, 1, true, &XlaBuilder::Ne); + TestCompare<uint32>(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare<uint32>(2, 1, true, &XlaBuilder::Ge); + TestCompare<uint32>(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare<uint32>(3, 3, true, &XlaBuilder::Ge); + TestCompare<uint32>(3, 3, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare<uint32>(1, 5, false, &XlaBuilder::Gt); - TestCompare<uint32>(5, 5, false, &XlaBuilder::Gt); - TestCompare<uint32>(5, 1, true, &XlaBuilder::Gt); + TestCompare<uint32>(1, 5, false, &Gt); + TestCompare<uint32>(5, 5, false, &Gt); + TestCompare<uint32>(5, 1, true, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare<uint32>(2, 1, false, &XlaBuilder::Le); + TestCompare<uint32>(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare<uint32>(9, 7, false, &XlaBuilder::Lt); - TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, - &XlaBuilder::Lt); + TestCompare<uint32>(9, 7, false, &Lt); + TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, &Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare<float>(2.0, 1.3, false, &XlaBuilder::Eq); + TestCompare<float>(2.0, 1.3, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare<float>(2.0, 1.3, true, &XlaBuilder::Ne); + TestCompare<float>(2.0, 1.3, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare<float>(2.0, 1.9, true, &XlaBuilder::Ge); + TestCompare<float>(2.0, 1.9, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare<float>(3.5, 3.5, true, &XlaBuilder::Ge); + TestCompare<float>(3.5, 3.5, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare<float>(1.0, 5.2, false, &XlaBuilder::Gt); + TestCompare<float>(1.0, 5.2, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare<float>(2.0, 1.2, false, &XlaBuilder::Le); + TestCompare<float>(2.0, 1.2, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare<float>(9.0, 7.2, false, &XlaBuilder::Lt); + TestCompare<float>(9.0, 7.2, false, &Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare<float>(-INFINITY, -0.0, true, &XlaBuilder::Lt); + TestCompare<float>(-INFINITY, -0.0, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare<float>(-0.0, 0.0, false, &XlaBuilder::Lt); + TestCompare<float>(-0.0, 0.0, false, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare<float>(0.0, INFINITY, true, &XlaBuilder::Lt); + TestCompare<float>(0.0, INFINITY, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare<float>(-INFINITY, -0.0, false, &XlaBuilder::Ge); + TestCompare<float>(-INFINITY, -0.0, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare<float>(-0.0, 0.0, true, &XlaBuilder::Ge); + TestCompare<float>(-0.0, 0.0, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare<float>(0.0, INFINITY, false, &XlaBuilder::Ge); + TestCompare<float>(0.0, INFINITY, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { @@ -813,65 +813,65 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax<int32>(10, 3, 3, &XlaBuilder::Min); + TestMinMax<int32>(10, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax<int32>(-100, 3, -100, &XlaBuilder::Min); + TestMinMax<int32>(-100, 3, -100, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax<int32>(10, 3, 10, &XlaBuilder::Max); + TestMinMax<int32>(10, 3, 10, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax<int32>(-100, 3, 3, &XlaBuilder::Max); + TestMinMax<int32>(-100, 3, 3, &Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits<int32>::max(); - TestMinMax<uint32>(large, 3, 3, &XlaBuilder::Min); + TestMinMax<uint32>(large, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax<uint32>(0, 5, 0, &XlaBuilder::Min); + TestMinMax<uint32>(0, 5, 0, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits<int32>::max(); - TestMinMax<uint32>(large, 3, large, &XlaBuilder::Max); + TestMinMax<uint32>(large, 3, large, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax<uint32>(0, 5, 5, &XlaBuilder::Max); + TestMinMax<uint32>(0, 5, 5, &Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax<float>(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); + TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax<float>(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); + TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { SetFastMathDisabled(true); - TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Min); - TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Min); + TestMinMax<float>(NAN, 3.1f, NAN, &Min); + TestMinMax<float>(-3.1f, NAN, NAN, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax<float>(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); + TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax<float>(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); + TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { SetFastMathDisabled(true); - TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Max); - TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Max); + TestMinMax<float>(NAN, 3.1f, NAN, &Max); + TestMinMax<float>(-3.1f, NAN, NAN, &Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 000535a982..20c7c30878 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -161,6 +161,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( })); break; } + // Token requires no data. + case TOKEN: + break; default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape).c_str()); diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e8f2fb44d8..8f424ae81f 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -53,5 +54,23 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { TF_ASSERT_OK(MakeFakeArguments(&module).status()); } +XLA_TEST_F(TestUtilsTest, Token) { + auto module = ParseHloString( + R"(HloModule outfeed_module + + ENTRY InfeedToOutfeed { + token = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) + })") + .ValueOrDie(); + TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); +} + } // namespace } // namespace xla diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 2d7916c8b1..229b0c481f 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -25,6 +25,7 @@ py_library( "//tensorflow/contrib/all_reduce", "//tensorflow/contrib/batching:batch_py", "//tensorflow/contrib/bayesflow:bayesflow_py", + "//tensorflow/contrib/bigtable", "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/checkpoint/python:checkpoint", "//tensorflow/contrib/cloud:cloud_py", diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 8f09689fe9..a49a4ed05c 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -22,6 +22,7 @@ py_library( "__init__.py", "anno.py", "ast_util.py", + "cfg.py", "compiler.py", "inspect_utils.py", "parser.py", @@ -64,6 +65,17 @@ py_test( ) py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + +py_test( name = "compiler_test", srcs = ["compiler_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py new file mode 100644 index 0000000000..666328781f --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg.py @@ -0,0 +1,733 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph (CFG) structure for Python AST representation. + +The CFG is a digraph with edges representing valid control flow. Each +node is associated with exactly one AST node, but not all AST nodes may have +a corresponding CFG counterpart. + +Once built, the CFG itself is immutable, but the values it holds need not be; +they are usually annotated with information extracted by walking the graph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from enum import Enum + +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.pyct import compiler + + +class Node(object): + """A node in the CFG. + + Although new instances of this class are mutable, the objects that a user + finds in the CFG are typically not. + + The nodes represent edges in the CFG graph, and maintain pointers to allow + efficient walking in both forward and reverse order. The following property + holds for all nodes: "child in node.next" iff "node in child.prev". + + Attributes: + next: FrozenSet[Node, ...], the nodes that follow this node, in control + flow order + prev: FrozenSet[Node, ...], the nodes that precede this node, in reverse + control flow order + ast_node: ast.AST, the AST node corresponding to this CFG node + """ + + def __init__(self, next_, prev, ast_node): + self.next = next_ + self.prev = prev + self.ast_node = ast_node + + def freeze(self): + self.next = frozenset(self.next) + self.prev = frozenset(self.prev) + + def __repr__(self): + return compiler.ast_to_source(self.ast_node).strip() + + +class Graph( + collections.namedtuple('Graph', ['entry', 'exit', 'error', 'index'])): + """A Control Flow Graph. + + The CFG maintains an index to allow looking up a CFG node by the AST node to + which it is associated. The index can also be enumerated in top-down, depth + first order. + + Walking the graph in forward or reverse order is supported by double + parent-child links. + + Note: the error nodes are not wired to their corresponding finally guards, + because these are shared, and wiring them would create a reverse path from + normal control flow into the error nodes, which we want to avoid. + + Attributes: + entry: Node, the entry node + exit: FrozenSet[Node, ...], the exit nodes + error: FrozenSet[Node, ...], nodes that exit due to an explicitly raised + error (errors propagated from function calls are not accounted) + index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG + node + """ + + def __repr__(self): + result = 'digraph CFG {\n' + for node in self.index.values(): + result += ' %s [label="%s"];\n' % (id(node), node) + for node in self.index.values(): + if node.next: + result += ' %s -> {%s};\n' % (id(node), ', '.join( + repr(id(n)) for n in node.next)) + result += '}' + return result + + +class _WalkMode(Enum): + FORWARD = 1 + REVERSE = 2 + + +class GraphVisitor(object): + """Base class for a CFG visitors. + + This implementation is not thread safe. + + The visitor has some facilities to simplify dataflow analyses. In particular, + it allows revisiting the nodes at the decision of the subclass. This can be + used to visit the graph until the state reaches a fixed point. + + For more details on dataflow analysis, see + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Note: the literature generally suggests visiting successor nodes only when the + state of the current node changed, regardless of whether that successor has + ever been visited. This implementation visits every successor at least once. + + Attributes: + graph: Graph + in_: Dict[Node, Any], stores node-keyed state during a visit + out: Dict[Node, Any], stores node-keyed state during a visit + """ + + def reset(self): + self.in_ = { + node: self.init_state(node) for node in self.graph.index.values() + } + self.out = { + node: self.init_state(node) for node in self.graph.index.values() + } + + def init_state(self, node): + """State initialization function. Optional to overload. + + An in/out state slot will be created for each node in the graph. Subclasses + may overload this to control what that is initialized to. + + Args: + node: Node + """ + del node + return None + + def visit_node(self, node): + """Visitor function. + + Args: + node: Node + Returns: + bool, whether the node should be revisited; subclasses can visit every + reachable node exactly once by always returning False + """ + raise NotImplementedError('Subclasses must implement this.') + + def _visit_internal(self, mode): + """Visits the CFG, depth-first.""" + assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE) + if mode == _WalkMode.FORWARD: + open_ = [self.graph.entry] + elif mode == _WalkMode.REVERSE: + open_ = list(self.graph.exit) + closed = set() + self.reset() + + while open_: + node = open_.pop(0) + closed.add(node) + + should_revisit = self.visit_node(node) + + if mode == _WalkMode.FORWARD: + children = node.next + elif mode == _WalkMode.REVERSE: + children = node.prev + + for next_ in children: + if should_revisit or next_ not in closed: + open_.append(next_) + + def visit_forward(self, graph): + self.graph = graph + self._visit_internal(_WalkMode.FORWARD) + + def visit_reverse(self, graph): + self.graph = graph + self._visit_internal(_WalkMode.REVERSE) + + +class GraphBuilder(object): + """Builder that constructs a CFG from a given AST. + + This GraphBuilder facilitates constructing the DAG that forms the CFG when + nodes + are supplied in lexical order (i.e., top-down, depth first). Under these + conditions, it supports building patterns found in typical structured + programs. + + This builder ignores the flow generated by exceptions, which are assumed to + always be catastrophic and present purely for diagnostic purposes (e.g. to + print debug information). Statements like raise and try/catch sections are + allowed and will generate control flow edges, but ordinaty statements are + assumed not to raise exceptions. + + Finally sections are also correctly interleaved between break/continue/return + nodes and their subsequent statements. + + Important concepts: + * nodes - nodes refer refer to CFG nodes; AST nodes are qualified explicitly + * leaf set - since the graph is constructed gradually, a leaf set maintains + the CFG nodes that will precede the node that the builder expects to + receive next; when an ordinary node is added, it is connected to the + existing leaves and it in turn becomes the new leaf + * jump nodes - nodes that should generate edges other than what + ordinary nodes would; these correspond to break, continue and return + statements + * sections - logical delimiters for subgraphs that require special + edges; there are various types of nodes, each admitting various + types of jump nodes; sections are identified by their corresponding AST + node + """ + + # TODO(mdan): Perhaps detail this in a markdown doc. + # TODO(mdan): Add exception support. + + def __init__(self, parent_ast_node): + self.reset() + self.parent = parent_ast_node + + def reset(self): + """Resets the state of this factory.""" + self.head = None + self.errors = set() + self.node_index = collections.OrderedDict() + + # TODO(mdan): Too many primitives. Use classes. + self.leaves = set() + + self.finally_sections = {} + self.finally_section_subgraphs = {} # Values are [begin_node, exit_nodes] + # Whether the guard section can be reached from the statement that precedes + # it. + self.finally_section_has_direct_flow = {} + # Finally sections that await their first node. + self.pending_finally_sections = set() + + # Exit jumps keyed by the section they affect. + self.exits = {} + + # The entry of loop sections, keyed by the section. + self.section_entry = {} + # Continue jumps keyed by the section they affect. + self.continues = {} + + # The entry of conditional sections, keyed by the section. + self.cond_entry = {} + # Lists of leaf nodes corresponding to each branch in the section. + self.cond_leaves = {} + + def _connect_nodes(self, first, second): + """Connects nodes to signify that control flows from first to second. + + Args: + first: Union[Set[Node, ...], Node] + second: Node + """ + if isinstance(first, Node): + first.next.add(second) + second.prev.add(first) + else: + for node in first: + self._connect_nodes(node, second) + + def _add_new_node(self, ast_node): + """Grows the graph by adding a CFG node following the current leaves.""" + if ast_node is self.node_index: + raise ValueError('%s added twice' % ast_node) + node = Node(next_=set(), prev=set(), ast_node=ast_node) + self.node_index[ast_node] = node + + if self.head is None: + self.head = node + + for leaf in self.leaves: + self._connect_nodes(leaf, node) + + # If any finally section awaits its first node, populate it. + for section_id in self.pending_finally_sections: + self.finally_section_subgraphs[section_id][0] = node + self.pending_finally_sections = set() + + return node + + def add_ordinary_node(self, ast_node): + """Grows the graph by adding an ordinary CFG node. + + Ordinary nodes are followed by the next node, in lexical order, that is, + they become the new leaf set. + + Args: + ast_node: ast.AST + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set((node,)) + return node + + def _add_jump_node(self, ast_node, guards): + """Grows the graph by adding a jump node. + + Jump nodes are added to the current leaf set, and the leaf set becomes + empty. If the jump node is the last in a cond section, then it may be added + back to the leaf set by a separate mechanism. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections active for this node + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set() + # The guards themselves may not yet be complete, and will be wired later. + self.finally_sections[node] = guards + return node + + def _connect_jump_to_finally_sections(self, node): + """Connects a jump node to the finally sections protecting it.""" + cursor = set((node,)) + for guard_section_id in self.finally_sections[node]: + guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id] + self._connect_nodes(cursor, guard_begin) + cursor = guard_ends + del self.finally_sections[node] + # TODO(mdan): Should garbage-collect finally_section_subgraphs. + return cursor + + def add_exit_node(self, ast_node, section_id, guards): + """Grows the graph by adding an exit node. + + This node becomes an exit for the current section. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.exits[section_id].add(node) + + def add_continue_node(self, ast_node, section_id, guards): + """Grows the graph by adding a reentry node. + + This node causes control flow to go back to the loop section's entry. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.continues[section_id].add(node) + + def add_error_node(self, ast_node, guards): + """Grows the graph by adding an error node. + + This node becomes an exit for the entire graph. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.errors.add(node) + self.leaves = set() + + def enter_section(self, section_id): + """Enters a regular section. + + Regular sections admit exit jumps, which end the section. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_exit_node + """ + assert section_id not in self.exits + self.exits[section_id] = set() + + def exit_section(self, section_id): + """Exits a regular section.""" + + # Exits are jump nodes, which may be protected. + for exit_ in self.exits[section_id]: + self.leaves |= self._connect_jump_to_finally_sections(exit_) + + del self.exits[section_id] + + def enter_loop_section(self, section_id, entry_node): + """Enters a loop section. + + Loop sections define an entry node. The end of the section always flows back + to the entry node. These admit continue jump nodes which also flow to the + entry node. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_continue_node + entry_node: ast.AST, the entry node into the loop (e.g. the test node + for while loops) + """ + assert section_id not in self.section_entry + assert section_id not in self.continues + self.continues[section_id] = set() + node = self.add_ordinary_node(entry_node) + self.section_entry[section_id] = node + + def exit_loop_section(self, section_id): + """Exits a loop section.""" + self._connect_nodes(self.leaves, self.section_entry[section_id]) + + # continues are jump nodes, which may be protected. + for reentry in self.continues[section_id]: + guard_ends = self._connect_jump_to_finally_sections(reentry) + self._connect_nodes(guard_ends, self.section_entry[section_id]) + + # Loop nodes always loop back. + self.leaves = set((self.section_entry[section_id],)) + + del self.continues[section_id] + del self.section_entry[section_id] + + def enter_cond_section(self, section_id): + """Enters a conditional section. + + Conditional sections define an entry node, and one or more branches. + + Args: + section_id: Hashable, the same node that will be used in calls to the + section_id arg passed to new_cond_branch + """ + + assert section_id not in self.cond_entry + assert section_id not in self.cond_leaves + self.cond_leaves[section_id] = [] + + def new_cond_branch(self, section_id): + """Begins a new branch in a cond section.""" + assert section_id in self.cond_leaves + + if section_id in self.cond_entry: + # Subsequent splits move back to the split point, and memorize the + # current leaves. + self.cond_leaves[section_id].append(self.leaves) + self.leaves = self.cond_entry[section_id] + else: + # If this is the first time we split a section, just remember the split + # point. + self.cond_entry[section_id] = self.leaves + + def exit_cond_section(self, section_id): + """Exits a conditional section.""" + for split in self.cond_leaves[section_id]: + self.leaves |= split + del self.cond_entry[section_id] + del self.cond_leaves[section_id] + + def enter_finally_section(self, section_id): + """Enters a finally section.""" + # TODO(mdan): This, not the caller, should track the active sections. + self.finally_section_subgraphs[section_id] = [None, None] + if self.leaves: + self.finally_section_has_direct_flow[section_id] = True + else: + self.finally_section_has_direct_flow[section_id] = False + self.pending_finally_sections.add(section_id) + + def exit_finally_section(self, section_id): + """Exits a finally section.""" + assert section_id not in self.pending_finally_sections, 'Empty finally?' + self.finally_section_subgraphs[section_id][1] = self.leaves + # If the guard can only be reached by a jump, then it will not flow + # into the statement that follows it. + if not self.finally_section_has_direct_flow[section_id]: + self.leaves = set() + del self.finally_section_has_direct_flow[section_id] + + def build(self): + """Returns the CFG accumulated so far and resets the builder. + + Returns: + Graph + """ + # Freeze the nodes. + for node in self.node_index.values(): + node.freeze() + + result = Graph( + entry=self.head, + exit=self.leaves, + error=self.errors, + index=self.node_index) + + # Reset the state. + self.reset() + + return result + + +class AstToCfg(gast.NodeVisitor): + """Converts an AST to CFGs. + + A separate CFG will be constructed for each function. + """ + + # TODO(mdan): Figure out how to deal with closures. + + def __init__(self): + super(AstToCfg, self).__init__() + + self.builder_stack = [] + self.builder = None + self.cfgs = {} + + self.lexical_scopes = [] + + def _enter_lexical_scope(self, node): + self.lexical_scopes.append(node) + + def _exit_lexical_scope(self, node): + leaving_node = self.lexical_scopes.pop() + assert node == leaving_node + + def _get_enclosing_scopes(self, include, stop_at): + included = [] + for node in reversed(self.lexical_scopes): + if isinstance(node, include): + included.append(node) + if isinstance(node, stop_at): + return node, included + return None, included + + def _process_basic_statement(self, node): + self.generic_visit(node) + self.builder.add_ordinary_node(node) + + def _process_exit_statement(self, node, *exits_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(exits_nodes_of_type), + ) + if try_node is None: + raise ValueError( + '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type)) + self.builder.add_exit_node(node, try_node, guards) + + def _process_continue_statement(self, node, *loops_to_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(loops_to_nodes_of_type), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any of %s' % + (node, loops_to_nodes_of_type)) + self.builder.add_continue_node(node, try_node, guards) + + def visit_FunctionDef(self, node): + self.builder_stack.append(self.builder) + self.builder = GraphBuilder(node) + + self._enter_lexical_scope(node) + self.builder.enter_section(node) + + self._process_basic_statement(node.args) + for stmt in node.body: + self.visit(stmt) + + self.builder.exit_section(node) + self._exit_lexical_scope(node) + + self.cfgs[node] = self.builder.build() + self.builder = self.builder_stack.pop() + + def visit_Lambda(self, node): + # TODO(mdan): Treat like FunctionDef? That would be a separate CFG. + raise NotImplementedError() + + def visit_Return(self, node): + self._process_exit_statement(node, gast.FunctionDef) + + def visit_Expr(self, node): + self._process_basic_statement(node) + + def visit_Assign(self, node): + self._process_basic_statement(node) + + def visit_AnnAssign(self, node): + self._process_basic_statement(node) + + def visit_AugAssign(self, node): + self._process_basic_statement(node) + + def visit_Print(self, node): + self._process_basic_statement(node) + + def visit_Raise(self, node): + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=(gast.FunctionDef,), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any FunctionDef' % node) + self.builder.add_error_node(node, try_node, guards) + + def visit_Assert(self, node): + # Ignoring the effect of exceptions. + self._process_basic_statement(node) + + def visit_Delete(self, node): + self._process_basic_statement(node) + + def visit_If(self, node): + # No need to track ifs as lexical scopes, for now. + # Lexical scopes are generally tracked in order to be able to resolve the + # targets of jump statements like break/continue/etc. Since there is no + # statement that can interrupt a conditional, we don't need to track their + # lexical scope. That may change in the future. + + self.builder.enter_cond_section(node) + self._process_basic_statement(node.test) + + self.builder.new_cond_branch(node) + for stmt in node.body: + self.visit(stmt) + + self.builder.new_cond_branch(node) + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_cond_section(node) + + def visit_While(self, node): + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + self.builder.enter_loop_section(node, node.test) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # the statements inside it don't affect the loop itself. For example, a + # break in the loop's orelse will not affect the loop itself. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + + def visit_For(self, node): + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + # TODO(mdan): Strictly speaking, this should be node.target + node.iter. + # A blind dataflow analysis would have to process both node.target and + # node.iter to properly process read and write access. + self.builder.enter_loop_section(node, node.iter) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # they don't count as loop bodies. For example, a break in the loop's + # orelse will affect the parent loop, not the current one. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + + def visit_Break(self, node): + self._process_exit_statement(node, gast.While, gast.For) + + def visit_Continue(self, node): + self._process_continue_statement(node, gast.While, gast.For) + + def visit_Try(self, node): + self._enter_lexical_scope(node) + + for stmt in node.body: + self.visit(stmt) + # Unlike loops, the orelse is a simple continuation of the body. + for stmt in node.orelse: + self.visit(stmt) + + if node.handlers: + # TODO(mdan): Should we still support bare try/except? Might be confusing. + raise NotImplementedError('exceptions are not yet supported') + + self._exit_lexical_scope(node) + + self.builder.enter_finally_section(node) + for stmt in node.finalbody: + self.visit(stmt) + self.builder.exit_finally_section(node) + + def visit_With(self, node): + # TODO(mdan): Mark the context manager's exit call as exit guard. + self._process_basic_statement(node.items) + for stmt in node.body: + self.visit(stmt) + + +def build(node): + builder = AstToCfg() + builder.visit(node) + return builder.cfgs diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/contrib/autograph/pyct/cfg_test.py new file mode 100644 index 0000000000..00afadd521 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg_test.py @@ -0,0 +1,790 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class CountingVisitor(cfg.GraphVisitor): + + def __init__(self): + self.counts = {} + + def visit_node(self, node): + self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1 + return False # visit only once + + +class GraphVisitorTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs, node + + def test_basic_coverage_forward(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor() + visitor.visit_forward(graph) + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + # The return node should be unreachable in forward direction. + self.assertTrue(fn_node.body[0].body[2] not in visitor.counts) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + def test_basic_coverage_reverse(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor() + visitor.visit_reverse(graph) + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + +class AstToCfgTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs + + def _repr_set(self, node_set): + return set(repr(n) for n in node_set) + + def _as_set(self, elements): + if elements is None: + return frozenset() + elif isinstance(elements, str): + return frozenset((elements,)) + else: + return frozenset(elements) + + def assertGraphMatches(self, graph, edges): + """Tests whether the CFG contains the specified edges.""" + for prev, node_repr, next_ in edges: + matched = False + for cfg_node in graph.index.values(): + if repr(cfg_node) == node_repr: + if (self._as_set(prev) == set(map(repr, cfg_node.prev)) and + self._as_set(next_) == set(map(repr, cfg_node.next))): + matched = True + break + if not matched: + self.fail( + 'match failed for node "%s" in graph:\n%s' % (node_repr, graph)) + + def test_straightline(self): + + def test_fn(a): + a += 1 + a = 2 + a = 3 + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'a += 1'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', 'return'), + ('a = 3', 'return', None), + ), + ) + + def test_straightline_no_return(self): + + def test_fn(a, b): + a = b + 1 + a += max(a) + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a, b', 'a = b + 1'), + ('a = b + 1', 'a += max(a)', None), + ), + ) + + def test_unreachable_code(self): + + def test_fn(a): + return + a += 1 # pylint:disable=unreachable + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'return'), + ('a', 'return', None), + (None, 'a += 1', None), + ), + ) + + def test_branch_straightline(self): + + def test_fn(a): + if a > 0: + a = 1 + else: + a += -1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('(a > 0)', 'a = 1', None), + ('(a > 0)', 'a += -1', None), + ), + ) + + def test_branch_nested(self): + + def test_fn(a): + if a > 0: + if a > 1: + a = 1 + else: + a = 2 + else: + if a > 2: + a = 3 + else: + a = 4 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', ('(a > 1)', '(a > 2)')), + ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', None), + ('(a > 1)', 'a = 2', None), + ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')), + ('(a > 2)', 'a = 3', None), + ('(a > 2)', 'a = 4', None), + ), + ) + + def test_branch_straightline_semi(self): + + def test_fn(a): + if a > 0: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', 'a = 1'), + ('(a > 0)', 'a = 1', None), + ), + ) + + def test_branch_return(self): + + def test_fn(a): + if a > 0: + return + else: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', ('return', 'a = 1')), + ('(a > 0)', 'a = 1', 'a = 2'), + ('(a > 0)', 'return', None), + ('a = 1', 'a = 2', None), + ), + ) + + def test_branch_return_minimal(self): + + def test_fn(a): + if a > 0: + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'return'), + ('(a > 0)', 'return', None), + ), + ) + + def test_while_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', None), + ), + ) + + def test_while_else_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_else_continue(self): + + def test_fn(a): + while a > 0: + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', '(a > 0)'), + ('a = 0', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_else_break(self): + + def test_fn(a): + while a > 0: + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + + def test_while_else_return(self): + + def test_fn(a): + while a > 0: + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_while_nested_straightline(self): + + def test_fn(a): + while a > 0: + while a > 1: + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_while_nested_continue(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 3: + continue + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')), + ('(a > 1)', '(a > 3)', ('continue', 'a = 1')), + ('(a > 3)', 'continue', '(a > 1)'), + ('(a > 3)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_while_nested_break(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 2: + break + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')), + ('(a > 1)', '(a > 2)', ('break', 'a = 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'a = 1', '(a > 1)'), + (('(a > 1)', 'break'), 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + + def test_for_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', None), + ), + ) + + def test_for_else_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_else_continue(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', 'range(0, a)'), + ('(a > 1)', 'a = 0', 'a = 1'), + ('a = 0', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_else_break(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + + def test_for_else_return(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_for_nested_straightline(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')), + ('range(1, a)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_for_nested_continue(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 3: + continue + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)', + ('(a > 3)', 'a = 2')), + ('range(1, a)', '(a > 3)', ('continue', 'b += 1')), + ('(a > 3)', 'continue', 'range(1, a)'), + ('(a > 3)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_for_nested_break(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 2: + break + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')), + ('range(1, a)', '(a > 2)', ('break', 'b += 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'b += 1', 'range(1, a)'), + (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + + def test_complex(self): + + def test_fn(a): + b = 0 + while a > 0: + for b in range(0, a): + if a > 2: + break + if a > 3: + if a > 4: + continue + else: + max(a) + break + b += 1 + else: # for b in range(0, a): + return a + a = 2 + for a in range(1, a): + return b + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')), + ( + ('(a > 0)', 'continue', 'b += 1'), + 'range(0, a)', + ('(a > 2)', 'return a'), + ), + ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')), + ('(a > 3)', '(a > 4)', ('continue', 'max(a)')), + ('(a > 4)', 'max(a)', 'break'), + ('max(a)', 'break', 'a = 2'), + ('(a > 4)', 'continue', 'range(0, a)'), + ('(a > 3)', 'b += 1', 'range(0, a)'), + ('range(0, a)', 'return a', None), + ('break', 'a = 2', '(a > 0)'), + ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')), + ('range(1, a)', 'return b', None), + ('range(1, a)', 'a = 3', None), + ), + ) + + def test_finally_straightline(self): + + def test_fn(a): + try: + a += 1 + finally: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'a += 1', 'a = 2'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_return_finally(self): + + def test_fn(a): + try: + return a + finally: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'return a', 'a = 1'), + ('return a', 'a = 1', None), + (None, 'a = 2', None), + ), + ) + + def test_break_finally(self): + + def test_fn(a): + while a > 0: + try: + break + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'break'), + ('(a > 0)', 'break', 'a = 1'), + ('break', 'a = 1', None), + ), + ) + + def test_continue_finally(self): + + def test_fn(a): + while a > 0: + try: + continue + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', 'continue'), + ('(a > 0)', 'continue', 'a = 1'), + ('continue', 'a = 1', '(a > 0)'), + ), + ) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py index 39eca6e444..4acc4ed66a 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -286,7 +286,7 @@ class Forward(object): # TODO(alexbw): see if we can simplify by visiting breadth-first def visit(self, node): - """Depth-first walking the CFG, applying dataflow information propagation.""" + """Depth-first walking the CFG, applying dataflow info propagation.""" # node.value is None only for the exit CfgNode. if not node.value: return diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 032b859d46..68ead2f760 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -192,7 +192,7 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). + r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). This function computes the Monte-Carlo approximation of an expectation, i.e., diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD new file mode 100644 index 0000000000..5c15d21e35 --- /dev/null +++ b/tensorflow/contrib/bigtable/BUILD @@ -0,0 +1,196 @@ +# Cloud Bigtable client for TensorFlow + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_cc_test", + "tf_py_test", +) + +tf_custom_op_py_library( + name = "bigtable", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + dso = [ + ":python/ops/_bigtable.so", + ], + kernels = [ + ":bigtable_kernels", + ":bigtable_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/data", + ], +) + +tf_custom_op_library( + name = "python/ops/_bigtable.so", + srcs = [ + "kernels/bigtable_kernels.cc", + "kernels/bigtable_lookup_dataset_op.cc", + "kernels/bigtable_prefix_key_dataset_op.cc", + "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_scan_dataset_op.cc", + "ops/bigtable_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_ops", + deps = [":bigtable_ops_op_lib"], +) + +tf_gen_op_libs( + op_lib_names = [ + "bigtable_ops", + "bigtable_test_ops", + ], +) + +tf_kernel_library( + name = "bigtable_kernels", + srcs = [ + "kernels/bigtable_kernels.cc", + "kernels/bigtable_lookup_dataset_op.cc", + "kernels/bigtable_prefix_key_dataset_op.cc", + "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_scan_dataset_op.cc", + ], + deps = [ + ":bigtable_lib_cc", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +# A library for use in the bigtable kernels. +cc_library( + name = "bigtable_lib_cc", + srcs = ["kernels/bigtable_lib.cc"], + hdrs = ["kernels/bigtable_lib.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +cc_library( + name = "bigtable_test_client", + srcs = ["kernels/test_kernels/bigtable_test_client.cc"], + hdrs = ["kernels/test_kernels/bigtable_test_client.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "@com_github_googleapis_googleapis//:bigtable_protos", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + "@com_googlesource_code_re2//:re2", + ], +) + +tf_cc_test( + name = "bigtable_test_client_test", + srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"], + tags = ["manual"], + deps = [ + ":bigtable_test_client", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_test_ops", + deps = [":bigtable_test_ops_op_lib"], +) + +tf_custom_op_library( + name = "python/kernel_tests/_bigtable_test.so", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + "ops/bigtable_test_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "@com_googlesource_code_re2//:re2", + ], +) + +# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h +cc_library( + name = "bigtable_test_kernels", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + ], + copts = tf_copts(), + linkstatic = 1, + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_googlesource_code_re2//:re2", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "bigtable_test_py", + dso = [ + ":python/kernel_tests/_bigtable_test.so", + ], + kernels = [ + ":bigtable_test_kernels", + ":bigtable_test_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_test_ops", + # "//tensorflow/contrib/util:util_py", + # "//tensorflow/python:framework_for_generated_wrappers", + # "//tensorflow/python:platform", + # "//tensorflow/python:util", + # "//tensorflow/python/data", + ], +) + +tf_py_test( + name = "bigtable_ops_test", + size = "small", + srcs = ["python/kernel_tests/bigtable_ops_test.py"], + additional_deps = [ + ":bigtable", + ":bigtable_test_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md new file mode 100644 index 0000000000..ef3c60069e --- /dev/null +++ b/tensorflow/contrib/bigtable/README.md @@ -0,0 +1,10 @@ +# Bigtable # + +[Google Cloud Bigtable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + +<!-- TODO(saeta): Document usage / methods / etc. --> diff --git a/tensorflow/contrib/bigtable/__init__.py b/tensorflow/contrib/bigtable/__init__.py new file mode 100644 index 0000000000..7df054637c --- /dev/null +++ b/tensorflow/contrib/bigtable/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cloud Bigtable Client for TensorFlow. + +This contrib package allows TensorFlow to interface directly with Cloud Bigtable +for high-speed data loading. + +@@BigtableClient +@@BigTable + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'BigTable', + 'BigtableClient', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc new file mode 100644 index 0000000000..0c81951d56 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -0,0 +1,313 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +namespace { + +class BigtableClientOp : public OpKernel { + public: + explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_)); + OP_REQUIRES(ctx, !project_id_.empty(), + errors::InvalidArgument("project_id must be non-empty")); + OP_REQUIRES(ctx, !instance_id_.empty(), + errors::InvalidArgument("instance_id must be non-empty")); + } + + ~BigtableClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete<BigtableClientResource>(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK( + ctx, mgr->LookupOrCreate<BigtableClientResource>( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](BigtableClientResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::shared_ptr<bigtable::DataClient> client = + bigtable::CreateDefaultDataClient( + project_id_, instance_id_, + bigtable::ClientOptions()); + *ret = new BigtableClientResource( + project_id_, instance_id_, std::move(client)); + return Status::OK(); + })); + core::ScopedUnref resource_cleanup(resource); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<BigtableClientResource>())); + } + + private: + string project_id_; + string instance_id_; + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU), + BigtableClientOp); + +class BigtableTableOp : public OpKernel { + public: + explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_)); + OP_REQUIRES(ctx, !table_.empty(), + errors::InvalidArgument("table_name must be non-empty")); + } + + ~BigtableTableOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete<BigtableTableResource>(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + + BigtableClientResource* client_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); + core::ScopedUnref unref_client(client_resource); + + BigtableTableResource* resource; + OP_REQUIRES_OK( + ctx, mgr->LookupOrCreate<BigtableTableResource>( + cinfo_.container(), cinfo_.name(), &resource, + [this, client_resource](BigtableTableResource** ret) { + *ret = new BigtableTableResource(client_resource, table_); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<BigtableTableResource>())); + } + + private: + string table_; // Note: this is const after construction. + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU), + BigtableTableOp); + +class ToBigtableOp : public AsyncOpKernel { + public: + explicit ToBigtableOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())), + /* num_threads = */ 1, /* low_latency_hint = */ false)) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + thread_pool_->Schedule([this, ctx, done]() { + const Tensor* column_families_tensor; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("column_families", &column_families_tensor), done); + OP_REQUIRES_ASYNC( + ctx, column_families_tensor->dims() == 1, + errors::InvalidArgument("`column_families` must be a vector."), done); + + const Tensor* columns_tensor; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done); + OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1, + errors::InvalidArgument("`columns` must be a vector."), + done); + OP_REQUIRES_ASYNC( + ctx, + columns_tensor->NumElements() == + column_families_tensor->NumElements(), + errors::InvalidArgument("len(column_families) != len(columns)"), + done); + + std::vector<string> column_families; + column_families.reserve(column_families_tensor->NumElements()); + std::vector<string> columns; + columns.reserve(column_families_tensor->NumElements()); + for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) { + column_families.push_back(column_families_tensor->flat<string>()(i)); + columns.push_back(columns_tensor->flat<string>()(i)); + } + + DatasetBase* dataset; + OP_REQUIRES_OK_ASYNC( + ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done); + + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr<IteratorBase> iterator; + OP_REQUIRES_OK_ASYNC( + ctx, + dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator), + done); + + int64 timestamp_int; + OP_REQUIRES_OK_ASYNC( + ctx, ParseScalarArgument<int64>(ctx, "timestamp", ×tamp_int), + done); + OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1, + errors::InvalidArgument("timestamp must be >= -1"), + done); + std::chrono::milliseconds timestamp(timestamp_int); + + BigtableTableResource* resource; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done); + core::ScopedUnref resource_cleanup(resource); + + std::vector<Tensor> components; + components.reserve(dataset->output_dtypes().size()); + bool end_of_sequence = false; + do { + ::bigtable::BulkMutation mutation; + // TODO(saeta): Make # of mutations configurable. + for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) { + OP_REQUIRES_OK_ASYNC( + ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), + done); + if (!end_of_sequence) { + OP_REQUIRES_OK_ASYNC( + ctx, + CreateMutation(std::move(components), column_families, columns, + timestamp, &mutation), + done); + } + components.clear(); + } + grpc::Status mutation_status; + std::vector<::bigtable::FailedMutation> failures = + resource->table().BulkApply(std::move(mutation), mutation_status); + if (!failures.empty()) { + for (const auto& failure : failures) { + LOG(ERROR) << "Failure applying mutation on row (" + << failure.original_index() + << "): " << failure.mutation().row_key() + << " - error: " << failure.status().error_message() + << " (Details: " << failure.status().error_details() + << ")."; + } + } + OP_REQUIRES_ASYNC( + ctx, failures.empty() && mutation_status.ok(), + errors::Unknown("Failure while writing to BigTable: ", + mutation_status.error_code(), " - ", + mutation_status.error_message(), " (", + mutation_status.error_details(), + "), # of mutation failures: ", failures.size(), + ". See the log for the specific error details."), + done); + } while (!end_of_sequence); + done(); + }); + } + + private: + static string SanitizeThreadSuffix(string suffix) { + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; + } + + Status CreateMutation(std::vector<Tensor> tensors, + const std::vector<string>& column_families, + const std::vector<string>& columns, + std::chrono::milliseconds timestamp, + ::bigtable::BulkMutation* bulk_mutation) { + if (tensors.size() != column_families.size() + 1) { + return errors::InvalidArgument( + "Iterator produced a set of Tensors shorter than expected"); + } + ::bigtable::SingleRowMutation mutation( + std::move(tensors[0].scalar<string>()())); + for (size_t i = 1; i < tensors.size(); ++i) { + if (!TensorShapeUtils::IsScalar(tensors[i].shape())) { + return errors::Internal("Output tensor ", i, " was not a scalar"); + } + mutation.emplace_back( + ::bigtable::SetCell(column_families[i - 1], columns[i - 1], timestamp, + std::move(tensors[i].scalar<string>()()))); + } + bulk_mutation->emplace_back(std::move(mutation)); + return Status::OK(); + } + + template <typename T> + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar<T>()(); + return Status::OK(); + } + + std::unique_ptr<thread::ThreadPool> thread_pool_; +}; + +REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), + ToBigtableOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc new file mode 100644 index 0000000000..2514575f30 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status) { + if (status.ok()) { + return Status::OK(); + } + auto grpc_code = status.error_code(); + if (status.error_code() == ::grpc::StatusCode::ABORTED || + status.error_code() == ::grpc::StatusCode::UNAVAILABLE || + status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) { + grpc_code = ::grpc::StatusCode::INTERNAL; + } + return Status( + static_cast<::tensorflow::error::Code>(status.error_code()), + strings::StrCat("Error reading from BigTable: ", status.error_message(), + " (Details: ", status.error_details(), ")")); +} + +string RegexFromStringSet(const std::vector<string>& strs) { + CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty."; + std::unordered_set<string> uniq(strs.begin(), strs.end()); + if (uniq.size() == 1) { + return *uniq.begin(); + } + return str_util::Join(uniq, "|"); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h new file mode 100644 index 0000000000..54303cdc5e --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ + +// Note: we use bigtable/client/internal/table.h as this is the no-exception API + +#include "google/cloud/bigtable/data_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status); + +string RegexFromStringSet(const std::vector<string>& strs); + +class BigtableClientResource : public ResourceBase { + public: + BigtableClientResource(string project_id, string instance_id, + std::shared_ptr<bigtable::DataClient> client) + : project_id_(std::move(project_id)), + instance_id_(std::move(instance_id)), + client_(std::move(client)) {} + + std::shared_ptr<bigtable::DataClient> get_client() { return client_; } + + string DebugString() override { + return strings::StrCat("BigtableClientResource(project_id: ", project_id_, + ", instance_id: ", instance_id_, ")"); + } + + private: + const string project_id_; + const string instance_id_; + std::shared_ptr<bigtable::DataClient> client_; +}; + +class BigtableTableResource : public ResourceBase { + public: + BigtableTableResource(BigtableClientResource* client, string table_name) + : client_(client), + table_name_(std::move(table_name)), + table_(client->get_client(), table_name_) { + client_->Ref(); + } + + ~BigtableTableResource() override { client_->Unref(); } + + ::bigtable::noex::Table& table() { return table_; } + + string DebugString() override { + return strings::StrCat( + "BigtableTableResource(client: ", client_->DebugString(), + ", table: ", table_name_, ")"); + } + + private: + BigtableClientResource* client_; // Ownes one ref. + const string table_name_; + ::bigtable::noex::Table table_; +}; + +// BigtableReaderDatasetIterator is an abstract class for iterators from +// datasets that are "readers" (source datasets, not transformation datasets) +// that read from Bigtable. +template <typename Dataset> +class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> { + public: + explicit BigtableReaderDatasetIterator( + const typename DatasetIterator<Dataset>::Params& params) + : DatasetIterator<Dataset>(params), iterator_(nullptr, false) {} + + Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureIteratorInitialized()); + if (iterator_ == reader_->end()) { + grpc::Status status = reader_->Finish(); + if (status.ok()) { + *end_of_sequence = true; + return Status::OK(); + } + return GrpcStatusToTfStatus(status); + } + *end_of_sequence = false; + bigtable::Row& row = *iterator_; + Status s = ParseRow(ctx, row, out_tensors); + // Ensure we always advance. + ++iterator_; + return s; + } + + protected: + virtual ::bigtable::RowRange MakeRowRange() = 0; + virtual ::bigtable::Filter MakeFilter() = 0; + virtual Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector<Tensor>* out_tensors) = 0; + + private: + Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (reader_) { + return Status::OK(); + } + + auto rows = MakeRowRange(); + auto filter = MakeFilter(); + + // Note: the this in `this->dataset()` below is necessary due to namespace + // name conflicts. + reader_.reset(new ::bigtable::RowReader( + this->dataset()->table()->table().ReadRows(rows, filter))); + iterator_ = reader_->begin(); + return Status::OK(); + } + + mutex mu_; + std::unique_ptr<::bigtable::RowReader> reader_ GUARDED_BY(mu_); + ::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc new file mode 100644 index 0000000000..4b6d55a2d3 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -0,0 +1,220 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { + public: + using UnaryDatasetOpKernel::UnaryDatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + BigtableTableResource* table; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); + + std::vector<string> column_families; + std::vector<string> columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + + const uint64 num_outputs = columns.size() + 1; + std::vector<PartialTensorShape> output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = + new Dataset(ctx, input, table, std::move(column_families), + std::move(columns), output_types, std::move(output_shapes)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, + BigtableTableResource* table, + std::vector<string> column_families, + std::vector<string> columns, + const DataTypeVector& output_types, + std::vector<PartialTensorShape> output_shapes) + : GraphDatasetBase(ctx), + input_(input), + table_(table), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + output_types_(output_types), + output_shapes_(std::move(output_shapes)), + filter_(MakeFilter(column_families_, columns_)) { + table_->Ref(); + input_->Ref(); + } + + ~Dataset() override { + table_->Unref(); + input_->Unref(); + } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::BigtableLookupDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableLookupDatasetOp::Dataset"; + } + + private: + static ::bigtable::Filter MakeFilter( + const std::vector<string>& column_families, + const std::vector<string>& columns) { + string column_family_regex = RegexFromStringSet(column_families); + string column_regex = RegexFromStringSet(columns); + + return ::bigtable::Filter::Chain( + ::bigtable::Filter::Latest(1), + ::bigtable::Filter::FamilyRegex(column_family_regex), + ::bigtable::Filter::ColumnRegex(column_regex)); + } + + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); // Sequence requests. + std::vector<Tensor> input_tensors; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &input_tensors, end_of_sequence)); + if (*end_of_sequence) { + return Status::OK(); + } + if (input_tensors.size() != 1) { + return errors::InvalidArgument( + "Upstream iterator (", dataset()->input_->DebugString(), + ") did not produce a single `tf.string` `tf.Tensor`. It " + "produced ", + input_tensors.size(), " tensors."); + } + if (input_tensors[0].NumElements() == 0) { + return errors::InvalidArgument("Upstream iterator (", + dataset()->input_->DebugString(), + ") return an empty set of keys."); + } + if (input_tensors[0].NumElements() == 1) { + // Single key lookup. + ::grpc::Status status; + auto pair = dataset()->table_->table().ReadRow( + input_tensors[0].scalar<string>()(), dataset()->filter_, status); + if (!status.ok()) { + return GrpcStatusToTfStatus(status); + } + if (!pair.first) { + return errors::DataLoss("Row key '", + input_tensors[0].scalar<string>()(), + "' not found."); + } + TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors)); + } else { + // Batched get. + return errors::Unimplemented( + "BigtableLookupDataset doesn't yet support batched retrieval."); + } + return Status::OK(); + } + + private: + Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector<Tensor>* out_tensors) { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar<string>()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar<string>()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::DataLoss("Column ", dataset()->column_families_[i], + ":", dataset()->columns_[i], + " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + + mutex mu_; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + BigtableTableResource* table_; + const std::vector<string> column_families_; + const std::vector<string> columns_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + const ::bigtable::Filter filter_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU), + BigtableLookupDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc new file mode 100644 index 0000000000..3d5c3cfdaa --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = new Dataset(ctx, resource, std::move(prefix)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix) + : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtablePrefixKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator<Dataset>(params) {} + + ::bigtable::RowRange MakeRowRange() override { + return ::bigtable::RowRange::Prefix(dataset()->prefix_); + } + ::bigtable::Filter MakeFilter() override { + return ::bigtable::Filter::Chain( + ::bigtable::Filter::CellsRowLimit(1), + ::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector<Tensor>* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar<string>()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string prefix_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU), + BigtablePrefixKeyDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc new file mode 100644 index 0000000000..7fa06052c5 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableRangeKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<string>(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = + new Dataset(ctx, resource, std::move(start_key), std::move(end_key)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string start_key, string end_key) + : GraphDatasetBase(ctx), + table_(table), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableRangeKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator<Dataset>(params) {} + + ::bigtable::RowRange MakeRowRange() override { + return ::bigtable::RowRange::Range(dataset()->start_key_, + dataset()->end_key_); + } + ::bigtable::Filter MakeFilter() override { + return ::bigtable::Filter::Chain( + ::bigtable::Filter::CellsRowLimit(1), + ::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector<Tensor>* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar<string>()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string start_key_; + const string end_key_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU), + BigtableRangeKeyDatasetOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc new file mode 100644 index 0000000000..11b9bd2bdc --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -0,0 +1,214 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableScanDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix)); + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<string>(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key)); + + OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()), + errors::InvalidArgument( + "Either prefix or start_key must be specified")); + OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), + errors::InvalidArgument( + "Only one of prefix and start_key can be provided")); + if (!prefix.empty()) { + OP_REQUIRES(ctx, end_key.empty(), + errors::InvalidArgument( + "If prefix is specified, end_key must be empty.")); + } + + std::vector<string> column_families; + std::vector<string> columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + OP_REQUIRES(ctx, !column_families.empty(), + errors::InvalidArgument("`column_families` is empty")); + + float probability = 0; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<float>(ctx, "probability", &probability)); + OP_REQUIRES( + ctx, probability > 0 && probability <= 1, + errors::InvalidArgument( + "Probability outside the range of (0, 1]. Got: ", probability)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + const uint64 num_outputs = columns.size() + 1; + std::vector<PartialTensorShape> output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = new Dataset(ctx, resource, std::move(prefix), + std::move(start_key), std::move(end_key), + std::move(column_families), std::move(columns), + probability, output_types, std::move(output_shapes)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix, string start_key, string end_key, + std::vector<string> column_families, + std::vector<string> columns, float probability, + const DataTypeVector& output_types, + std::vector<PartialTensorShape> output_shapes) + : GraphDatasetBase(ctx), + table_(table), + prefix_(std::move(prefix)), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + column_family_regex_(RegexFromStringSet(column_families_)), + column_regex_(RegexFromStringSet(columns_)), + probability_(probability), + output_types_(output_types), + output_shapes_(std::move(output_shapes)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::BigtableScanDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableScanDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator<Dataset>(params) {} + + ::bigtable::RowRange MakeRowRange() override { + if (!dataset()->prefix_.empty()) { + DCHECK(dataset()->start_key_.empty()); + return ::bigtable::RowRange::Prefix(dataset()->prefix_); + } else { + DCHECK(!dataset()->start_key_.empty()) + << "Both prefix and start_key were empty!"; + return ::bigtable::RowRange::Range(dataset()->start_key_, + dataset()->end_key_); + } + } + ::bigtable::Filter MakeFilter() override { + // TODO(saeta): Investigate optimal ordering here. + return ::bigtable::Filter::Chain( + ::bigtable::Filter::Latest(1), + ::bigtable::Filter::FamilyRegex(dataset()->column_family_regex_), + ::bigtable::Filter::ColumnRegex(dataset()->column_regex_), + dataset()->probability_ != 1.0 + ? ::bigtable::Filter::RowSample(dataset()->probability_) + : ::bigtable::Filter::PassAllFilter()); + } + Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row, + std::vector<Tensor>* out_tensors) override { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar<string>()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar<string>()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::InvalidArgument( + "Column ", dataset()->column_families_[i], ":", + dataset()->columns_[i], " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + }; + + BigtableTableResource* table_; + const string prefix_; + const string start_key_; + const string end_key_; + const std::vector<string> column_families_; + const std::vector<string> columns_; + const string column_family_regex_; + const string column_regex_; + const float probability_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU), + BigtableScanDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc new file mode 100644 index 0000000000..0f107f169c --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -0,0 +1,367 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" + +#include "google/bigtable/v2/data.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "re2/re2.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/ptr_util.h" +// #include "util/task/codes.pb.h" + +namespace tensorflow { +namespace { + +void UpdateRow(const ::google::bigtable::v2::Mutation& mut, + std::map<string, string>* row) { + if (mut.has_set_cell()) { + auto col = + strings::Printf("%s:%s", mut.set_cell().family_name().c_str(), + string(mut.set_cell().column_qualifier()).c_str()); + (*row)[col] = string(mut.set_cell().value()); + } else if (mut.has_delete_from_column()) { + auto col = strings::Printf( + "%s:%s", mut.delete_from_column().family_name().c_str(), + string(mut.delete_from_column().column_qualifier()).c_str()); + row->erase(col); + } else if (mut.has_delete_from_family()) { + auto itr = row->lower_bound(mut.delete_from_family().family_name()); + auto prefix = + strings::Printf("%s:", mut.delete_from_family().family_name().c_str()); + while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) { + row->erase(itr); + } + } else if (mut.has_delete_from_row()) { + row->clear(); + } else { + LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString(); + } +} + +} // namespace + +class SampleRowKeysResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::SampleRowKeysResponse> { + public: + explicit SampleRowKeysResponse(BigtableTestClient* client) + : client_(client) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::SampleRowKeysResponse(); + resp->set_row_key(client_->table_.rows.begin()->first); + resp->set_offset_bytes(0); + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + mutex mu_; + bool sent_first_message_ GUARDED_BY(mu_) = false; + BigtableTestClient* client_; // Not owned. +}; + +class ReadRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::ReadRowsResponse> { + public: + ReadRowsResponse(BigtableTestClient* client, + google::bigtable::v2::ReadRowsRequest const& request) + : client_(client), request_(request) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::ReadRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + RowFilter filter = MakeRowFilter(); + + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::ReadRowsResponse(); + // Send all contents in first response. + for (auto itr = client_->table_.rows.begin(); + itr != client_->table_.rows.end(); ++itr) { + if (filter.AllowRow(itr->first)) { + ::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr; + bool sent_first = false; + for (auto col_itr = itr->second.columns.begin(); + col_itr != itr->second.columns.end(); ++col_itr) { + if (filter.AllowColumn(col_itr->first)) { + chunk = resp->add_chunks(); + if (!sent_first) { + sent_first = true; + chunk->set_row_key(itr->first); + } + auto colon_idx = col_itr->first.find(":"); + CHECK(colon_idx != string::npos) + << "No ':' found in: " << col_itr->first; + chunk->mutable_family_name()->set_value( + string(col_itr->first, 0, colon_idx)); + chunk->mutable_qualifier()->set_value( + string(col_itr->first, ++colon_idx)); + if (!filter.strip_values) { + chunk->set_value(col_itr->second); + } + if (filter.only_one_column) { + break; + } + } + } + if (sent_first) { + // We are sending this row, so set the commit flag on the last chunk. + chunk->set_commit_row(true); + } + } + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + struct RowFilter { + std::set<string> row_set; + std::vector<std::pair<string, string>> row_ranges; + double row_sample = 0.0; // Note: currently ignored. + std::unique_ptr<RE2> col_filter; + bool strip_values = false; + bool only_one_column = false; + + bool AllowRow(const string& row) { + if (row_set.find(row) != row_set.end()) { + return true; + } + for (const auto& range : row_ranges) { + if (range.first <= row && range.second > row) { + return true; + } + } + return false; + } + + bool AllowColumn(const string& col) { + if (col_filter) { + return RE2::FullMatch(col, *col_filter); + } else { + return true; + } + } + }; + + RowFilter MakeRowFilter() { + RowFilter filter; + for (auto i = request_.rows().row_keys().begin(); + i != request_.rows().row_keys().end(); ++i) { + filter.row_set.insert(string(*i)); + } + for (auto i = request_.rows().row_ranges().begin(); + i != request_.rows().row_ranges().end(); ++i) { + if (i->start_key_case() != + google::bigtable::v2::RowRange::kStartKeyClosed || + i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) { + LOG(WARNING) << "Skipping row range that cannot be processed: " + << i->ShortDebugString(); + continue; + } + filter.row_ranges.emplace_back(std::make_pair( + string(i->start_key_closed()), string(i->end_key_open()))); + } + if (request_.filter().has_chain()) { + string family_filter; + string qualifier_filter; + for (auto i = request_.filter().chain().filters().begin(); + i != request_.filter().chain().filters().end(); ++i) { + switch (i->filter_case()) { + case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter: + family_filter = i->family_name_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter: + qualifier_filter = i->column_qualifier_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter: + if (i->cells_per_column_limit_filter() != 1) { + LOG(ERROR) << "Unexpected cells_per_column_limit_filter: " + << i->cells_per_column_limit_filter(); + } + break; + case google::bigtable::v2::RowFilter::kStripValueTransformer: + filter.strip_values = i->strip_value_transformer(); + break; + case google::bigtable::v2::RowFilter::kRowSampleFilter: + LOG(INFO) << "Ignoring row sample directive."; + break; + case google::bigtable::v2::RowFilter::kPassAllFilter: + break; + case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter: + filter.only_one_column = true; + break; + default: + LOG(WARNING) << "Ignoring unknown filter type: " + << i->ShortDebugString(); + } + } + if (family_filter.empty() || qualifier_filter.empty()) { + LOG(WARNING) << "Missing regex!"; + } else { + string regex = strings::Printf("%s:%s", family_filter.c_str(), + qualifier_filter.c_str()); + filter.col_filter.reset(new RE2(regex)); + } + } else { + LOG(WARNING) << "Read request did not have a filter chain specified: " + << request_.filter().DebugString(); + } + return filter; + } + + mutex mu_; + bool sent_first_message_ GUARDED_BY(mu_) = false; + BigtableTestClient* client_; // Not owned. + const google::bigtable::v2::ReadRowsRequest request_; +}; + +class MutateRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::MutateRowsResponse> { + public: + explicit MutateRowsResponse(size_t num_successes) + : num_successes_(num_successes) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::MutateRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + *resp = google::bigtable::v2::MutateRowsResponse(); + for (size_t i = 0; i < num_successes_; ++i) { + auto entry = resp->add_entries(); + entry->set_index(i); + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + const size_t num_successes_; + + mutex mu_; + bool sent_first_message_ = false; +}; + +grpc::Status BigtableTestClient::MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) { + mutex_lock l(mu_); + auto* row = &table_.rows[string(request.row_key())]; + for (int i = 0; i < request.mutations_size(); ++i) { + UpdateRow(request.mutations(i), &row->columns); + } + *response = google::bigtable::v2::MutateRowResponse(); + return grpc::Status::OK; +} +grpc::Status BigtableTestClient::CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "CheckAndMutateRow not implemented."); +} +grpc::Status BigtableTestClient::ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "ReadModifyWriteRow not implemented."); +} +std::unique_ptr< + grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>> +BigtableTestClient::ReadRows( + grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) { + return MakeUnique<ReadRowsResponse>(this, request); +} + +std::unique_ptr< + grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>> +BigtableTestClient::SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) { + return MakeUnique<SampleRowKeysResponse>(this); +} +std::unique_ptr< + grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>> +BigtableTestClient::MutateRows( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) { + mutex_lock l(mu_); + for (auto i = request.entries().begin(); i != request.entries().end(); ++i) { + auto* row = &table_.rows[string(i->row_key())]; + for (auto mut = i->mutations().begin(); mut != i->mutations().end(); + ++mut) { + UpdateRow(*mut, &row->columns); + } + } + return MakeUnique<MutateRowsResponse>(request.entries_size()); +} + +std::shared_ptr<grpc::Channel> BigtableTestClient::Channel() { + LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " + "cause a crash!"; + return nullptr; +} +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h new file mode 100644 index 0000000000..dcce6a33a7 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ + +#include "google/cloud/bigtable/data_client.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class BigtableTestClient : public ::bigtable::DataClient { + public: + std::string const& project_id() const override { return project_id_; } + std::string const& instance_id() const override { return instance_id_; } + void reset() override { + mutex_lock l(mu_); + table_ = Table(); + } + + grpc::Status MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) override; + + grpc::Status CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) override; + + grpc::Status ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + + std::unique_ptr< + grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>> + ReadRows(grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) override; + std::unique_ptr< + grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>> + SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) override; + + std::unique_ptr< + grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>> + MutateRows(grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) override; + + std::shared_ptr<grpc::Channel> Channel() override; + + private: + friend class SampleRowKeysResponse; + friend class ReadRowsResponse; + friend class MutateRowsResponse; + + struct Row { + string row_key; + std::map<string, string> columns; + }; + struct Table { + std::map<string, Row> rows; + }; + + mutex mu_; + const std::string project_id_ = "testproject"; + const std::string instance_id_ = "testinstance"; + Table table_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc new file mode 100644 index 0000000000..f9be9ec6e2 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { + +namespace { + +class BigtableTestClientOp : public OpKernel { + public: + explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~BigtableTestClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete<BigtableClientResource>(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK(ctx, + mgr->LookupOrCreate<BigtableClientResource>( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](BigtableClientResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::shared_ptr<bigtable::DataClient> client( + new BigtableTestClient()); + // Note: must make explicit copies to sequence + // them before the move of client. + string project_id = client->project_id(); + string instance_id = client->instance_id(); + *ret = new BigtableClientResource( + std::move(project_id), + std::move(instance_id), std::move(client)); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<BigtableClientResource>())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU), + BigtableTestClientOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc new file mode 100644 index 0000000000..bd362f7de5 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc @@ -0,0 +1,279 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void WriteCell(const string& row, const string& family, const string& column, + const string& value, ::bigtable::noex::Table* table) { + ::bigtable::SingleRowMutation mut(row); + mut.emplace_back(::bigtable::SetCell(family, column, value)); + table->Apply(std::move(mut)); +} + +TEST(BigtableTestClientTest, EmptyRowRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared<BigtableTestClient>(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + ::bigtable::RowSet rowset; + rowset.Append("r1"); + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!"; + EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows."; +} + +TEST(BigtableTestClientTest, SingleRowWriteAndRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared<BigtableTestClient>(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + + ::bigtable::RowSet rowset("r1"); + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "No rows were returned in response!"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()); + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared<BigtableTestClient>(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::bigtable::RowSet rowset("r1"); + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared<BigtableTestClient>(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::bigtable::RowSet rowset("r1", "r2", "r3"); + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared<BigtableTestClient>(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, ColumnFiltering) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared<BigtableTestClient>(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::bigtable::Filter::Chain( + ::bigtable::Filter::Latest(1), ::bigtable::Filter::FamilyRegex("f1"), + ::bigtable::Filter::ColumnRegex("c1")); + auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, RowKeys) { + std::shared_ptr<::bigtable::DataClient> client_ptr = + std::make_shared<BigtableTestClient>(); + ::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::bigtable::Filter::Chain( + ::bigtable::Filter::Latest(1), ::bigtable::Filter::CellsRowLimit(1), + ::bigtable::Filter::StripValueTransformer()); + auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc new file mode 100644 index 0000000000..17ecc3dcb2 --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// TODO(saeta): Add support for setting ClientOptions values. +REGISTER_OP("BigtableClient") + .Attr("project_id: string") + .Attr("instance_id: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Add support for Application Profiles. +// See https://cloud.google.com/bigtable/docs/app-profiles for more info. +REGISTER_OP("BigtableTable") + .Input("client: resource") + .Attr("table_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("table: resource") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("DatasetToBigtable") + .Input("table: resource") + .Input("input_dataset: variant") + .Input("column_families: string") + .Input("columns: string") + .Input("timestamp: int64") + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("BigtableLookupDataset") + .Input("keys_dataset: variant") + .Input("table: resource") + .Input("column_families: string") + .Input("columns: string") + .Output("handle: variant") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtablePrefixKeyDataset") + .Input("table: resource") + .Input("prefix: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableRangeKeyDataset") + .Input("table: resource") + .Input("start_key: string") + .Input("end_key: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Support continuing despite bad data (e.g. empty string, or +// skip incomplete row.) +REGISTER_OP("BigtableScanDataset") + .Input("table: resource") + .Input("prefix: string") + .Input("start_key: string") + .Input("end_key: string") + .Input("column_families: string") + .Input("columns: string") + .Input("probability: float") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc new file mode 100644 index 0000000000..f7d02458f6 --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("BigtableTestClient") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py new file mode 100644 index 0000000000..292d8f4e51 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains tests for the bigtable integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py new file mode 100644 index 0000000000..d33a66f2df --- /dev/null +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -0,0 +1,132 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bigtable Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import bigtable +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable_test.so")) + + +class BigtableOpsTest(test.TestCase): + COMMON_ROW_KEYS = ["r1", "r2", "r3"] + COMMON_VALUES = ["v1", "v2", "v3"] + + def setUp(self): + self._client = gen_bigtable_test_ops.bigtable_test_client() + table = gen_bigtable_ops.bigtable_table(self._client, "testtable") + self._table = bigtable.BigTable("testtable", None, table) + + def _makeSimpleDataset(self): + output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS) + output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES) + return dataset_ops.Dataset.zip((output_rows, output_values)) + + def _writeCommonValues(self, sess): + output_ds = self._makeSimpleDataset() + write_op = self._table.write(output_ds, ["cf1"], ["c1"]) + sess.run(write_op) + + def runReadKeyTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected = list(self.COMMON_ROW_KEYS) + expected.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output), + "Unequal at step %d: want: %s, got: %s" % (i, want, output)) + + def testReadPrefixKeys(self): + self.runReadKeyTest(self._table.keys_by_prefix_dataset("r")) + + def testReadRangeKeys(self): + self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) + + def runScanTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_keys.reverse() + expected_values = list(self.COMMON_VALUES) + expected_values.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected_keys.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0])) + want = expected_values.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[1]), + "Unequal values at step: %d: want: %s, got: %s" % (i, want, + output[1])) + + def testScanPrefixStringCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1="c1")) + + def testScanPrefixListCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1=["c1"])) + + def testScanRangeStringCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1")) + + def testScanRangeListCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"])) + + def testLookup(self): + ds = self._table.keys_by_prefix_dataset("r") + ds = ds.apply(self._table.lookup_columns(cf1="c1")) + itr = ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_values = list(self.COMMON_VALUES) + expected_tuples = zip(expected_keys, expected_values) + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i, elem in enumerate(expected_tuples): + output = sess.run(n) + self.assertEqual( + compat.as_bytes(elem[0]), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0]))) + self.assertEqual( + compat.as_bytes(elem[1]), compat.as_bytes(output[1]), + "Unequal values at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1]))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bigtable/python/ops/__init__.py b/tensorflow/contrib/bigtable/python/ops/__init__.py new file mode 100644 index 0000000000..36d75b0d70 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/ops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains the Python API for the Cloud Bigtable integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py new file mode 100644 index 0000000000..a54e020ed7 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -0,0 +1,480 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Python API for TensorFlow's Bigtable integration. + +TensorFlow has support for reading from and writing to Cloud Bigtable. To use +the Bigtable TensorFlow integration, first create a BigtableClient (which +configures your connection to Cloud Bigtable), and then open a Table. The Table +object then allows you to create numerous @{tf.data.Dataset}s to read data, or +write a @{tf.data.Dataset} object to the underlying Bigtable Table. + +For background on Google Cloud Bigtable, see: https://cloud.google.com/bigtable. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six import iteritems + +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import resource_loader + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable.so")) + + +class BigtableClient(object): + """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF. + + BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the + `table` method to open a Bigtable Table. + """ + + def __init__(self, project_id, instance_id): + """Creates a BigtableClient that can be used to open connections to tables. + + Args: + project_id: A string representing the GCP project id to connect to. + instance_id: A string representing the Bigtable instance to connect to. + """ + self._project_id = project_id + self._instance_id = instance_id + self._resource = gen_bigtable_ops.bigtable_client(project_id, instance_id) + + def table(self, name, snapshot=None): + """Opens a table and returns a `BigTable` object. + + Args: + name: A `tf.string` `tf.Tensor` name of the table to open. + snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to + request the creation of a snapshot. (Note: currently unimplemented.) + + Returns: + A `BigTable` python object representing the operations available on the + table. + """ + # TODO(saeta): Implement snapshot functionality. + table = gen_bigtable_ops.bigtable_table(self._resource, name) + return BigTable(name, snapshot, table) + + +class BigTable(object): + """BigTable is the entrypoint for reading and writing data in Cloud Bigtable. + + This BigTable class is the python representation of the Cloud Bigtable table + within TensorFlow. Methods on this class allow data to be read from and + written to the Cloud Bigtable service in flexible and high performance + manners. + """ + + # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface. + # TODO(saeta): Consider variant tensors instead of resources (while supporting + # connection pooling). + + def __init__(self, name, snapshot, resource): + self._name = name + self._snapshot = snapshot + self._resource = resource + + def lookup_columns(self, *args, **kwargs): + """Retrieves the values of columns for a dataset of keys. + + Example usage: + ``` + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns(("cf1", "image"), + ("cf2", "label"), + ("cf2", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Alternatively, you can use keyword arguments to specify the columns to + capture. Example (same as above, rewritten): + ``` + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns( + cf1="image", cf2=("label", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Note: certain kwargs keys are reserved, and thus some column families cannot + be identified using the kwargs syntax. Instead, please use the args syntax. + This list includes: + - 'name' + This list can change at any time. + + Args: + *args: A list of tuples containing (column family, column name) pairs. + **kwargs: Column families and + + Returns: + A function that can be passed to `tf.data.Dataset.apply` to retrieve the + values of columns for the rows. + """ + table = self # Capture self + normalized = args + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + def _apply_fn(dataset): + # TODO(saeta): Verify dataset's types are correct! + return _BigtableLookupDataset(dataset, table, normalized) + + return _apply_fn + + def keys_by_range_dataset(self, start, end): + """Retrieves all row keys between start and end. + + Note: it does NOT retrieve the values of columns. + + Args: + start: The start row key. The row keys for rows after start (inclusive) + will be retrieved. + end: (Optional.) The end row key. Rows up to (but not including) end will + be retrieved. If end is None, all subsequent row keys will be retrieved. + + Returns: + A @{tf.data.Dataset} containing `tf.string` Tensors corresponding to all + of the row keys between `start` and `end`. + """ + # TODO(saeta): Make inclusive / exclusive configurable? + if end is None: + end = "" + return _BigtableRangeKeyDataset(self, start, end) + + def keys_by_prefix_dataset(self, prefix): + """Retrieves the row keys matching a given prefix. + + Args: + prefix: All row keys that begin with `prefix` in the table will be + retrieved. + + Returns: + A @{tf.data.Dataset}. containing `tf.string` Tensors corresponding to all + of the row keys matching that prefix. + """ + return _BigtablePrefixKeyDataset(self, prefix) + + def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): + """Retrieves row (including values) from the Bigtable service. + + Rows with row-key prefixed by `prefix` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + prefix: The prefix all row keys muat match to be retrieved for prefix- + based scans. + probability: Probabilistically sample rows. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + if probability is None: + probability = 1.0 + if isinstance(probability, float) and (probability <= 0.0 or + probability > 1.0): + raise ValueError("probability must be in the range (0, 1].") + + normalized = columns + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + + def scan_range(self, start, end, probability=None, columns=None, **kwargs): + """Retrieves rows (including values) from the Bigtable service. + + Rows with row-keys between `start` and `end` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + start: The start of the range when scanning by range. + end: (Optional.) The end of the range when scanning by range. + probability: Probabilistically sample rows. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + if probability is None: + probability = 1.0 + if isinstance(probability, float) and (probability <= 0.0 or + probability > 1.0): + raise ValueError("probability must be in the range (0, 1].") + + normalized = columns + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + return _BigtableScanDataset(self, "", start, end, normalized, probability) + + def write(self, dataset, column_families, columns, timestamp=None): + """Writes a dataset to the table. + + Args: + dataset: A @{tf.data.Dataset} to be written to this table. It must produce + a list of number-of-columns+1 elements, all of which must be strings. + The first value will be used as the row key, and subsequent values will + be used as cell values for the corresponding columns from the + corresponding column_families and columns entries. + column_families: A @{tf.Tensor} of `tf.string`s corresponding to the + column names to store the dataset's elements into. + columns: A `tf.Tensor` of `tf.string`s corresponding to the column names + to store the dataset's elements into. + timestamp: (Optional.) An int64 timestamp to write all the values at. + Leave as None to use server-provided timestamps. + + Returns: + A @{tf.Operation} that can be run to perform the write. + + Raises: + ValueError: If there are unexpected or incompatible types, or if the + number of columns and column_families does not match the output of + `dataset`. + """ + if timestamp is None: + timestamp = -1 # Bigtable server provided timestamp. + for tensor_type in nest.flatten(dataset.output_types): + if tensor_type != dtypes.string: + raise ValueError("Not all elements of the dataset were `tf.string`") + for shape in nest.flatten(dataset.output_shapes): + if not shape.is_compatible_with(tensor_shape.scalar()): + raise ValueError("Not all elements of the dataset were scalars") + if len(column_families) != len(columns): + raise ValueError("len(column_families) != len(columns)") + if len(nest.flatten(dataset.output_types)) != len(columns) + 1: + raise ValueError("A column name must be specified for every component of " + "the dataset elements. (e.g.: len(columns) != " + "len(dataset.output_types))") + return gen_bigtable_ops.dataset_to_bigtable( + self._resource, + dataset._as_variant_tensor(), # pylint: disable=protected-access + column_families, + columns, + timestamp) + + +class _BigtableKeyDataset(dataset_ops.Dataset): + """_BigtableKeyDataset is an abstract class representing the keys of a table. + """ + + def __init__(self, table): + """Constructs a _BigtableKeyDataset. + + Args: + table: a Bigtable class. + """ + super(_BigtableKeyDataset, self).__init__() + self._table = table + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.TensorShape([]) + + @property + def output_types(self): + return dtypes.string + + +class _BigtablePrefixKeyDataset(_BigtableKeyDataset): + """_BigtablePrefixKeyDataset represents looking up keys by prefix. + """ + + def __init__(self, table, prefix): + super(_BigtablePrefixKeyDataset, self).__init__(table) + self._prefix = prefix + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_prefix_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix) + + +class _BigtableRangeKeyDataset(_BigtableKeyDataset): + """_BigtableRangeKeyDataset represents looking up keys by range. + """ + + def __init__(self, table, start, end): + super(_BigtableRangeKeyDataset, self).__init__(table) + self._start = start + self._end = end + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_range_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + start_key=self._start, + end_key=self._end) + + +class _BigtableLookupDataset(dataset_ops.Dataset): + """_BigtableLookupDataset represents a dataset that retrieves values for keys. + """ + + def __init__(self, dataset, table, normalized): + self._num_outputs = len(normalized) + 1 # 1 for row key + self._dataset = dataset + self._table = table + self._normalized = normalized + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._as_variant_tensor(), + table=self._table._resource, + column_families=self._column_families, + columns=self._columns) + + +class _BigtableScanDataset(dataset_ops.Dataset): + """_BigtableScanDataset represents a dataset that retrieves keys and values. + """ + + def __init__(self, table, prefix, start, end, normalized, probability): + self._table = table + self._prefix = prefix + self._start = start + self._end = end + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + self._probability = probability + self._num_outputs = len(normalized) + 1 # 1 for row key + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_scan_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix, + start_key=self._start, + end_key=self._end, + column_families=self._column_families, + columns=self._columns, + probability=self._probability) diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index 1a7a3759ba..523a9efcf0 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -50,6 +50,7 @@ py_library( deps = [ ":gen_bigquery_reader_ops", ":gen_gcs_config_ops", + "//tensorflow/contrib/bigtable", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/README.md b/tensorflow/contrib/cloud/README.md new file mode 100644 index 0000000000..134ce057f4 --- /dev/null +++ b/tensorflow/contrib/cloud/README.md @@ -0,0 +1,18 @@ +# Cloud # + +## BigTable ## + +[Google Cloud BigTable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + +<!-- TODO(saeta): Document usage / methods / etc. --> + +## Cloud Storage (GCS) ## + +The Google Cloud Storage ops allow the user to configure the GCS File System. + +<!-- TODO(saeta): Document usage / methods / etc. --> diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index ef7aa7624c..af81106a68 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -18,15 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=line-too-long,wildcard-import +import os + +# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * -# pylint: enable=line-too-long,wildcard-import + +if os.name != 'nt': + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient + +del os from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'BigQueryReader', + 'BigTable', + 'BigtableClient', 'BlockCacheParams', 'configure_colab_session', 'configure_gcs', diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index d530572e91..8ff6ebedab 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -86,6 +86,8 @@ tensorflow/contrib/batching/python/ops tensorflow/contrib/bayesflow tensorflow/contrib/bayesflow/python tensorflow/contrib/bayesflow/python/ops +# tensorflow/contrib/bigtable/python +# tensorflow/contrib/bigtable/python/ops tensorflow/contrib/boosted_trees tensorflow/contrib/boosted_trees/estimator_batch tensorflow/contrib/boosted_trees/kernels diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 067c299a71..872b016d2b 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -49,43 +49,48 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -if(NOT WIN32) - function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR) - if(NOT ARGN) - message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") - return() +function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR) + if(NOT ARGN) + message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") + return() + endif() + + set(${SRCS}) + set(${HDRS}) + foreach(FIL ${ARGN}) + set(ABS_FIL ${ROOT_DIR}/${FIL}) + get_filename_component(FIL_WE ${FIL} NAME_WE) + get_filename_component(FIL_DIR ${ABS_FIL} PATH) + file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) + + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h") + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h") + + # We adust the path of the gRPC code generation accordingly. + if(WIN32) + set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/Release/grpc_cpp_plugin.exe) + else() + set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/grpc_cpp_plugin) endif() - set(${SRCS}) - set(${HDRS}) - foreach(FIL ${ARGN}) - set(ABS_FIL ${ROOT_DIR}/${FIL}) - get_filename_component(FIL_WE ${FIL} NAME_WE) - get_filename_component(FIL_DIR ${ABS_FIL} PATH) - file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) - - list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc") - list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h") - list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc") - list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h") - - add_custom_command( - OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc" - "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h" - COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} - ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin protoc-gen-grpc=${GRPC_BUILD}/grpc_cpp_plugin -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS} - DEPENDS ${ABS_FIL} protobuf grpc - COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}" - VERBATIM ) - endforeach() - - set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) - set(${SRCS} ${${SRCS}} PARENT_SCOPE) - set(${HDRS} ${${HDRS}} PARENT_SCOPE) - endfunction() -endif() + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin=protoc-gen-grpc=${GRPC_PROTOC_PLUGIN_PATH} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS} + DEPENDS ${ABS_FIL} protobuf grpc + COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}" + VERBATIM ) + endforeach() + + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) +endfunction() function(RELATIVE_PROTOBUF_TEXT_GENERATE_CPP SRCS HDRS ROOT_DIR) if(NOT ARGN) @@ -175,17 +180,14 @@ RELATIVE_PROTOBUF_TEXT_GENERATE_CPP(PROTO_TEXT_SRCS PROTO_TEXT_HDRS ${tensorflow_source_dir} ${tf_proto_text_srcs} ) -if(WIN32) - add_library(tf_protos_cc ${PROTO_SRCS} ${PROTO_HDRS}) -else() - file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/debug/*.proto" - ) - RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS - ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs} - ) - add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS}) -endif() +file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/tensorflow/core/debug/*.proto" + "${tensorflow_source_dir}/tensorflow/core/protobuf/master_service.proto" +) +RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS + ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs} +) +add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS}) ######################################################## # tf_core_lib library diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md new file mode 100644 index 0000000000..21fc44febc --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/README.md @@ -0,0 +1,45 @@ +# RevNet with TensorFlow eager execution + +This folder contains an TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster. + +## Content + +- `revnet.py`: The RevNet model. +- `blocks.py`: The relevant reversible blocks. +- `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100. +- `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API. +- `config.py`: Configuration file for network architectures and training hyperparameters. +- `main.py`: Main training and evaluation script. +- `ops.py`: Auxiliary downsampling operation. + +## To run +- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly` +or `tf-nightly-gpu` pip package in order to access the eager execution feature. + +- First run + +```bash +python cifar_tfrecords.py --data_dir ${PWD}/cifar +``` +to download the cifar dataset and convert them +to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100. + +- To train a model run + +```bash +python main.py --data_dir ${PWD}/cifar +``` + +- Optional arguments for `main.py` include + - `train_dir`: Directory to store eventfiles and checkpoints. + - `restore`: Restore the latest checkpoint. + - `validate`: Use validation set for training monitoring. + - `manual_grad`: Use the manually defined gradient map given by the authors. + - `dataset`: Use either `cifar-10` or `cifar-100` + +## Performance +- With the current implementation, RevNet-38 achieves >92% on CIFAR-10 and >71% on CIFAR-100. + +## Reference +The Reversible Residual Network: Backpropagation Without Storing Activations. +Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse. Neural Information Processing Systems (NIPS), 2017. diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java index bfb4a0a04b..580206943b 100644 --- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java @@ -25,6 +25,8 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; @@ -54,6 +56,14 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private static final float H_SCALE = 5.0f; private static final float W_SCALE = 5.0f; + // Float model + private static final float IMAGE_MEAN = 128.0f; + private static final float IMAGE_STD = 128.0f; + + //Number of threads in the java app + private static final int NUM_THREADS = 4; + + // Config values. private int inputSize; @@ -65,7 +75,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private float[][][] outputLocations; private float[][][] outputClasses; - float[][][][] img; + private ByteBuffer imgData = null; private Interpreter tfLite; @@ -176,9 +186,12 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { } // Pre-allocate buffers. - d.img = new float[1][inputSize][inputSize][3]; - + int numBytesPerChannel = 4; // Floating point + d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel); + d.imgData.order(ByteOrder.nativeOrder()); d.intValues = new int[d.inputSize * d.inputSize]; + + d.tfLite.setNumThreads(NUM_THREADS); d.outputLocations = new float[1][NUM_RESULTS][4]; d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; return d; @@ -198,10 +211,11 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { for (int i = 0; i < inputSize; ++i) { for (int j = 0; j < inputSize; ++j) { - int pixel = intValues[j * inputSize + i]; - img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f; - img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f; - img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f; + int pixelValue = intValues[i * inputSize + j]; + // Float model + imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); } } Trace.endSection(); // preprocessBitmap @@ -211,7 +225,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { outputLocations = new float[1][NUM_RESULTS][4]; outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; - Object[] inputArray = {img}; + Object[] inputArray = {imgData}; Map<Integer, Object> outputMap = new HashMap<>(); outputMap.put(0, outputLocations); outputMap.put(1, outputClasses); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 1b8a7205e6..8597707b24 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -59,6 +59,7 @@ using reference_ops::Mean; using reference_ops::RankOneSelect; using reference_ops::Relu1; using reference_ops::Relu6; +using reference_ops::ReluX; using reference_ops::Select; using reference_ops::SpaceToBatchND; using reference_ops::StridedSlice; diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 16901a3e53..9357e7407e 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -951,6 +951,19 @@ inline void Relu6(const float* input_data, const RuntimeShape& input_shape, } } +inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, + const RuntimeShape& input_shape, uint8* output_data, + const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const uint8 val = input_data[i]; + const uint8 clamped = + val > max_value ? max_value : val < min_value ? min_value : val; + output_data[i] = clamped; + } +} + template <FusedActivationFunctionType Ac> void L2Normalization(const float* input_data, const RuntimeShape& input_shape, float* output_data, const RuntimeShape& output_shape) { diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index f54db3af87..c448fb71db 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -991,7 +991,7 @@ TfLiteStatus InterpreterBuilder::operator()( variables.push_back(i); } } - (**interpreter).SetVariables(variables); + (**interpreter).SetVariables(std::move(variables)); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index e7343cb388..681448be20 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -20,8 +20,8 @@ limitations under the License. #include <vector> // Place `<locale>` before <Python.h> to avoid build failures in macOS. -#include <locale> #include <Python.h> +#include <locale> // We forward declare TFLite classes here to avoid exposing them to SWIG. namespace tflite { diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index a4229f91f5..29a1487c1f 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -132,7 +132,7 @@ class TocoConverter(object): Args: - graph_def: TensorFlow GraphDef. + graph_def: Frozen TensorFlow GraphDef. input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). @@ -178,7 +178,7 @@ class TocoConverter(object): """Creates a TocoConverter class from a file containing a frozen GraphDef. Args: - graph_def_file: Full filepath of file containing TensorFlow GraphDef. + graph_def_file: Full filepath of file containing frozen GraphDef. input_arrays: List of input tensors to freeze graph with. output_arrays: List of output tensors to freeze graph with. input_shapes: Dict of strings representing input tensor names to list of diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 0a60477c6d..9bd1f4f76e 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -225,7 +225,7 @@ def run_main(_): input_file_group.add_argument( "--graph_def_file", type=str, - help="Full filepath of file containing TensorFlow GraphDef.") + help="Full filepath of file containing frozen TensorFlow GraphDef.") input_file_group.add_argument( "--saved_model_dir", type=str, diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md index ee83c7a6e3..2db6a627ab 100644 --- a/tensorflow/contrib/lite/toco/README.md +++ b/tensorflow/contrib/lite/toco/README.md @@ -17,11 +17,12 @@ Usage information is given in these documents: Once an application developer has a trained TensorFlow model, TOCO will accept that model and generate a TensorFlow Lite [FlatBuffer](https://google.github.io/flatbuffers/) file. TOCO currently supports -[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators) -and frozen graphs (models generated via -[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)). -The TensorFlow Lite FlatBuffer file can be shipped to client devices, generally -mobile devices, where the TensorFlow Lite interpreter handles them on-device. -This flow is represented in the diagram below. +[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators), +frozen graphs (models generated via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)), +and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped +to client devices, generally mobile devices, where the TensorFlow Lite +interpreter handles them on-device. This flow is represented in the diagram +below. ![drawing](g3doc/toco_landscape.svg) diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 0ab024c618..18b7848db8 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -11,8 +11,10 @@ Table of contents: * [Command-line tools](#tools) * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) -* [Convert a TensorFlow GraphDef](#graphdef) -* [Convert a TensorFlow SavedModel](#savedmodel) +* [Basic examples](#basic) + * [Convert a TensorFlow GraphDef](#graphdef) + * [Convert a TensorFlow SavedModel](#savedmodel) + * [Convert a tf.keras model](#keras) * [Quantization](#quantization) * [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant) * [Use "dummy-quantization" to try out quantized inference on a float @@ -51,7 +53,12 @@ API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the Terminal for additional details on the command-line flags available. There were no command line tools in TensorFlow 1.8. -## Convert a TensorFlow GraphDef <a name="graphdef"></a> +## Basic examples <a name="basic"></a> + +The following section shows examples of how to convert a basic float-point model +from each of the supported data formats into a TensorFlow Lite FlatBuffers. + +### Convert a TensorFlow GraphDef <a name="graphdef"></a> The follow example converts a basic TensorFlow GraphDef (frozen by [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)) @@ -70,7 +77,7 @@ tflite_convert \ The value for `input_shapes` is automatically determined whenever possible. -## Convert a TensorFlow SavedModel <a name="savedmodel"></a> +### Convert a TensorFlow SavedModel <a name="savedmodel"></a> The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite FlatBuffer to perform floating-point inference. @@ -95,6 +102,17 @@ There is currently no support for MetaGraphDefs without a SignatureDef or for MetaGraphDefs that use the [`assets/` directory](https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory). +### Convert a tf.Keras model <a name="keras"></a> + +The following example converts a `tf.keras` model into a TensorFlow Lite +Flatbuffer. The `tf.keras` file must contain both the model and the weights. + +``` +tflite_convert \ + --output_file=/tmp/foo.tflite \ + --keras_model_file=/tmp/keras_model.h5 +``` + ## Quantization ### Convert a TensorFlow GraphDef for quantized inference <a name="graphdef-quant"></a> diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 2d44b871c6..decc8a45a4 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -19,7 +19,7 @@ Table of contents: The following high level flags specify the details of the input and output files. The flag `--output_file` is always required. Additionally, either -`--graph_def_file` or `--saved_model_dir` is required. +`--graph_def_file`, `--saved_model_dir` or `--keras_model_file` is required. * `--output_file`. Type: string. Specifies the full path of the output file. * `--graph_def_file`. Type: string. Specifies the full path of the input @@ -27,6 +27,8 @@ files. The flag `--output_file` is always required. Additionally, either [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). * `--saved_model_dir`. Type: string. Specifies the full path to the directory containing the SavedModel. +* `--keras_model_file`. Type: string. Specifies the full path of the HDF5 file + containing the tf.keras model. * `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of the output file. Allowed values: * `TFLITE`: TensorFlow Lite FlatBuffer format. diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index b04d166f89..3799eac0a1 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -41,9 +41,11 @@ is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is `TocoConverter` provides class methods based on the original format of the model. `TocoConverter.from_session()` is available for GraphDefs. -`TocoConverter.from_saved_model()` is available for SavedModels. Example usages -for simple float-point models are shown in [Basic Examples](#basic). Examples -usages for more complex models is shown in [Complex Examples](#complex). +`TocoConverter.from_saved_model()` is available for SavedModels. +`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files. +Example usages for simple float-point models are shown in [Basic +Examples](#basic). Examples usages for more complex models is shown in [Complex +Examples](#complex). **NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python interpreter when the conversion fails. This will be remedied as soon as @@ -117,7 +119,7 @@ available by running `help(tf.contrib.lite.TocoConverter)`. ### Exporting a tf.keras File <a name="basic-keras-file"></a> -The following example shows how to convert a tf.keras model into a TensorFlow +The following example shows how to convert a `tf.keras` model into a TensorFlow Lite FlatBuffer. ```python @@ -128,7 +130,7 @@ tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` -The tf.keras file must contain both the model and the weights. A comprehensive +The `tf.keras` file must contain both the model and the weights. A comprehensive example including model construction can be seen below. ```python diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 38699a62b5..58885b4950 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -59,7 +59,8 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kGreater || type == OperatorType::kGreaterEqual || type == OperatorType::kLess || type == OperatorType::kLessEqual || type == OperatorType::kSelect || - type == OperatorType::kArgMax; + type == OperatorType::kArgMax || type == OperatorType::kRelu || + type == OperatorType::kRelu1 || type == OperatorType::kRelu6; } const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { @@ -325,12 +326,13 @@ bool ChooseQuantizationForOperatorOutput( output, OperatorTypeName(op.type)); return true; } - if ((op.type == OperatorType::kDepthToSpace) || - (op.type == OperatorType::kSpaceToDepth) || - (op.type == OperatorType::kReshape) || - (op.type == OperatorType::kSplit) || - (op.type == OperatorType::kConcatenation && - model->flags.change_concat_input_ranges())) { + if ((op.type == OperatorType::kConcatenation && + model->flags.change_concat_input_ranges()) || + op.type == OperatorType::kDepthToSpace || + op.type == OperatorType::kSpaceToDepth || + op.type == OperatorType::kReshape || op.type == OperatorType::kSplit || + op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 || + op.type == OperatorType::kRelu6) { int data_input_index = 0; if (op.type == OperatorType::kSplit) { data_input_index = 1; diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 89db9ee279..6e7423f85e 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc tensorflow/core/kernels/reduction_ops_any.cc tensorflow/core/kernels/reduction_ops_all.cc tensorflow/core/kernels/roll_op.cc +tensorflow/core/kernels/queue_op.cc tensorflow/core/kernels/queue_ops.cc tensorflow/core/kernels/queue_base.cc tensorflow/core/kernels/pooling_ops_common.cc diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 157ed6a278..3e63e99030 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -22,17 +22,18 @@ from __future__ import print_function from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * +from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import * +from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * -from tensorflow.contrib.opt.python.training.model_average_optimizer import * -from tensorflow.contrib.opt.python.training.ggt import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 8aa40aeb45..b9cf40eb7b 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.training import optimizer from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops from tensorflow.python.training import adam from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.training import optimizer from tensorflow.python.util.tf_export import tf_export -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import resource_variable_ops class DecoupledWeightDecayExtension(object): @@ -65,7 +65,7 @@ class DecoupledWeightDecayExtension(object): Args: weight_decay: A `Tensor` or a floating point value, the factor by which a variable is decayed in the update step. - decay_var_list: Optional list or tuple or set of `Variable` objects to + **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ self._decay_var_list = None # is set in minimize or apply_gradients @@ -85,6 +85,28 @@ class DecoupledWeightDecayExtension(object): If decay_var_list is None, all variables in var_list are decayed. For more information see the documentation of Optimizer.minimize. + + Args: + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + var_list: Optional list or tuple of `Variable` objects to update to + minimize `loss`. Defaults to the list of variables collected in + the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + decay_var_list: Optional list of decay variables. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).minimize( @@ -103,6 +125,19 @@ class DecoupledWeightDecayExtension(object): are decayed. For more information see the documentation of Optimizer.apply_gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of decay variables. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).apply_gradients( @@ -197,6 +232,7 @@ def extend_with_decoupled_weight_decay(base_optimizer): A new optimizer class that inherits from DecoupledWeightDecayExtension and base_optimizer. """ + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, base_optimizer): """Base_optimizer with decoupled weight decay. diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py index 74d1cdbbda..76d8a5697a 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.opt.python.training import weight_decay_optimizers from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +30,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import adam -from tensorflow.contrib.opt.python.training import weight_decay_optimizers WEIGHT_DECAY = 0.01 @@ -91,7 +91,6 @@ class WeightDecayOptimizerTest(test.TestCase): opt = optimizer() update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - if not context.executing_eagerly(): with ops.Graph().as_default(): # Shouldn't return non-slot variables from other graphs. @@ -171,9 +170,9 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): @staticmethod def get_optimizer(): - AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay( adam.AdamOptimizer) - return AdamW(WEIGHT_DECAY) + return adamw(WEIGHT_DECAY) def testBasic(self): self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", @@ -185,6 +184,5 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): use_resource=True) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index e69725ff8a..f58268eff5 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -21,6 +21,7 @@ from __future__ import print_function import abc import six +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -182,19 +183,20 @@ def dynamic_decode(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) - def _is_xla_tensor(tensor): - try: - op = tensor.op - except AttributeError: - return False - if control_flow_util.IsInXLAContext(op): - return True - return False - with variable_scope.variable_scope(scope, "decoder") as varscope: - # Properly cache variable values inside the while_loop - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + # Determine context types. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None + in_while_loop = ( + control_flow_util.GetContainingWhileContext(ctxt) is not None) + # Properly cache variable values inside the while_loop. + # Don't set a caching device when running in a loop, since it is possible + # that train steps could be wrapped in a tf.while_loop. In that scenario + # caching prevents forward computations in loop iterations from re-reading + # the updated weights. + if not context.executing_eagerly() and not in_while_loop: + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) if maximum_iterations is not None: maximum_iterations = ops.convert_to_tensor( @@ -208,9 +210,6 @@ def dynamic_decode(decoder, decoder.output_dtype, decoder.batch_size) - is_xla = False - if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]): - is_xla = True if is_xla and maximum_iterations is None: raise ValueError("maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 3d0308aaf3..2c97834523 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase): checkpoint_path = os.path.join(self.get_temp_dir(), 'this_file_doesnt_exist') log_dir = os.path.join(self.get_temp_dir(), 'error_raised') - with self.assertRaises(errors.NotFoundError): + with self.assertRaises(ValueError): evaluation.evaluate_once('', checkpoint_path, log_dir) def _prepareCheckpoint(self, checkpoint_path): diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 13986127ba..4dc1c551cc 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -142,7 +142,7 @@ tensorflow::Status ConvertCalibGraphToInferGraph( auto n = infer_graph->mutable_node(i); if (n->op() == "TRTEngineOp") { VLOG(1) << "Processing " << n->name(); - const string& container_name = n->attr().at("segment_funcdef_name").s(); + string container_name = n->attr().at("segment_funcdef_name").s(); TRTCalibrationResource* cres = nullptr; auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); if (!status.ok()) { @@ -168,7 +168,6 @@ tensorflow::Status ConvertCalibGraphToInferGraph( "Can't get TRTCalibrator from resource manager!"); } cres->Unref(); - calib_rm->Cleanup(container_name); } } return tensorflow::Status::OK(); @@ -823,8 +822,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } else { // Graph is not modified. LOG(WARNING) << "Engine creation for segment " << i << ", composed of " - << converted_segments.at(i).first.size() << " nodes failed: " - << status << ". Skipping..."; + << converted_segments.at(i).first.size() + << " nodes failed: " << status << ". Skipping..."; } } cudaSetDevice(old_cuda_device); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 7684d8d4a2..1a4c0e755d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -46,8 +46,8 @@ const int INT8MODE = 2; struct EngineConnection { EngineConnection(const string& outside, int out_id, int out_port, - const string& inside, int in_id, int in_port, - bool input_edge, int port) + const string& inside, int in_id, int in_port, + bool input_edge, int port) : outside_node_name(outside), outside_id(out_id), outside_port(out_port), diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 75e32559bb..8a17eb02f1 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -319,7 +319,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, default: LOG(ERROR) << "Unknown TRT data type: " << int(dtype); ctx->SetStatus(tensorflow::errors::InvalidArgument( - "Unknown ouput TRT data type! ", static_cast<int>(dtype))); + "Unknown output TRT data type! ", static_cast<int>(dtype))); return; } } @@ -327,8 +327,8 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor const string output_name = StrCat(kOutputPHName, i); - const size_t binding_index = trt_engine_ptr->getBindingIndex( - output_name.c_str()); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; @@ -371,7 +371,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, default: LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype); ctx->SetStatus(tensorflow::errors::InvalidArgument( - "Unsupported output data type! ", int(dtype))); + "Unsupported output data type! ", static_cast<int>(dtype))); return; } } @@ -420,10 +420,10 @@ nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { } TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, - OpKernelContext* ctx) { + OpKernelContext* ctx) { static EngineCtxPair null_pair = { - TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr), - TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)}; + TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr), + TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)}; // TODO(sami): This method needs to be re-written to use resource manager and // with LRU mechanism option. tensorflow::mutex_lock lock(engine_mutex_); @@ -450,9 +450,9 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, auto raw_static_engine = static_engine.get(); const auto max_batch_size = raw_static_engine->getMaxBatchSize(); engine_map_[max_batch_size] = { - std::move(static_engine), - TrtUniquePtrType<nvinfer1::IExecutionContext>( - raw_static_engine->createExecutionContext())}; + std::move(static_engine), + TrtUniquePtrType<nvinfer1::IExecutionContext>( + raw_static_engine->createExecutionContext())}; // Runtime is safe to delete after engine creation serialized_segment_.clear(); if (max_batch_size < batch_size) return null_pair; diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index d6628cd1eb..d51a0b59e2 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -221,26 +221,22 @@ std::pair<string, string> calib_convert( #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } -version_struct get_linked_tensorrt_version() { +version_struct get_linked_tensorrt_version(){ // Return the version at the link time. - version_struct s; -#if GOOGLE_CUDA && GOOGLE_TENSORRT const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion(); + version_struct s; s.vmajor = lv[0]; s.vminor = lv[1]; s.vpatch = lv[2]; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT return s; } version_struct get_loaded_tensorrt_version(){ // Return the version from the loaded library. - version_struct s; -#if GOOGLE_CUDA && GOOGLE_TENSORRT const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion(); + version_struct s; s.vmajor = lv[0]; s.vminor = lv[1]; s.vpatch = lv[2]; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT return s; } diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 5210139336..14e025973e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -81,12 +81,17 @@ _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CTX_KEY = 'context' +_USE_TPU_KEY = 'use_tpu' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +# Ideally _USE_TPU_KEY should be reserved as well. However there are already +# models that make use of this key, thus it can not be reserved now to prevent +# breakage. In the long run, we would like to mitigate this by migrating models +# off of using _USE_TPU_KEY. _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] @@ -1414,8 +1419,11 @@ class _ModelFnWrapper(object): if batch_size_for_model_fn is not None: _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) + running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) + _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) + estimator_spec = self._model_fn(features=features, **kwargs) - if (self._ctx.is_running_on_cpu(is_export_mode) and + if (running_on_cpu and isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c1efc9c0c6..0e6bc03c0b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1923,6 +1923,7 @@ tf_proto_library_cc( srcs = ["protobuf/master_service.proto"], has_services = 1, cc_api_version = 2, + cc_grpc_version = 1, cc_stubby_versions = ["2"], protodeps = [":master_proto"], visibility = [ diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 477a0b670e..6149e5fca8 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -171,7 +171,7 @@ TEST_F(BaseApiTest, AllOpsAreInApiDef) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { continue; } - ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) + EXPECT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) << op.name() << " op does not have api_def_*.pbtxt file. " << "Please add api_def_" << op.name() << ".pbtxt file " << "under tensorflow/core/api_def/base_api/ directory."; diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 87ba609dd7..f903faf1bd 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1626,15 +1626,6 @@ Status DirectSession::MakeCallable(const CallableOptions& callable_options, TF_RETURN_IF_ERROR(CheckNotClosed()); TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()")); - if (!callable_options.run_options() - .debug_options() - .debug_tensor_watch_opts() - .empty()) { - return errors::Unimplemented( - "Debug options are not currently supported via the C++ MakeCallable " - "interface."); - } - std::unique_ptr<ExecutorsAndKeys> ek; std::unique_ptr<FunctionInfo> func_info; RunStateArgs run_state_args(callable_options.run_options().debug_options()); diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index b4bf1c408f..0b096a14a3 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -106,24 +106,24 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(1, shape.dim(1).size()); if (node->name() == y->name()) { #ifdef INTEL_MKL - // if MKL is used, it goes through various additional - // graph rewrite pass. In TF, everytime a graph pass + // if MKL is used, it goes through various additional + // graph rewrite pass. In TF, everytime a graph pass // happens, "constant" nodes are allocated // and deallocated. Each allocation calls the // (FindChunkPtr of BFCAllocator), - // which increments the value of AllocationId. - // Thus AllocationId becomes more than TF if MKL - // is used. Now IDs for MKL are 8 more than TF. + // which increments the value of AllocationId. + // Thus AllocationId becomes more than TF if MKL + // is used. Now IDs for MKL are 8 more than TF. EXPECT_EQ(29, cm->AllocationId(node, 0)); #else EXPECT_EQ(21, cm->AllocationId(node, 0)); -#endif +#endif } else { #ifdef INTEL_MKL EXPECT_EQ(30, cm->AllocationId(node, 0)); #else EXPECT_EQ(22, cm->AllocationId(node, 0)); -#endif +#endif } } EXPECT_LE(0, cm->MaxExecutionTime(node)); diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 0abef01a9a..75f8a19e9c 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -636,12 +636,12 @@ tf_cuda_cc_test( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:master_proto_cc", + "//tensorflow/core:master_service_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc:grpc_master_service_impl", "//tensorflow/core/distributed_runtime/rpc:grpc_testlib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc index 62b18a45b1..09e96cbd40 100644 --- a/tensorflow/core/distributed_runtime/master_test.cc +++ b/tensorflow/core/distributed_runtime/master_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "grpcpp/grpcpp.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/allocator.h" @@ -38,6 +37,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/master_service.grpc.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 4a10d99a60..d6c493c022 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -201,11 +201,11 @@ cc_library( srcs = ["grpc_remote_master.cc"], hdrs = ["grpc_remote_master.h"], deps = [ - ":grpc_master_service_impl", ":grpc_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:master_proto_cc", + "//tensorflow/core:master_service_proto_cc", "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:master_interface", ], @@ -219,28 +219,18 @@ cc_library( deps = [ ":async_service_interface", ":grpc_call", - ":grpc_master_service_impl", ":grpc_util", "//tensorflow:grpc++", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:master_proto_cc", + "//tensorflow/core:master_service_proto_cc", "//tensorflow/core/distributed_runtime:master", ], alwayslink = 1, ) cc_library( - name = "grpc_master_service_impl", - srcs = ["grpc_master_service_impl.cc"], - hdrs = ["grpc_master_service_impl.h"], - deps = [ - "//tensorflow:grpc++", - "//tensorflow/core:master_proto_cc", - ], -) - -cc_library( name = "rpc_rendezvous_mgr", srcs = ["rpc_rendezvous_mgr.cc"], hdrs = ["rpc_rendezvous_mgr.h"], diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index 127dea2882..2c2c1d484a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -36,12 +36,12 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/master.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/master_service.grpc.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc deleted file mode 100644 index 770a0fcf14..0000000000 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" - -#include "grpcpp/impl/codegen/async_stream.h" -#include "grpcpp/impl/codegen/async_unary_call.h" -#include "grpcpp/impl/codegen/channel_interface.h" -#include "grpcpp/impl/codegen/client_unary_call.h" -#include "grpcpp/impl/codegen/method_handler_impl.h" -#include "grpcpp/impl/codegen/rpc_service_method.h" -#include "grpcpp/impl/codegen/service_type.h" -#include "grpcpp/impl/codegen/sync_stream.h" - -namespace tensorflow { - -namespace grpc { - -static const char* grpcMasterService_method_names[] = { - "/tensorflow.MasterService/CreateSession", - "/tensorflow.MasterService/ExtendSession", - "/tensorflow.MasterService/PartialRunSetup", - "/tensorflow.MasterService/RunStep", - "/tensorflow.MasterService/CloseSession", - "/tensorflow.MasterService/ListDevices", - "/tensorflow.MasterService/Reset", - "/tensorflow.MasterService/MakeCallable", - "/tensorflow.MasterService/RunCallable", - "/tensorflow.MasterService/ReleaseCallable", -}; - -std::unique_ptr<MasterService::Stub> MasterService::NewStub( - const std::shared_ptr< ::grpc::ChannelInterface>& channel, - const ::grpc::StubOptions& options) { - std::unique_ptr<MasterService::Stub> stub(new MasterService::Stub(channel)); - return stub; -} - -MasterService::Stub::Stub( - const std::shared_ptr< ::grpc::ChannelInterface>& channel) - : channel_(channel), - rpcmethod_CreateSession_(grpcMasterService_method_names[0], - ::grpc::internal::RpcMethod::NORMAL_RPC, - channel), - rpcmethod_ExtendSession_(grpcMasterService_method_names[1], - ::grpc::internal::RpcMethod::NORMAL_RPC, - channel), - rpcmethod_PartialRunSetup_(grpcMasterService_method_names[2], - ::grpc::internal::RpcMethod::NORMAL_RPC, - channel), - rpcmethod_RunStep_(grpcMasterService_method_names[3], - ::grpc::internal::RpcMethod::NORMAL_RPC, channel), - rpcmethod_CloseSession_(grpcMasterService_method_names[4], - ::grpc::internal::RpcMethod::NORMAL_RPC, channel), - rpcmethod_ListDevices_(grpcMasterService_method_names[5], - ::grpc::internal::RpcMethod::NORMAL_RPC, channel), - rpcmethod_Reset_(grpcMasterService_method_names[6], - ::grpc::internal::RpcMethod::NORMAL_RPC, channel), - rpcmethod_MakeCallable_(grpcMasterService_method_names[7], - ::grpc::internal::RpcMethod::NORMAL_RPC, channel), - rpcmethod_RunCallable_(grpcMasterService_method_names[8], - ::grpc::internal::RpcMethod::NORMAL_RPC, channel), - rpcmethod_ReleaseCallable_(grpcMasterService_method_names[9], - ::grpc::internal::RpcMethod::NORMAL_RPC, - channel) {} - -::grpc::Status MasterService::Stub::CreateSession( - ::grpc::ClientContext* context, const CreateSessionRequest& request, - CreateSessionResponse* response) { - return ::grpc::internal::BlockingUnaryCall( - channel_.get(), rpcmethod_CreateSession_, context, request, response); -} - -::grpc::Status MasterService::Stub::ExtendSession( - ::grpc::ClientContext* context, const ExtendSessionRequest& request, - ExtendSessionResponse* response) { - return ::grpc::internal::BlockingUnaryCall( - channel_.get(), rpcmethod_ExtendSession_, context, request, response); -} - -::grpc::Status MasterService::Stub::PartialRunSetup( - ::grpc::ClientContext* context, const PartialRunSetupRequest& request, - PartialRunSetupResponse* response) { - return ::grpc::internal::BlockingUnaryCall( - channel_.get(), rpcmethod_PartialRunSetup_, context, request, response); -} - -::grpc::Status MasterService::Stub::RunStep(::grpc::ClientContext* context, - const RunStepRequest& request, - RunStepResponse* response) { - return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_RunStep_, - context, request, response); -} - -::grpc::Status MasterService::Stub::CloseSession( - ::grpc::ClientContext* context, const CloseSessionRequest& request, - CloseSessionResponse* response) { - return ::grpc::internal::BlockingUnaryCall( - channel_.get(), rpcmethod_CloseSession_, context, request, response); -} - -::grpc::Status MasterService::Stub::ListDevices( - ::grpc::ClientContext* context, const ListDevicesRequest& request, - ListDevicesResponse* response) { - return ::grpc::internal::BlockingUnaryCall( - channel_.get(), rpcmethod_ListDevices_, context, request, response); -} - -::grpc::Status MasterService::Stub::Reset(::grpc::ClientContext* context, - const ResetRequest& request, - ResetResponse* response) { - return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Reset_, - context, request, response); -} - -::grpc::Status MasterService::Stub::MakeCallable( - ::grpc::ClientContext* context, const MakeCallableRequest& request, - MakeCallableResponse* response) { - return ::grpc::internal::BlockingUnaryCall( - channel_.get(), rpcmethod_MakeCallable_, context, request, response); -} - -::grpc::Status MasterService::Stub::RunCallable( - ::grpc::ClientContext* context, const RunCallableRequest& request, - RunCallableResponse* response) { - return ::grpc::internal::BlockingUnaryCall( - channel_.get(), rpcmethod_RunCallable_, context, request, response); -} - -::grpc::Status MasterService::Stub::ReleaseCallable( - ::grpc::ClientContext* context, const ReleaseCallableRequest& request, - ReleaseCallableResponse* response) { - return ::grpc::internal::BlockingUnaryCall( - channel_.get(), rpcmethod_ReleaseCallable_, context, request, response); -} - -MasterService::AsyncService::AsyncService() { - int method_len = sizeof(grpcMasterService_method_names) / - sizeof(grpcMasterService_method_names[0]); - for (int i = 0; i < method_len; ++i) { - AddMethod(new ::grpc::internal::RpcServiceMethod( - grpcMasterService_method_names[i], - ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); - ::grpc::Service::MarkMethodAsync(i); - } -} - -MasterService::AsyncService::~AsyncService() {} - -} // namespace grpc - -} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h deleted file mode 100644 index 751f2633e7..0000000000 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ -#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ - -#include "grpcpp/impl/codegen/async_stream.h" -#include "grpcpp/impl/codegen/async_unary_call.h" -#include "grpcpp/impl/codegen/proto_utils.h" -#include "grpcpp/impl/codegen/rpc_method.h" -#include "grpcpp/impl/codegen/service_type.h" -#include "grpcpp/impl/codegen/status.h" -#include "grpcpp/impl/codegen/stub_options.h" -#include "grpcpp/impl/codegen/sync_stream.h" - -#include "tensorflow/core/protobuf/master.pb.h" - -namespace grpc { -class CompletionQueue; -class Channel; -class RpcService; -class ServerCompletionQueue; -class ServerContext; -} // namespace grpc - -namespace tensorflow { - -namespace grpc { - -// Implementation of `tensorflow.MasterService`, based on the -// definition in "//tensorflow/core/protobuf/master_service.proto", -// and the gRPC generated stub and service classes. -// See that file for the definition of methods and messages. -class MasterService final { - public: - class StubInterface { - public: - virtual ~StubInterface() {} - virtual ::grpc::Status CreateSession(::grpc::ClientContext* context, - const CreateSessionRequest& request, - CreateSessionResponse* response) = 0; - virtual ::grpc::Status ExtendSession(::grpc::ClientContext* context, - const ExtendSessionRequest& request, - ExtendSessionResponse* response) = 0; - virtual ::grpc::Status PartialRunSetup( - ::grpc::ClientContext* context, const PartialRunSetupRequest& request, - PartialRunSetupResponse* response) = 0; - virtual ::grpc::Status RunStep(::grpc::ClientContext* context, - const RunStepRequest& request, - RunStepResponse* response) = 0; - virtual ::grpc::Status CloseSession(::grpc::ClientContext* context, - const CloseSessionRequest& request, - CloseSessionResponse* response) = 0; - virtual ::grpc::Status ListDevices(::grpc::ClientContext* context, - const ListDevicesRequest& request, - ListDevicesResponse* response) = 0; - virtual ::grpc::Status Reset(::grpc::ClientContext* context, - const ResetRequest& request, - ResetResponse* response) = 0; - virtual ::grpc::Status MakeCallable(::grpc::ClientContext* context, - const MakeCallableRequest& request, - MakeCallableResponse* response) = 0; - virtual ::grpc::Status RunCallable(::grpc::ClientContext* context, - const RunCallableRequest& request, - RunCallableResponse* response) = 0; - virtual ::grpc::Status ReleaseCallable( - ::grpc::ClientContext* context, const ReleaseCallableRequest& request, - ReleaseCallableResponse* response) = 0; - }; - class Stub final : public StubInterface { - public: - Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); - ::grpc::Status CreateSession(::grpc::ClientContext* context, - const CreateSessionRequest& request, - CreateSessionResponse* response) override; - ::grpc::Status ExtendSession(::grpc::ClientContext* context, - const ExtendSessionRequest& request, - ExtendSessionResponse* response) override; - ::grpc::Status PartialRunSetup(::grpc::ClientContext* context, - const PartialRunSetupRequest& request, - PartialRunSetupResponse* response) override; - ::grpc::Status RunStep(::grpc::ClientContext* context, - const RunStepRequest& request, - RunStepResponse* response) override; - ::grpc::Status CloseSession(::grpc::ClientContext* context, - const CloseSessionRequest& request, - CloseSessionResponse* response) override; - ::grpc::Status ListDevices(::grpc::ClientContext* context, - const ListDevicesRequest& request, - ListDevicesResponse* response) override; - ::grpc::Status Reset(::grpc::ClientContext* context, - const ResetRequest& request, - ResetResponse* response) override; - ::grpc::Status MakeCallable(::grpc::ClientContext* context, - const MakeCallableRequest& request, - MakeCallableResponse* response) override; - ::grpc::Status RunCallable(::grpc::ClientContext* context, - const RunCallableRequest& request, - RunCallableResponse* response) override; - ::grpc::Status ReleaseCallable(::grpc::ClientContext* context, - const ReleaseCallableRequest& request, - ReleaseCallableResponse* response) override; - - private: - std::shared_ptr< ::grpc::ChannelInterface> channel_; - const ::grpc::internal::RpcMethod rpcmethod_CreateSession_; - const ::grpc::internal::RpcMethod rpcmethod_ExtendSession_; - const ::grpc::internal::RpcMethod rpcmethod_PartialRunSetup_; - const ::grpc::internal::RpcMethod rpcmethod_RunStep_; - const ::grpc::internal::RpcMethod rpcmethod_CloseSession_; - const ::grpc::internal::RpcMethod rpcmethod_ListDevices_; - const ::grpc::internal::RpcMethod rpcmethod_Reset_; - const ::grpc::internal::RpcMethod rpcmethod_MakeCallable_; - const ::grpc::internal::RpcMethod rpcmethod_RunCallable_; - const ::grpc::internal::RpcMethod rpcmethod_ReleaseCallable_; - }; - static std::unique_ptr<Stub> NewStub( - const std::shared_ptr< ::grpc::ChannelInterface>& channel, - const ::grpc::StubOptions& options = ::grpc::StubOptions()); - - class AsyncService : public ::grpc::Service { - public: - AsyncService(); - virtual ~AsyncService(); - void RequestCreateSession( - ::grpc::ServerContext* context, CreateSessionRequest* request, - ::grpc::ServerAsyncResponseWriter<CreateSessionResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(0, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestExtendSession( - ::grpc::ServerContext* context, ExtendSessionRequest* request, - ::grpc::ServerAsyncResponseWriter<ExtendSessionResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(1, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestPartialRunSetup( - ::grpc::ServerContext* context, PartialRunSetupRequest* request, - ::grpc::ServerAsyncResponseWriter<PartialRunSetupResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(2, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestRunStep( - ::grpc::ServerContext* context, RunStepRequest* request, - ::grpc::ServerAsyncResponseWriter<RunStepResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(3, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestCloseSession( - ::grpc::ServerContext* context, CloseSessionRequest* request, - ::grpc::ServerAsyncResponseWriter<CloseSessionResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(4, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestListDevices( - ::grpc::ServerContext* context, ListDevicesRequest* request, - ::grpc::ServerAsyncResponseWriter<ListDevicesResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(5, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestReset( - ::grpc::ServerContext* context, ResetRequest* request, - ::grpc::ServerAsyncResponseWriter<ResetResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(6, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestMakeCallable( - ::grpc::ServerContext* context, MakeCallableRequest* request, - ::grpc::ServerAsyncResponseWriter<MakeCallableResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(7, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestRunCallable( - ::grpc::ServerContext* context, RunCallableRequest* request, - ::grpc::ServerAsyncResponseWriter<RunCallableResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(8, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestReleaseCallable( - ::grpc::ServerContext* context, ReleaseCallableRequest* request, - ::grpc::ServerAsyncResponseWriter<ReleaseCallableResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(9, context, request, response, - new_call_cq, notification_cq, tag); - } - }; -}; - -} // namespace grpc - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index b832a2115c..6c2940553c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -19,13 +19,13 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_interface.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/master_service.grpc.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index ff64d78b79..2c833d11a9 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -289,12 +289,10 @@ Status GrpcServer::Init( nullptr); } - Status GrpcServer::Init( ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func) { - return Init(std::move(service_func), rendezvous_mgr_func, nullptr, - nullptr); + return Init(std::move(service_func), rendezvous_mgr_func, nullptr, nullptr); } Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 115148b84e..b01cfb6426 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -100,6 +100,9 @@ class GrpcServer : public ServerInterface { Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func); + Status Init(ServiceInitFunction service_func, + const RendezvousMgrCreationFunction& rendezvous_mgr_func); + Status Init(); // A subclass can override this method to support secure credentials. diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h index 813ec6eed5..0a8da8b3bf 100644 --- a/tensorflow/core/framework/resource_op_kernel.h +++ b/tensorflow/core/framework/resource_op_kernel.h @@ -43,9 +43,15 @@ template <typename T> class ResourceOpKernel : public OpKernel { public: explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, - context->allocate_persistent(DT_STRING, TensorShape({2}), - &handle_, nullptr)); + has_resource_type_ = (context->output_type(0) == DT_RESOURCE); + if (!has_resource_type_) { + // The resource variant of the op may be placed on non-CPU devices, but + // this allocation is always on the host. Fortunately we don't need it in + // the resource case. + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_STRING, TensorShape({2}), + &handle_, nullptr)); + } } // The resource is deleted from the resource manager only when it is private @@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel { return; } - auto h = handle_.AccessTensor(context)->template flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); + if (!has_resource_type_) { + auto h = handle_.AccessTensor(context)->template flat<string>(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + } resource_ = resource; } - if (context->expected_output_dtype(0) == DT_RESOURCE) { + if (has_resource_type_) { OP_REQUIRES_OK(context, MakeResourceHandleToOutput( context, 0, cinfo_.container(), cinfo_.name(), MakeTypeIndex<T>())); @@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel { virtual Status VerifyResource(T* resource) { return Status::OK(); } PersistentTensor handle_ GUARDED_BY(mu_); + + // Is the output of the operator of type DT_RESOURCE? + bool has_resource_type_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h index 8002d9291c..4a18efc940 100644 --- a/tensorflow/core/framework/stats_aggregator.h +++ b/tensorflow/core/framework/stats_aggregator.h @@ -57,6 +57,10 @@ class StatsAggregator { // interface. It is possible that not all implementations will support // encoding their state as a protocol buffer. virtual void EncodeToProto(Summary* out_summary) = 0; + + // Increment the `label` cell of metrics mapped with `name` by given `value`. + virtual void IncrementCounter(const string& name, const string& label, + int64 val) = 0; }; // A `StatsAggregatorResource` wraps a shareable `StatsAggregator` as a resource diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index bdeb5c66fc..653b088b1d 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -161,6 +161,8 @@ bool IsExit(const NodeDef& node) { return op == "Exit" || op == "RefExit"; } +bool IsExp(const NodeDef& node) { return node.op() == "Exp"; } + bool IsFill(const NodeDef& node) { return node.op() == "Fill"; } bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 2de7d8cc9a..94439265c9 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -60,6 +60,7 @@ bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); bool IsExit(const NodeDef& node); +bool IsExp(const NodeDef& node); bool IsFill(const NodeDef& node); bool IsFloorDiv(const NodeDef& node); bool IsFloorMod(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d8c5d09c4d..72ca3c3fa2 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -178,6 +178,42 @@ NodeDef* GetTailOfIdempotentChain( is_idempotent_non_branching); } +// GetElementUnexhaustive tries to get the value of an element in a tensor and +// turn it into complex128 type. It only check for a limited number of data +// types, so it's unexhaustive. +bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes, + complex128* element) { + if (dtypes.find(t.dtype()) == dtypes.end()) return false; + switch (t.dtype()) { + case DT_BFLOAT16: + *element = complex128(t.flat<bfloat16>()(i)); + return true; + case DT_HALF: + *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); + return true; + case DT_INT32: + *element = complex128(t.flat<int32>()(i)); + return true; + case DT_INT64: + *element = complex128(t.flat<int64>()(i)); + return true; + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return true; + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return true; + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return true; + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return true; + default: + return false; + } +} + // Graph optimizer context extension specific to ArithmeticOptimizer. struct ArithmeticOptimizerContext { explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify) @@ -2361,7 +2397,13 @@ class ConvertPowStage : public ArithmeticOptimizerStage { complex128 prev, curr; for (int i = 0; i < pow.NumElements(); ++i) { - TF_RETURN_IF_ERROR(GetElement(pow, i, &curr)); + if (!GetElementUnexhaustive(pow, i, + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128}, + &curr)) { + // input data type is not supported by Pow. Skip. + return Status::OK(); + } if (i != 0 && curr != prev) { // pow has different values on different elements. Skip. return Status::OK(); @@ -2432,31 +2474,6 @@ class ConvertPowStage : public ArithmeticOptimizerStage { } private: - Status GetElement(const Tensor& t, int i, complex128* element) { - switch (t.dtype()) { - case DT_INT32: - *element = complex128(t.flat<int32>()(i)); - return Status::OK(); - case DT_INT64: - *element = complex128(t.flat<int64>()(i)); - return Status::OK(); - case DT_FLOAT: - *element = complex128(t.flat<float>()(i)); - return Status::OK(); - case DT_DOUBLE: - *element = complex128(t.flat<double>()(i)); - return Status::OK(); - case DT_COMPLEX64: - *element = complex128(t.flat<complex64>()(i)); - return Status::OK(); - case DT_COMPLEX128: - *element = t.flat<complex128>()(i); - return Status::OK(); - default: - return errors::InvalidArgument("Invalid data type: ", t.dtype()); - } - } - Status SetElementToOne(int i, Tensor* t) { switch (t->dtype()) { case DT_INT32: @@ -2544,7 +2561,10 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } complex128 element; for (int k = 0; k < constant.NumElements(); ++k) { - if (!GetElement(constant, k, &element)) { + if (!GetElementUnexhaustive(constant, k, + {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128}, + &element)) { // input data type is not supported by log1p. Skip. return Status::OK(); } @@ -2569,30 +2589,81 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } return Status::OK(); } +}; - bool GetElement(const Tensor& t, int i, complex128* element) { - switch (t.dtype()) { - case DT_BFLOAT16: - *element = complex128(t.flat<bfloat16>()(i)); - return true; - case DT_HALF: - *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); - return true; - case DT_FLOAT: - *element = complex128(t.flat<float>()(i)); - return true; - case DT_DOUBLE: - *element = complex128(t.flat<double>()(i)); - return true; - case DT_COMPLEX64: - *element = complex128(t.flat<complex64>()(i)); - return true; - case DT_COMPLEX128: - *element = t.flat<complex128>()(i); - return true; - default: - return false; +class ConvertExpm1Stage : public ArithmeticOptimizerStage { + public: + explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {} + ~ConvertExpm1Stage() override = default; + + bool IsSupported(const NodeDef* node) const override { return IsExp(*node); } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + if (!IsSub(*input)) { + return Status::OK(); } + + if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) { + return Status::OK(); + } + + const auto& t = + ctx().graph_properties->GetInputProperties(input->name())[0]; + const auto& c = + ctx().graph_properties->GetInputProperties(input->name())[1]; + for (int k = 0; k < c.shape().dim_size(); ++k) { + // Skip if c shape is not fully determined. + if (c.shape().dim(k).size() < 0) { + return Status::OK(); + } + } + TensorShapeProto broadcast_shape; + if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { + return Status::OK(); + } + if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { + // skip if the non-constant tensor doesn't have the same shape after + // broadcast. + return Status::OK(); + } + if (TensorShape::IsValid(c.shape()) && c.has_value()) { + Tensor constant(c.dtype(), c.shape()); + if (!constant.FromProto(c.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + c.value().DebugString()); + } + complex128 element; + for (int k = 0; k < constant.NumElements(); ++k) { + if (!GetElementUnexhaustive(constant, k, + {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128}, + &element)) { + // input data type is not supported by expm1. Skip. + return Status::OK(); + } + if (element != complex128(1)) { + // current element is not 1. Skip. + return Status::OK(); + } + } + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &x)); + TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &y)); + node->set_op("Expm1"); + node->set_input(0, input->input(0)); + node->add_input(AsControlDependency(y->name())); + ForwardControlDependencies(node, {input}); + + AddToOptimizationQueue(node); + AddToOptimizationQueue(input); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + } + return Status::OK(); } }; @@ -2928,6 +2999,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); if (options_.optimize_max_or_min_of_monotonic) pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); + if (options_.convert_expm1) + pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 824ef35ef6..45a5f65b81 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -77,6 +77,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool simplify_aggregation = true; bool convert_pow = true; bool convert_log1p = true; + bool convert_expm1 = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d0e6b04679..3f6c04a5b5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -274,6 +274,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.optimize_max_or_min_of_monotonic = true; } + + void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_expm1 = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2533,6 +2538,43 @@ TEST_F(ArithmeticOptimizerTest, Log1p) { CompareGraphs(want, got); } +TEST_F(ArithmeticOptimizerTest, Expm1) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2}); + auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2}); + auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2}); + auto s12 = ops::Sub(s.WithOpName("s12").WithControlDependencies(x3), x1, x2); + auto s23 = ops::Sub(s.WithOpName("s23"), x2, x3); + Output out1 = ops::Exp(s.WithOpName("out1"), s12); + Output out2 = ops::Exp(s.WithOpName("out2"), s23); + + GrapplerItem item; + item.fetch = {"out1", "out2"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(2, tensors_expected.size()); + + GraphDef got; + ArithmeticOptimizer optimizer; + EnableOnlyExpm1(&optimizer); + OptimizeAndPrune(&optimizer, &item, &got); + auto tensors = EvaluateNodes(got, item.fetch); + EXPECT_EQ(2, tensors.size()); + + GraphDef want; + AddNode("x1", "Const", {}, {}, &want); + AddNode("x2", "Const", {}, {}, &want); + AddNode("x3", "Const", {}, {}, &want); + AddNode("s23", "Sub", {"x2", "x3"}, {}, &want); + AddNode("out1", "Expm1", + {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {}, + &want); + AddNode("out2", "Exp", {"s23"}, {}, &want); + + CompareGraphs(want, got); +} + TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 00f66c9bc1..bc717d5eeb 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -23,9 +23,7 @@ namespace grappler { namespace graph_utils { namespace { -class GraphUtilsTest : public ::testing::Test {}; - -TEST_F(GraphUtilsTest, AddScalarConstNodeBool) { +TEST(GraphUtilsTest, AddScalarConstNodeBool) { GraphDef graph; NodeDef* bool_node; TF_EXPECT_OK(AddScalarConstNode<bool>(true, &graph, &bool_node)); @@ -33,7 +31,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeBool) { EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true); } -TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) { +TEST(GraphUtilsTest, AddScalarConstNodeDouble) { GraphDef graph; NodeDef* double_node; TF_EXPECT_OK(AddScalarConstNode<double>(3.14, &graph, &double_node)); @@ -41,7 +39,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) { EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14); } -TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) { +TEST(GraphUtilsTest, AddScalarConstNodeFloat) { GraphDef graph; NodeDef* float_node; TF_EXPECT_OK(AddScalarConstNode<float>(3.14, &graph, &float_node)); @@ -49,7 +47,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) { EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14); } -TEST_F(GraphUtilsTest, AddScalarConstNodeInt) { +TEST(GraphUtilsTest, AddScalarConstNodeInt) { GraphDef graph; NodeDef* int_node; TF_EXPECT_OK(AddScalarConstNode<int>(42, &graph, &int_node)); @@ -57,7 +55,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeInt) { EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42); } -TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) { +TEST(GraphUtilsTest, AddScalarConstNodeInt64) { GraphDef graph; NodeDef* int64_node; TF_EXPECT_OK(AddScalarConstNode<int64>(42, &graph, &int64_node)); @@ -65,7 +63,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) { EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42); } -TEST_F(GraphUtilsTest, AddScalarConstNodeString) { +TEST(GraphUtilsTest, AddScalarConstNodeString) { GraphDef graph; NodeDef* string_node; TF_EXPECT_OK(AddScalarConstNode<StringPiece>("hello", &graph, &string_node)); @@ -73,7 +71,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeString) { EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello"); } -TEST_F(GraphUtilsTest, Compare) { +TEST(GraphUtilsTest, Compare) { GraphDef graphA; GraphDef graphB; EXPECT_TRUE(Compare(graphA, graphB)); @@ -88,7 +86,7 @@ TEST_F(GraphUtilsTest, Compare) { EXPECT_TRUE(Compare(graphA, graphB)); } -TEST_F(GraphUtilsTest, ContainsNodeWithName) { +TEST(GraphUtilsTest, ContainsNodeWithName) { GraphDef graph; EXPECT_TRUE(!ContainsNodeWithName("A", graph)); @@ -100,7 +98,7 @@ TEST_F(GraphUtilsTest, ContainsNodeWithName) { EXPECT_TRUE(!ContainsNodeWithName("A", graph)); } -TEST_F(GraphUtilsTest, ContainsNodeWithOp) { +TEST(GraphUtilsTest, ContainsNodeWithOp) { GraphDef graph; EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph)); @@ -112,7 +110,7 @@ TEST_F(GraphUtilsTest, ContainsNodeWithOp) { EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph)); } -TEST_F(GraphUtilsTest, FindNodeWithName) { +TEST(GraphUtilsTest, FindNodeWithName) { GraphDef graph; EXPECT_EQ(FindNodeWithName("A", graph), -1); @@ -124,7 +122,7 @@ TEST_F(GraphUtilsTest, FindNodeWithName) { EXPECT_EQ(FindNodeWithName("A", graph), -1); } -TEST_F(GraphUtilsTest, FindNodeWithOp) { +TEST(GraphUtilsTest, FindNodeWithOp) { GraphDef graph; EXPECT_EQ(FindNodeWithOp("OpA", graph), -1); @@ -136,7 +134,7 @@ TEST_F(GraphUtilsTest, FindNodeWithOp) { EXPECT_EQ(FindNodeWithOp("OpA", graph), -1); } -TEST_F(GraphUtilsTest, SetUniqueName) { +TEST(GraphUtilsTest, SetUniqueName) { GraphDef graph; NodeDef* node1; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index d3710a4b5c..3e66d6412a 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -368,6 +368,7 @@ cc_library( cc_library( name = "queue_op", + srcs = ["queue_op.cc"], hdrs = ["queue_op.h"], deps = [ ":queue_base", @@ -1885,9 +1886,10 @@ cc_library( name = "fifo_queue", srcs = ["fifo_queue.cc"], hdrs = ["fifo_queue.h"], - visibility = ["//visibility:private"], + visibility = [":friends"], deps = [ ":queue_base", + ":queue_op", ":typed_queue", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -3919,6 +3921,7 @@ tf_cc_test( cc_library( name = "sparse", deps = [ + ":deserialize_sparse_variant_op", ":serialize_sparse_op", ":sparse_add_grad_op", ":sparse_add_op", @@ -4073,6 +4076,15 @@ tf_kernel_library( ) tf_kernel_library( + name = "deserialize_sparse_variant_op", + prefix = "deserialize_sparse_variant_op", + deps = SPARSE_DEPS + [ + ":reshape_util", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_kernel_library( name = "sparse_tensors_map_ops", prefix = "sparse_tensors_map_ops", deps = SPARSE_DEPS, @@ -5083,6 +5095,7 @@ filegroup( "padding_fifo_queue.cc", "padding_fifo_queue_op.cc", "queue_base.cc", + "queue_op.cc", "queue_ops.cc", "random_op.cc", "reduction_ops_all.cc", diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index fe1a1ba5a3..a888422d49 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -297,7 +297,8 @@ class ZerosLikeOp : public OpKernel { errors::InvalidArgument("ZerosLike non-scalar Tensor with " "dtype=DT_VARIANT is not supported.")); const Variant& v = input.scalar<Variant>()(); - Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({})); + Tensor out(ctx->device()->GetAllocator(AllocatorAttributes()), DT_VARIANT, + TensorShape({})); Variant* out_v = &(out.scalar<Variant>()()); OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>( ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v)); diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index c17e9343ea..07cc91f9d5 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -40,9 +40,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES( ctx, window_size > 0, errors::InvalidArgument("Window size must be greater than zero.")); - OP_REQUIRES( - ctx, stride > 0, - errors::InvalidArgument("Stride must be greater than zero.")); + OP_REQUIRES(ctx, stride > 0, + errors::InvalidArgument("Stride must be greater than zero.")); if (stride == window_size) { LOG(WARNING) << "stride: " << stride << " is equal to window_size: " << window_size diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index 33a56b2eb5..b133cfab54 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -20,11 +20,25 @@ limitations under the License. #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace { +static mutex* get_counters_map_lock() { + static mutex counters_map_lock(LINKER_INITIALIZED); + return &counters_map_lock; +} + +static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() { + static std::unordered_map<string, monitoring::Counter<1>*>* counters_map = + new std::unordered_map<string, monitoring::Counter<1>*>; + return counters_map; +} + class StatsAggregatorImpl : public StatsAggregator { public: StatsAggregatorImpl() {} @@ -61,6 +75,21 @@ class StatsAggregatorImpl : public StatsAggregator { } } + void IncrementCounter(const string& name, const string& label, + int64 val) override { + mutex_lock l(*get_counters_map_lock()); + auto counters_map = get_counters_map(); + if (counters_map->find(name) == counters_map->end()) { + counters_map->emplace( + name, monitoring::Counter<1>::New( + /*streamz name*/ "/tensorflow/" + name, + /*streamz description*/ + name + " generated or consumed by the component.", + /*streamz label name*/ "component_descriptor")); + } + counters_map->at(name)->GetCell(label)->IncrementBy(val); + } + private: mutex mu_; std::unordered_map<string, histogram::Histogram> histograms_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 3e0a6ae049..a537e7e68f 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -316,10 +316,14 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel { // changes to parse_example() where it returns stats as well. for (int i = 0; i < record_t.size(); ++i) { if (example.ParseFromString(record_t(i))) { + stats_aggregator->IncrementCounter("examples_count", "trainer", + 1); AddStatsFeatures(example, stats_aggregator); } else { SequenceExample sequence_example; if (sequence_example.ParseFromString(record_t(i))) { + stats_aggregator->IncrementCounter("sequence_examples_count", + "trainer", 1); AddStatsFeatures(sequence_example, stats_aggregator); } } @@ -360,8 +364,11 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel { int feature_values_list_size_sum = 0; for (const auto& feature : example.features().feature()) { + stats_aggregator->IncrementCounter("features_count", "trainer", 1); feature_values_list_size_sum += AddStatsFeatureValues(feature.second); } + stats_aggregator->IncrementCounter("feature_values_count", "trainer", + feature_values_list_size_sum); stats_aggregator->AddToHistogram( strings::StrCat(dataset()->tag_, ":feature-values"), {static_cast<double>(feature_values_list_size_sum)}); @@ -378,16 +385,20 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel { int feature_values_list_size_sum = 0; for (const auto& feature : example.context().feature()) { + stats_aggregator->IncrementCounter("features_count", "trainer", 1); feature_values_list_size_sum += AddStatsFeatureValues(feature.second); } for (const auto& feature_list : example.feature_lists().feature_list()) { + stats_aggregator->IncrementCounter("feature_lists_count", "reainer", + 1); for (const auto& feature : feature_list.second.feature()) { feature_values_list_size_sum += AddStatsFeatureValues(feature); } } - + stats_aggregator->IncrementCounter("feature_values_count", "trainer", + feature_values_list_size_sum); stats_aggregator->AddToHistogram( strings::StrCat(dataset()->tag_, ":feature-values"), {static_cast<double>(feature_values_list_size_sum)}); diff --git a/tensorflow/core/kernels/deserialize_sparse_variant_op.cc b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc new file mode 100644 index 0000000000..fce3029e4e --- /dev/null +++ b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc @@ -0,0 +1,372 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +namespace { + +class DeserializeSparseOp : public OpKernel { + public: + explicit DeserializeSparseOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + + OP_REQUIRES( + context, input.dims() > 0, + errors::InvalidArgument("Serialized sparse should have non-zero rank ", + input.shape().DebugString())); + OP_REQUIRES(context, input.shape().dim_size(input.dims() - 1) == 3, + errors::InvalidArgument( + "Serialized sparse should have 3 as the last dimension ", + input.shape().DebugString())); + + // `input_dims_to_stack` is the number of dimensions that will be added to + // each of the elements before they are concatenated into the output. + const int64 input_dims_to_stack = input.dims() - 1; + int num_sparse_tensors = 1; + for (int i = 0; i < input_dims_to_stack; ++i) { + num_sparse_tensors *= input.shape().dim_size(i); + } + + if (num_sparse_tensors == 1 && input_dims_to_stack == 0) { + // Special case with a single sparse tensor, and no dimensions to add + // to the output indices. We can return the boxed tensors directly (after + // validating them). + const Tensor* output_indices; + const Tensor* output_values; + const Tensor* output_shape; + const auto& input_as_vec = input.vec<Variant>(); + int64 total_non_zeros; + OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape( + input_as_vec(1), input_as_vec(2), 0, + &output_shape, &total_non_zeros)); + OP_REQUIRES_OK(context, GetAndValidateSparseTensorIndicesAndValues( + input_as_vec(0), input_as_vec(1), 0, + output_shape->NumElements(), &output_indices, + &output_values)); + context->set_output(0, *output_indices); + context->set_output(1, *output_values); + context->set_output(2, *output_shape); + return; + } + + OP_REQUIRES( + context, num_sparse_tensors > 0, + errors::InvalidArgument( + "Serialized sparse should have at least 1 serialized tensor, " + "but has a zero dimension ", + input.shape().DebugString())); + + const auto& input_as_matrix = input.flat_inner_dims<Variant, 2>(); + + // Compute the output "dense shape" of and number of non-zero elements in + // the stacked sparse tensors. Given an input of shape (S_0, ..., + // S_{input_dims_to_stack-1}, 3), and an element of dense shape (E_0, ... + // E_n), the output dense shape will be (S_0, ..., + // S_{input_dims_to_stack-1}, E_0, ..., E_n). + Tensor* output_shape; + int64 total_non_zeros = 0; + + // Allocate and build the initial output shape based on the element shape of + // the 0th sparse tensor in the input. + // + // NOTE(mrry): We define `element_shape` as a `const Tensor*` rather than a + // `Tensor` to avoid the overhead of allocating and deallocating a `Tensor` + // on the stack. While the per-`Tensor` cost is small, this op can unbox a + // large number of tensors (3 per batch element) and these fixed overheads + // dominate when the number of non-zeros per element is small. + const Tensor* element_shape; + OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape( + input_as_matrix(0, 1), input_as_matrix(0, 2), 0, + &element_shape, &total_non_zeros)); + OP_REQUIRES_OK(context, + context->allocate_output( + 2, {input_dims_to_stack + element_shape->NumElements()}, + &output_shape)); + const auto element_shape_vec = element_shape->vec<int64>(); + auto output_shape_vec = output_shape->vec<int64>(); + output_shape_vec(0) = num_sparse_tensors; + for (int64 j = 0; j < input_dims_to_stack; ++j) { + output_shape_vec(j) = input.dim_size(j); + } + for (int64 j = 0; j < element_shape->NumElements(); ++j) { + output_shape_vec(j + input_dims_to_stack) = element_shape_vec(j); + } + + // Accumulate the number of non-zero elements from the remaining sparse + // tensors, and validate that they have compatible dense shapes. + // + // NOTE(mrry): For compatibility with the implementations of + // DeserializeManySparse, and many ops that generate SparseTensors to batch + // that do not have a fixed dense_shape (e.g. `tf.parse_single_example()`), + // we compute the maximum in each dimension to find the smallest dense_shape + // that bounds all of the input SparseTensors. + for (int i = 1; i < num_sparse_tensors; ++i) { + int64 num_non_zeros; + OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape( + input_as_matrix(i, 1), input_as_matrix(i, 2), + i, &element_shape, &num_non_zeros)); + total_non_zeros += num_non_zeros; + OP_REQUIRES( + context, + output_shape->NumElements() - input_dims_to_stack == + element_shape->NumElements(), + errors::InvalidArgument( + "Inconsistent shape across SparseTensors: rank prior to " + "SparseTensor[", + i, "] was: ", output_shape->NumElements() - input_dims_to_stack, + " but rank of SparseTensor[", i, + "] is: ", element_shape->NumElements())); + const auto element_shape_vec = element_shape->vec<int64>(); + for (int j = 0; j < element_shape->NumElements(); ++j) { + output_shape_vec(j + input_dims_to_stack) = std::max( + output_shape_vec(j + input_dims_to_stack), element_shape_vec(j)); + } + } + + // Compute the output "indices" matrix and "values" vector. + Tensor* output_indices; + Tensor* output_values; + + const int output_rank = output_shape->NumElements(); + OP_REQUIRES_OK(context, + context->allocate_output( + 0, {static_cast<int64>(total_non_zeros), output_rank}, + &output_indices)); + OP_REQUIRES_OK( + context, context->allocate_output( + 1, {static_cast<int64>(total_non_zeros)}, &output_values)); + + // The bulk of the work in this method involves building the output indices + // in a tight loop. For cache friendliness, we generate the indices in the + // order that they will be laid out in memory. We use raw pointers instead + // of Eigen element/slice indexing methods, to access the underlying index + // buffer to minimize the amount of work in that tight loop. + int64* output_indices_data = output_indices->matrix<int64>().data(); + size_t current_row = 0; + + for (int i = 0; i < num_sparse_tensors; ++i) { + const Tensor* element_indices; + const Tensor* element_values; + OP_REQUIRES_OK(context, this->GetAndValidateSparseTensorIndicesAndValues( + input_as_matrix(i, 0), input_as_matrix(i, 1), + i, output_rank - input_dims_to_stack, + &element_indices, &element_values)); + + const size_t num_index_rows = element_values->NumElements(); + + // An empty sparse tensor in the input will generate no data + // in the output. We short-circuit the rest of the iteration to avoid + // triggering assertions in the Eigen when manipulating empty tensors (or + // slices of tensors). + if (num_index_rows == 0) continue; + + const size_t start_row = current_row; + const size_t next_start_row = current_row + num_index_rows; + + // NOTE(mrry): If the element is a scalar SparseTensor, + // `element_indices` will be an empty tensor, and this pointer will not + // be valid. However, we will not dereference the pointer in that case, + // because `input_dims_to_stack == output_rank`. + const int64* element_indices_data = + element_indices->matrix<int64>().data(); + + // Build the submatrix of `output_indices` for the i^th sparse tensor + // in the input. + // + // Each row of `output_indices` comprises `input_dims_to_stack` indices + // based on the position of the i^th sparse tensor in the input tensor, + // followed by the indices from the corresponding row in + // `element_indices`. + if (input_dims_to_stack == 1 && output_rank == 2) { + // We specialize this case because the compiler can generate + // more efficient code when the number of indices for each element is + // known statically. Since the most common use of this op is to + // serialize batches of SparseTensors, and the most common source of + // SparseTensors is the `tf.parse_single_example()` op, which generates + // 1-D SparseTensors, we statically unroll the loop for the rank 2 + // output case. + for (; current_row < next_start_row; ++current_row) { + *output_indices_data++ = i; + *output_indices_data++ = *element_indices_data++; + } + } else { + // `sparse_tensor_index` is the tuple of indices that correspond to + // mapping the flat element index (`i`) back onto the stacked + // coordinates implied by the position of the i^th sparse tensor in the + // input tensor. + // + // We build `sparse_tensor_index` in reverse (innermost/minor dimension + // to outermost/major dimension). The `cumulative_product` represents + // the size of the inner subtensor for which `sparse_tensor_index` has + // already been built. + gtl::InlinedVector<int64, 4> sparse_tensor_index(input_dims_to_stack); + int cumulative_product = 1; + for (size_t j = 0; j < sparse_tensor_index.size(); ++j) { + size_t reverse_index = sparse_tensor_index.size() - j - 1; + sparse_tensor_index[reverse_index] = + (i / cumulative_product) % input.dim_size(reverse_index); + cumulative_product *= input.dim_size(reverse_index); + } + for (; current_row < next_start_row; ++current_row) { + for (int64 sparse_tensor_index_component : sparse_tensor_index) { + *output_indices_data++ = sparse_tensor_index_component; + } + for (size_t k = input_dims_to_stack; k < output_rank; ++k) { + *output_indices_data++ = *element_indices_data++; + } + } + } + + // Build the subvector of `output_values` for the i^th sparse tensor + // in the input. + // + // NOTE(mrry): There is a potential optimization here where we use a T* + // to represent the current position in `output_values`, but it would + // require some rejigging of the template parameters. + // NOTE(mrry): Another potential optimization: if we know that this + // operation consumes its input, we could std::move non-primitive elements + // into the output and avoid a copy. + Eigen::DSizes<Eigen::DenseIndex, 1> values_start(start_row); + Eigen::DSizes<Eigen::DenseIndex, 1> values_sizes(num_index_rows); + +#define HANDLE_TYPE(T) \ + case DataTypeToEnum<T>::value: { \ + output_values->vec<T>().slice(values_start, values_sizes) = \ + element_values->vec<T>(); \ + break; \ + } + switch (dtype_) { + TF_CALL_ALL_TYPES(HANDLE_TYPE); + TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + default: + OP_REQUIRES_OK( + context, errors::Unimplemented( + "DeserializeSparse Unhandled data type: ", dtype_)); + } + } + } + + private: + Status GetAndValidateSparseTensorShape(const Variant& serialized_values, + const Variant& serialized_shape, + int index, const Tensor** output_shape, + int64* output_num_non_zeros) { + // Deserialize and validate the shape. + *output_shape = serialized_shape.get<Tensor>(); + if (*output_shape == nullptr) { + return errors::InvalidArgument( + "Could not get a tensor from serialized_sparse[", index, ", 2]"); + } + if ((*output_shape)->dtype() != DT_INT64) { + return errors::InvalidArgument( + "Expected serialized_sparse[", index, + ", 2] to be a vector of DT_INT64 but received dtype ", + DataTypeString((*output_shape)->dtype())); + } + if (!TensorShapeUtils::IsVector((*output_shape)->shape())) { + return errors::InvalidArgument( + "Expected serialized_sparse[", index, + ", 2] to be a shape vector but its shape is ", + (*output_shape)->shape().DebugString()); + } + *output_num_non_zeros = serialized_values.get<Tensor>()->NumElements(); + return Status::OK(); + } + + Status GetAndValidateSparseTensorIndicesAndValues( + const Variant& serialized_indices, const Variant& serialized_values, + int index, int expected_rank, const Tensor** output_indices, + const Tensor** output_values) { + // Deserialize and validate the indices. + *output_indices = serialized_indices.get<Tensor>(); + if (*output_indices == nullptr) { + return errors::InvalidArgument( + "Could not get a tensor from serialized_sparse[", index, ", 0]"); + } + if ((*output_indices)->dtype() != DT_INT64) { + return errors::InvalidArgument( + "Expected serialized_sparse[", index, + ", 0] to be a matrix of DT_INT64 but received dtype ", + DataTypeString((*output_indices)->dtype())); + } + if (!TensorShapeUtils::IsMatrix((*output_indices)->shape())) { + return errors::InvalidArgument( + "Expected serialized_sparse[", index, + ", 0] to represent an index matrix but received shape ", + (*output_indices)->shape().DebugString()); + } + int64 num_entries = (*output_indices)->dim_size(0); + int rank = (*output_indices)->dim_size(1); + if (rank != expected_rank) { + return errors::InvalidArgument( + "Expected column counts of SparseTensor[", index, + "].indices to match size of SparseTensor[", index, + "].shape but they do not: ", rank, " vs. ", expected_rank); + } + + // Deserialize and validate the values. + *output_values = serialized_values.get<Tensor>(); + if (*output_values == nullptr) { + return errors::InvalidArgument( + "Could not get a tensor from serialized_sparse[", index, ", 1]"); + } + if (!TensorShapeUtils::IsVector((*output_values)->shape())) { + return errors::InvalidArgument( + "Expected serialized_sparse[", index, + ", 1] to represent a values vector but received shape ", + (*output_values)->shape().DebugString()); + } + if (dtype_ != (*output_values)->dtype()) { + return errors::InvalidArgument( + "Requested SparseTensor of type ", DataTypeString(dtype_), + " but SparseTensor[", index, + "].values.dtype() == ", DataTypeString((*output_values)->dtype())); + } + if (num_entries != (*output_values)->dim_size(0)) { + return errors::InvalidArgument( + "Expected row counts of SparseTensor[", index, + "].indices and SparseTensor[", index, + "].values to match but they do not: ", num_entries, " vs. ", + (*output_values)->dim_size(0)); + } + + return Status::OK(); + } + + DataType dtype_; +}; + +REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") + .Device(DEVICE_CPU) + .TypeConstraint<Variant>("Tserialized"), + DeserializeSparseOp) + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index a23478af5b..d6e859f1aa 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) { return Status::OK(); } +// Defines a FIFOQueueOp, which produces a Queue (specifically, one +// backed by FIFOQueue) that persists across different graph +// executions, and sessions. Running this op produces a single-element +// tensor of handles to Queues in the corresponding device. +FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context) + : TypedQueueOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); +} + +Status FIFOQueueOp::CreateResource(QueueInterface** ret) { + FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, + component_shapes_, cinfo_.name()); + return CreateTypedQueue(queue, ret); +} + } // namespace tensorflow diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h index f01d70924d..697ee81c39 100644 --- a/tensorflow/core/kernels/fifo_queue.h +++ b/tensorflow/core/kernels/fifo_queue.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_ -#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_ +#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ +#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ #include <deque> #include <vector> @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/typed_queue.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -69,6 +70,22 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > { TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue); }; +// Defines a FIFOQueueOp, which produces a Queue (specifically, one +// backed by FIFOQueue) that persists across different graph +// executions, and sessions. Running this op produces a single-element +// tensor of handles to Queues in the corresponding device. +class FIFOQueueOp : public TypedQueueOp { + public: + explicit FIFOQueueOp(OpKernelConstruction* context); + + private: + Status CreateResource(QueueInterface** ret) override + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + std::vector<TensorShape> component_shapes_; + TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp); +}; + } // namespace tensorflow -#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_ +#endif // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc index b35bdbb2f0..80869768f1 100644 --- a/tensorflow/core/kernels/fifo_queue_op.cc +++ b/tensorflow/core/kernels/fifo_queue_op.cc @@ -13,50 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// See docs in ../ops/data_flow_ops.cc. - -#include <deque> -#include <vector> - #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/fifo_queue.h" -#include "tensorflow/core/kernels/queue_base.h" -#include "tensorflow/core/kernels/queue_op.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" namespace tensorflow { -// Defines a FIFOQueueOp, which produces a Queue (specifically, one -// backed by FIFOQueue) that persists across different graph -// executions, and sessions. Running this op produces a single-element -// tensor of handles to Queues in the corresponding device. -class FIFOQueueOp : public TypedQueueOp { - public: - explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) { - OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); - } - - private: - Status CreateResource(QueueInterface** ret) override - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, - component_shapes_, cinfo_.name()); - return CreateTypedQueue(queue, ret); - } - - std::vector<TensorShape> component_shapes_; - TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp); -}; - REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp); REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp); diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index cede0b9dd6..1d0edb10b3 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -70,23 +70,25 @@ struct MklConvFwdParams { memory::dims padding_left; memory::dims padding_right; - MklConvFwdParams(memory::dims src_dims, - memory::dims filter_dims, memory::dims bias_dims, - memory::dims dst_dims, memory::dims strides, - memory::dims dilations, memory::dims padding_left, - memory::dims padding_right) : - src_dims(src_dims), filter_dims(filter_dims), - bias_dims(bias_dims), dst_dims(dst_dims), - strides(strides), dilations(dilations), - padding_left(padding_left), padding_right(padding_right) { - } + MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims, + memory::dims bias_dims, memory::dims dst_dims, + memory::dims strides, memory::dims dilations, + memory::dims padding_left, memory::dims padding_right) + : src_dims(src_dims), + filter_dims(filter_dims), + bias_dims(bias_dims), + dst_dims(dst_dims), + strides(strides), + dilations(dilations), + padding_left(padding_left), + padding_right(padding_right) {} }; template <typename T> -class MklConv2DFwdPrimitive: public MklPrimitive { +class MklConv2DFwdPrimitive : public MklPrimitive { public: - explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims) : - cpu_engine_(engine::cpu, 0) { + explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims) + : cpu_engine_(engine::cpu, 0) { context_.fwd_stream.reset(new stream(stream::kind::eager)); // create conv primitive if (context_.conv_fwd == nullptr) { @@ -101,8 +103,8 @@ class MklConv2DFwdPrimitive: public MklPrimitive { // filter_data: input data buffer of filter (weights) // bias_data: input data buffer of bias // dst_data: output data buffer of dst - void Execute(const T* src_data, const T* filter_data, - const T* bias_data, const T* dst_data) { + void Execute(const T* src_data, const T* filter_data, const T* bias_data, + const T* dst_data) { context_.src_mem->set_data_handle( static_cast<void*>(const_cast<T*>(src_data))); context_.filter_mem->set_data_handle( @@ -126,8 +128,7 @@ class MklConv2DFwdPrimitive: public MklPrimitive { // src_data: input data buffer of src // filter_data: input data buffer of filter (weights) // dst_data: output data buffer of dst - void Execute(const T* src_data, const T* filter_data, - const T* dst_data) { + void Execute(const T* src_data, const T* filter_data, const T* dst_data) { context_.src_mem->set_data_handle( static_cast<void*>(const_cast<T*>(src_data))); context_.filter_mem->set_data_handle( @@ -142,13 +143,9 @@ class MklConv2DFwdPrimitive: public MklPrimitive { context_.dst_mem->set_data_handle(DummyData); } - memory::format GetSrcMemoryFormat() const { - return context_.src_fmt; - } + memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } - memory::format GetFilterMemoryFormat() const { - return context_.filter_fmt; - } + memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; } std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetPrimitiveDesc() const { @@ -184,43 +181,50 @@ class MklConv2DFwdPrimitive: public MklPrimitive { std::shared_ptr<mkldnn::stream> fwd_stream; std::vector<mkldnn::primitive> fwd_primitives; - ConvFwdContext() : - src_fmt(memory::format::any), filter_fmt(memory::format::any), - src_mem(nullptr), filter_mem(nullptr), bias_mem(nullptr), - dst_mem(nullptr), fwd_desc(nullptr), - src_md(nullptr), filter_md(nullptr), bias_md(nullptr), - fwd_pd(nullptr), conv_fwd(nullptr), fwd_stream(nullptr) { - } + ConvFwdContext() + : src_fmt(memory::format::any), + filter_fmt(memory::format::any), + src_mem(nullptr), + filter_mem(nullptr), + bias_mem(nullptr), + dst_mem(nullptr), + fwd_desc(nullptr), + src_md(nullptr), + filter_md(nullptr), + bias_md(nullptr), + fwd_pd(nullptr), + conv_fwd(nullptr), + fwd_stream(nullptr) {} }; void Setup(const MklConvFwdParams& convFwdDims) { // create memory descriptors for convolution data w/ no specified format - context_.src_md.reset(new memory::desc({convFwdDims.src_dims}, - MklDnnType<T>(), memory::format::any)); + context_.src_md.reset(new memory::desc( + {convFwdDims.src_dims}, MklDnnType<T>(), memory::format::any)); - context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims}, - MklDnnType<T>(), memory::format::any)); + context_.filter_md.reset(new memory::desc( + {convFwdDims.filter_dims}, MklDnnType<T>(), memory::format::any)); - context_.dst_md.reset(new memory::desc({convFwdDims.dst_dims}, - MklDnnType<T>(), memory::format::any)); + context_.dst_md.reset(new memory::desc( + {convFwdDims.dst_dims}, MklDnnType<T>(), memory::format::any)); if (!convFwdDims.bias_dims.empty()) - context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, - MklDnnType<T>(), memory::format::any)); + context_.bias_md.reset(new memory::desc( + {convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::any)); // create a convolution if (!convFwdDims.bias_dims.empty()) { - context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward, - convolution_direct, *context_.src_md, *context_.filter_md, - *context_.bias_md, *context_.dst_md, + context_.fwd_desc.reset(new convolution_forward::desc( + prop_kind::forward, convolution_direct, *context_.src_md, + *context_.filter_md, *context_.bias_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right, padding_kind::zero)); } else { - context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward, - convolution_direct, *context_.src_md, *context_.filter_md, - *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, - convFwdDims.padding_left, convFwdDims.padding_right, - padding_kind::zero)); + context_.fwd_desc.reset(new convolution_forward::desc( + prop_kind::forward, convolution_direct, *context_.src_md, + *context_.filter_md, *context_.dst_md, convFwdDims.strides, + convFwdDims.dilations, convFwdDims.padding_left, + convFwdDims.padding_right, padding_kind::zero)); } context_.fwd_pd.reset(new convolution_forward::primitive_desc( @@ -234,24 +238,26 @@ class MklConv2DFwdPrimitive: public MklPrimitive { context_.fwd_pd.get()->weights_primitive_desc().desc().data.format); // create memory primitive based on dummy data - context_.src_mem.reset(new memory( - context_.fwd_pd.get()->src_primitive_desc(), DummyData)); - context_.filter_mem.reset(new memory( - context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); - context_.dst_mem.reset(new memory( - context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); + context_.src_mem.reset( + new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData)); + context_.filter_mem.reset( + new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); + context_.dst_mem.reset( + new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); // create convolution primitive and add it to net if (!convFwdDims.bias_dims.empty()) { - context_.bias_mem.reset(new memory({{{convFwdDims.bias_dims}, - MklDnnType<T>(), memory::format::x}, cpu_engine_}, DummyData)); - context_.conv_fwd.reset(new convolution_forward( - *context_.fwd_pd, *context_.src_mem, *context_.filter_mem, - *context_.bias_mem, *context_.dst_mem)); + context_.bias_mem.reset(new memory( + {{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x}, + cpu_engine_}, + DummyData)); + context_.conv_fwd.reset(new convolution_forward( + *context_.fwd_pd, *context_.src_mem, *context_.filter_mem, + *context_.bias_mem, *context_.dst_mem)); } else { - context_.conv_fwd.reset(new convolution_forward( - *context_.fwd_pd, *context_.src_mem, - *context_.filter_mem, *context_.dst_mem)); + context_.conv_fwd.reset( + new convolution_forward(*context_.fwd_pd, *context_.src_mem, + *context_.filter_mem, *context_.dst_mem)); } context_.fwd_primitives.push_back(*context_.conv_fwd); @@ -266,19 +272,19 @@ template <typename T> class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> { public: static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) { - MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr; - - // try to find a suitable one in pool - conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*> ( - MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd( - convFwdDims)); - - if (conv2d_fwd == nullptr) { - conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims); - MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd( - convFwdDims, conv2d_fwd); - } - return conv2d_fwd; + MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr; + + // try to find a suitable one in pool + conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>( + MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd( + convFwdDims)); + + if (conv2d_fwd == nullptr) { + conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims); + MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims, + conv2d_fwd); + } + return conv2d_fwd; } private: @@ -312,7 +318,7 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> { return this->GetOp(key); } - void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive *op) { + void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { std::string key = CreateKey(convFwdDims); this->SetOp(key, op); } @@ -865,22 +871,24 @@ class MklConv2DOp : public OpKernel { dilations[kDilationW] -= 1; // get a conv2d fwd from primitive pool - MklConv2DFwdPrimitive<T> *conv2d_fwd = nullptr; + MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr; if (biasEnabled) { memory::dims bias_dims = {}; conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims, - dst_dims_mkl_order, strides, dilations, padding_left, padding_right); + dst_dims_mkl_order, strides, dilations, + padding_left, padding_right); conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims); } else { MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS, - dst_dims_mkl_order, strides, dilations, padding_left, padding_right); + dst_dims_mkl_order, strides, dilations, + padding_left, padding_right); conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims); } // allocate output tensors output_tensor and filter_out_tensor - std::shared_ptr<mkldnn::convolution_forward::primitive_desc> - conv_fwd_pd = conv2d_fwd->GetPrimitiveDesc(); + std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd = + conv2d_fwd->GetPrimitiveDesc(); AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt, &dst_tensor); Tensor* filter_out_tensor = nullptr; @@ -892,26 +900,24 @@ class MklConv2DOp : public OpKernel { // check whether src/filter need reorder std::vector<primitive> net; - T *src_data = nullptr; + T* src_data = nullptr; if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) { src.SetUsrMem(src_md, &src_tensor); - src.CheckReorderToOpMem( - conv_fwd_pd.get()->src_primitive_desc(), &net); + src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc(), &net); src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); } else { - src_data = static_cast<T*>(const_cast<T*>( - src_tensor.flat<T>().data())); + src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); } - T *filter_data = nullptr; + T* filter_data = nullptr; if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) { filter.SetUsrMem(filter_md, &filter_tensor); - filter.CheckReorderToOpMem( - conv_fwd_pd.get()->weights_primitive_desc(), - filter.GetTensorBuffer(filter_out_tensor), &net); + filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(), + filter.GetTensorBuffer(filter_out_tensor), + &net); filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle()); } else { - filter_data = static_cast<T*>(const_cast<T*>( - filter_tensor.flat<T>().data())); + filter_data = + static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data())); } stream(stream::kind::eager).submit(net).wait(); diff --git a/tensorflow/core/kernels/queue_op.cc b/tensorflow/core/kernels/queue_op.cc new file mode 100644 index 0000000000..53f431ef3c --- /dev/null +++ b/tensorflow/core/kernels/queue_op.cc @@ -0,0 +1,367 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/queue_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); + if (capacity_ < 0) { + capacity_ = QueueBase::kUnbounded; + } + OP_REQUIRES_OK(context, + context->GetAttr("component_types", &component_types_)); +} + +void QueueOp::Compute(OpKernelContext* context) { + ResourceOpKernel<QueueInterface>::Compute(context); + mutex_lock l(mu_); + if (resource_ && context->track_allocations()) { + context->record_persistent_memory_allocation(resource_->MemoryUsed()); + } +} + +Status QueueOp::VerifyResource(QueueInterface* queue) { + return queue->MatchesNodeDef(def()); +} + + +QueueOpKernel::QueueOpKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + +void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) { + QueueInterface* queue; + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), + callback); + } + ComputeAsync(ctx, queue, [callback, queue]() { + queue->Unref(); + callback(); + }); +} + +QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context) + : QueueOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); + // TODO(keveman): Enable timeout. + OP_REQUIRES(context, timeout_ == -1, + errors::InvalidArgument("Timeout not supported yet.")); +} + +// Defines an EnqueueOp, the execution of which enqueues a tuple of +// tensors in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +EnqueueOp::EnqueueOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + DataTypeVector expected_inputs; + if (ctx->input_dtype(0) == DT_RESOURCE) { + expected_inputs.push_back(DT_RESOURCE); + } else { + expected_inputs.push_back(DT_STRING_REF); + } + for (DataType dt : queue->component_dtypes()) { + expected_inputs.push_back(dt); + } + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback); + + QueueInterface::Tuple tuple; + OpInputList components; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), + callback); + for (const Tensor& Tcomponent : components) { + tuple.push_back(Tcomponent); + } + + OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); + queue->TryEnqueue(tuple, ctx, callback); +} + +// Defines an EnqueueManyOp, the execution of which slices each +// component of a tuple of tensors along the 0th dimension, and +// enqueues tuples of slices in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +// +// N.B. All tuple components must have the same size in the 0th +// dimension. +EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + DataTypeVector expected_inputs; + if (ctx->input_dtype(0) == DT_RESOURCE) { + expected_inputs.push_back(DT_RESOURCE); + } else { + expected_inputs.push_back(DT_STRING_REF); + } + for (DataType dt : queue->component_dtypes()) { + expected_inputs.push_back(dt); + } + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback); + + QueueInterface::Tuple tuple; + OpInputList components; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), + callback); + for (const Tensor& Tcomponent : components) { + tuple.push_back(Tcomponent); + } + + OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); + queue->TryEnqueueMany(tuple, ctx, callback); +} + +EnqueueManyOp::~EnqueueManyOp() = default; + +// Defines a DequeueOp, the execution of which dequeues a tuple of +// tensors from the given Queue. +// +// The op has one input, which is the handle of the appropriate +// Queue. The op has k outputs, where k is the number of components in +// the tuples stored in the given Queue, and output i is the ith +// component of the dequeued tuple. +DequeueOp::DequeueOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), + callback); + } + + queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); +} + +DequeueOp::~DequeueOp() = default; + +// Defines a DequeueManyOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +DequeueManyOp::DequeueManyOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + const Tensor& Tnum_elements = ctx->input(1); + int32 num_elements = Tnum_elements.flat<int32>()(0); + + OP_REQUIRES_ASYNC(ctx, num_elements >= 0, + errors::InvalidArgument("DequeueManyOp requested ", + num_elements, " < 0 elements"), + callback); + + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + } + + queue->TryDequeueMany( + num_elements, ctx, false /* allow_small_batch */, + [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); +} + +DequeueManyOp::~DequeueManyOp() = default; + +// Defines a DequeueUpToOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The difference between this op and DequeueMany is the handling when +// the Queue is closed. While the DequeueMany op will return if there +// an error when there are less than num_elements elements left in the +// closed queue, this op will return between 1 and +// min(num_elements, elements_remaining_in_queue), and will not block. +// If there are no elements left, then the standard DequeueMany error +// is returned. +// +// This op only works if the underlying Queue implementation accepts +// the allow_small_batch = true parameter to TryDequeueMany. +// If it does not, an errors::Unimplemented exception is returned. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +// +// The op has one attribute: allow_small_batch. If the Queue supports +// it, setting this to true causes the queue to return smaller +// (possibly zero length) batches when it is closed, up to however +// many elements are available when the op executes. In this case, +// the Queue does not block when closed. +DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + +void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + const Tensor& Tnum_elements = ctx->input(1); + int32 num_elements = Tnum_elements.flat<int32>()(0); + + OP_REQUIRES_ASYNC(ctx, num_elements >= 0, + errors::InvalidArgument("DequeueUpToOp requested ", + num_elements, " < 0 elements"), + callback); + + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + } + + queue->TryDequeueMany( + num_elements, ctx, true /* allow_small_batch */, + [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); +} + +DequeueUpToOp::~DequeueUpToOp() = default; + +// Defines a QueueCloseOp, which closes the given Queue. Closing a +// Queue signals that no more elements will be enqueued in it. +// +// The op has one input, which is the handle of the appropriate Queue. +QueueCloseOp::QueueCloseOp(OpKernelConstruction* context) + : QueueOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", + &cancel_pending_enqueues_)); +} + +void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + queue->Close(ctx, cancel_pending_enqueues_, callback); +} + +// Defines a QueueSizeOp, which computes the number of elements in the +// given Queue, and emits it as an output tensor. +// +// The op has one input, which is the handle of the appropriate Queue; +// and one output, which is a single-element tensor containing the current +// size of that Queue. +QueueSizeOp::QueueSizeOp(OpKernelConstruction* context) + : QueueOpKernel(context) {} + +void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + Tensor* Tqueue_size = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); + Tqueue_size->flat<int32>().setConstant(queue->size()); + callback(); +} + +QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context) + : QueueOpKernel(context) {} + +void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) { + Tensor* Tqueue_is_closed = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed)); + Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed()); + callback(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h index 6c19f9841c..2efd838a5f 100644 --- a/tensorflow/core/kernels/queue_op.h +++ b/tensorflow/core/kernels/queue_op.h @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_ -#define TENSORFLOW_KERNELS_QUEUE_OP_H_ +#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ #include <deque> #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -32,22 +33,9 @@ namespace tensorflow { // Defines a QueueOp, an abstract class for Queue construction ops. class QueueOp : public ResourceOpKernel<QueueInterface> { public: - QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); - if (capacity_ < 0) { - capacity_ = QueueBase::kUnbounded; - } - OP_REQUIRES_OK(context, - context->GetAttr("component_types", &component_types_)); - } + QueueOp(OpKernelConstruction* context); - void Compute(OpKernelContext* context) override { - ResourceOpKernel<QueueInterface>::Compute(context); - mutex_lock l(mu_); - if (resource_ && context->track_allocations()) { - context->record_persistent_memory_allocation(resource_->MemoryUsed()); - } - } + void Compute(OpKernelContext* context) override; protected: // Variables accessible by subclasses @@ -55,9 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> { DataTypeVector component_types_; private: - Status VerifyResource(QueueInterface* queue) override { - return queue->MatchesNodeDef(def()); - } + Status VerifyResource(QueueInterface* queue) override; }; class TypedQueueOp : public QueueOp { @@ -75,6 +61,211 @@ class TypedQueueOp : public QueueOp { } }; +// Queue manipulator kernels + +class QueueOpKernel : public AsyncOpKernel { + public: + explicit QueueOpKernel(OpKernelConstruction* context); + + void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final; + + protected: + virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) = 0; +}; + +class QueueAccessOpKernel : public QueueOpKernel { + public: + explicit QueueAccessOpKernel(OpKernelConstruction* context); + + protected: + int64 timeout_; +}; + +// Defines an EnqueueOp, the execution of which enqueues a tuple of +// tensors in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +class EnqueueOp : public QueueAccessOpKernel { + public: + explicit EnqueueOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp); +}; + +// Defines an EnqueueManyOp, the execution of which slices each +// component of a tuple of tensors along the 0th dimension, and +// enqueues tuples of slices in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +// +// N.B. All tuple components must have the same size in the 0th +// dimension. +class EnqueueManyOp : public QueueAccessOpKernel { + public: + explicit EnqueueManyOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~EnqueueManyOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp); +}; + +// Defines a DequeueOp, the execution of which dequeues a tuple of +// tensors from the given Queue. +// +// The op has one input, which is the handle of the appropriate +// Queue. The op has k outputs, where k is the number of components in +// the tuples stored in the given Queue, and output i is the ith +// component of the dequeued tuple. +class DequeueOp : public QueueAccessOpKernel { + public: + explicit DequeueOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp); +}; + +// Defines a DequeueManyOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +class DequeueManyOp : public QueueAccessOpKernel { + public: + explicit DequeueManyOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueManyOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp); +}; + +// Defines a DequeueUpToOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The difference between this op and DequeueMany is the handling when +// the Queue is closed. While the DequeueMany op will return if there +// an error when there are less than num_elements elements left in the +// closed queue, this op will return between 1 and +// min(num_elements, elements_remaining_in_queue), and will not block. +// If there are no elements left, then the standard DequeueMany error +// is returned. +// +// This op only works if the underlying Queue implementation accepts +// the allow_small_batch = true parameter to TryDequeueMany. +// If it does not, an errors::Unimplemented exception is returned. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +// +// The op has one attribute: allow_small_batch. If the Queue supports +// it, setting this to true causes the queue to return smaller +// (possibly zero length) batches when it is closed, up to however +// many elements are available when the op executes. In this case, +// the Queue does not block when closed. +class DequeueUpToOp : public QueueAccessOpKernel { + public: + explicit DequeueUpToOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueUpToOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp); +}; + +// Defines a QueueCloseOp, which closes the given Queue. Closing a +// Queue signals that no more elements will be enqueued in it. +// +// The op has one input, which is the handle of the appropriate Queue. +class QueueCloseOp : public QueueOpKernel { + public: + explicit QueueCloseOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + bool cancel_pending_enqueues_; + TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp); +}; + +// Defines a QueueSizeOp, which computes the number of elements in the +// given Queue, and emits it as an output tensor. +// +// The op has one input, which is the handle of the appropriate Queue; +// and one output, which is a single-element tensor containing the current +// size of that Queue. +class QueueSizeOp : public QueueOpKernel { + public: + explicit QueueSizeOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp); +}; + +class QueueIsClosedOp : public QueueOpKernel { + public: + explicit QueueIsClosedOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp); +}; + } // namespace tensorflow -#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_ +#endif // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index 46a02854d7..c4d404259b 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -13,437 +13,44 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// See docs in ../ops/data_flow_ops.cc. - #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/queue_interface.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -class QueueOpKernel : public AsyncOpKernel { - public: - explicit QueueOpKernel(OpKernelConstruction* context) - : AsyncOpKernel(context) {} - - void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { - QueueInterface* queue; - if (ctx->input_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK_ASYNC( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback); - } else { - OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), - callback); - } - ComputeAsync(ctx, queue, [callback, queue]() { - queue->Unref(); - callback(); - }); - } - - protected: - virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) = 0; -}; - -class QueueAccessOpKernel : public QueueOpKernel { - public: - explicit QueueAccessOpKernel(OpKernelConstruction* context) - : QueueOpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); - // TODO(keveman): Enable timeout. - OP_REQUIRES(context, timeout_ == -1, - errors::InvalidArgument("Timeout not supported yet.")); - } - - protected: - int64 timeout_; -}; - -// Defines an EnqueueOp, the execution of which enqueues a tuple of -// tensors in the given Queue. -// -// The op has 1 + k inputs, where k is the number of components in the -// tuples stored in the given Queue: -// - Input 0: queue handle. -// - Input 1: 0th element of the tuple. -// - ... -// - Input (1+k): kth element of the tuple. -class EnqueueOp : public QueueAccessOpKernel { - public: - explicit EnqueueOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - DataTypeVector expected_inputs; - if (ctx->input_dtype(0) == DT_RESOURCE) { - expected_inputs.push_back(DT_RESOURCE); - } else { - expected_inputs.push_back(DT_STRING_REF); - } - for (DataType dt : queue->component_dtypes()) { - expected_inputs.push_back(dt); - } - OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), - callback); - - QueueInterface::Tuple tuple; - OpInputList components; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), - callback); - for (const Tensor& Tcomponent : components) { - tuple.push_back(Tcomponent); - } - - OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); - queue->TryEnqueue(tuple, ctx, callback); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp); REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp); -// Defines an EnqueueManyOp, the execution of which slices each -// component of a tuple of tensors along the 0th dimension, and -// enqueues tuples of slices in the given Queue. -// -// The op has 1 + k inputs, where k is the number of components in the -// tuples stored in the given Queue: -// - Input 0: queue handle. -// - Input 1: 0th element of the tuple. -// - ... -// - Input (1+k): kth element of the tuple. -// -// N.B. All tuple components must have the same size in the 0th -// dimension. -class EnqueueManyOp : public QueueAccessOpKernel { - public: - explicit EnqueueManyOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - DataTypeVector expected_inputs; - if (ctx->input_dtype(0) == DT_RESOURCE) { - expected_inputs.push_back(DT_RESOURCE); - } else { - expected_inputs.push_back(DT_STRING_REF); - } - for (DataType dt : queue->component_dtypes()) { - expected_inputs.push_back(dt); - } - OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), - callback); - - QueueInterface::Tuple tuple; - OpInputList components; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), - callback); - for (const Tensor& Tcomponent : components) { - tuple.push_back(Tcomponent); - } - - OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); - queue->TryEnqueueMany(tuple, ctx, callback); - } - - ~EnqueueManyOp() override {} - - private: - TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU), EnqueueManyOp); REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU), EnqueueManyOp); -// Defines a DequeueOp, the execution of which dequeues a tuple of -// tensors from the given Queue. -// -// The op has one input, which is the handle of the appropriate -// Queue. The op has k outputs, where k is the number of components in -// the tuples stored in the given Queue, and output i is the ith -// component of the dequeued tuple. -class DequeueOp : public QueueAccessOpKernel { - public: - explicit DequeueOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - if (ctx->input_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK_ASYNC( - ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()), - callback); - } else { - OP_REQUIRES_OK_ASYNC( - ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), - callback); - } - - queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { - if (!ctx->status().ok()) { - callback(); - return; - } - OpOutputList output_components; - OP_REQUIRES_OK_ASYNC( - ctx, ctx->output_list("components", &output_components), callback); - for (int i = 0; i < ctx->num_outputs(); ++i) { - output_components.set(i, tuple[i]); - } - callback(); - }); - } - - ~DequeueOp() override {} - - private: - TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp); REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp); -// Defines a DequeueManyOp, the execution of which concatenates the -// requested number of elements from the given Queue along the 0th -// dimension, and emits the result as a single tuple of tensors. -// -// The op has two inputs: -// - Input 0: the handle to a queue. -// - Input 1: the number of elements to dequeue. -// -// The op has k outputs, where k is the number of components in the -// tuples stored in the given Queue, and output i is the ith component -// of the dequeued tuple. -class DequeueManyOp : public QueueAccessOpKernel { - public: - explicit DequeueManyOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - const Tensor& Tnum_elements = ctx->input(1); - int32 num_elements = Tnum_elements.flat<int32>()(0); - - OP_REQUIRES_ASYNC(ctx, num_elements >= 0, - errors::InvalidArgument("DequeueManyOp requested ", - num_elements, " < 0 elements"), - callback); - - if (ctx->input_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK_ASYNC(ctx, - ctx->MatchSignature({DT_RESOURCE, DT_INT32}, - queue->component_dtypes()), - callback); - } else { - OP_REQUIRES_OK_ASYNC(ctx, - ctx->MatchSignature({DT_STRING_REF, DT_INT32}, - queue->component_dtypes()), - callback); - } - - queue->TryDequeueMany( - num_elements, ctx, false /* allow_small_batch */, - [ctx, callback](const QueueInterface::Tuple& tuple) { - if (!ctx->status().ok()) { - callback(); - return; - } - OpOutputList output_components; - OP_REQUIRES_OK_ASYNC( - ctx, ctx->output_list("components", &output_components), - callback); - for (int i = 0; i < ctx->num_outputs(); ++i) { - output_components.set(i, tuple[i]); - } - callback(); - }); - } - - ~DequeueManyOp() override {} - - private: - TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), DequeueManyOp); REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU), DequeueManyOp); -// Defines a DequeueUpToOp, the execution of which concatenates the -// requested number of elements from the given Queue along the 0th -// dimension, and emits the result as a single tuple of tensors. -// -// The difference between this op and DequeueMany is the handling when -// the Queue is closed. While the DequeueMany op will return if there -// an error when there are less than num_elements elements left in the -// closed queue, this op will return between 1 and -// min(num_elements, elements_remaining_in_queue), and will not block. -// If there are no elements left, then the standard DequeueMany error -// is returned. -// -// This op only works if the underlying Queue implementation accepts -// the allow_small_batch = true parameter to TryDequeueMany. -// If it does not, an errors::Unimplemented exception is returned. -// -// The op has two inputs: -// - Input 0: the handle to a queue. -// - Input 1: the number of elements to dequeue. -// -// The op has k outputs, where k is the number of components in the -// tuples stored in the given Queue, and output i is the ith component -// of the dequeued tuple. -// -// The op has one attribute: allow_small_batch. If the Queue supports -// it, setting this to true causes the queue to return smaller -// (possibly zero length) batches when it is closed, up to however -// many elements are available when the op executes. In this case, -// the Queue does not block when closed. -class DequeueUpToOp : public QueueAccessOpKernel { - public: - explicit DequeueUpToOp(OpKernelConstruction* context) - : QueueAccessOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - const Tensor& Tnum_elements = ctx->input(1); - int32 num_elements = Tnum_elements.flat<int32>()(0); - - OP_REQUIRES_ASYNC(ctx, num_elements >= 0, - errors::InvalidArgument("DequeueUpToOp requested ", - num_elements, " < 0 elements"), - callback); - - if (ctx->input_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK_ASYNC(ctx, - ctx->MatchSignature({DT_RESOURCE, DT_INT32}, - queue->component_dtypes()), - callback); - } else { - OP_REQUIRES_OK_ASYNC(ctx, - ctx->MatchSignature({DT_STRING_REF, DT_INT32}, - queue->component_dtypes()), - callback); - } - - queue->TryDequeueMany( - num_elements, ctx, true /* allow_small_batch */, - [ctx, callback](const QueueInterface::Tuple& tuple) { - if (!ctx->status().ok()) { - callback(); - return; - } - OpOutputList output_components; - OP_REQUIRES_OK_ASYNC( - ctx, ctx->output_list("components", &output_components), - callback); - for (int i = 0; i < ctx->num_outputs(); ++i) { - output_components.set(i, tuple[i]); - } - callback(); - }); - } - - ~DequeueUpToOp() override {} - - private: - TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU), DequeueUpToOp); REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU), DequeueUpToOp); -// Defines a QueueCloseOp, which closes the given Queue. Closing a -// Queue signals that no more elements will be enqueued in it. -// -// The op has one input, which is the handle of the appropriate Queue. -class QueueCloseOp : public QueueOpKernel { - public: - explicit QueueCloseOp(OpKernelConstruction* context) - : QueueOpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", - &cancel_pending_enqueues_)); - } - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - queue->Close(ctx, cancel_pending_enqueues_, callback); - } - - private: - bool cancel_pending_enqueues_; - TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp); REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp); -// Defines a QueueSizeOp, which computes the number of elements in the -// given Queue, and emits it as an output tensor. -// -// The op has one input, which is the handle of the appropriate Queue; -// and one output, which is a single-element tensor containing the current -// size of that Queue. -class QueueSizeOp : public QueueOpKernel { - public: - explicit QueueSizeOp(OpKernelConstruction* context) - : QueueOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - Tensor* Tqueue_size = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); - Tqueue_size->flat<int32>().setConstant(queue->size()); - callback(); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp); REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp); -class QueueIsClosedOp : public QueueOpKernel { - public: - explicit QueueIsClosedOp(OpKernelConstruction* context) - : QueueOpKernel(context) {} - - protected: - void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, - DoneCallback callback) override { - Tensor* Tqueue_is_closed = nullptr; - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed)); - Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed()); - callback(); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp); -}; - REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU), QueueIsClosedOp); REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index 15004ae4df..2da83a0288 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -24,6 +24,12 @@ limitations under the License. // non-GPU targets. This only breaks in clang, because it's more strict for // template code and CudaAtomicMax is used in template context. +// This file requires the following include because it uses CudaAtomicMax: +// #include "tensorflow/core/util/cuda_kernel_helper.h" + +// Unfortunately we can't add the #include, since it breaks compilation for +// non-GPU targets. This only breaks in clang, because it's more strict for +// template code and CudaAtomicMax is used in template context. // This file requires the following include because it uses CudaAtomicMax: // #include "tensorflow/core/util/cuda_kernel_helper.h" diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index 4ad653601a..4fea57e6b7 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -559,16 +559,4 @@ REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU), DeserializeSparseOp<string>) -template <> -Status DeserializeSparseOp<Variant>::Deserialize(const Variant& serialized, - Tensor* result) { - *result = *serialized.get<Tensor>(); - return Status::OK(); -} - -REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") - .Device(DEVICE_CPU) - .TypeConstraint<Variant>("Tserialized"), - DeserializeSparseOp<Variant>) - } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 37803ec775..5aa5d20b1a 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -735,6 +735,7 @@ class TensorArrayPackOrGatherOp : public OpKernel { TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK); +TF_CALL_variant(REGISTER_GATHER_AND_PACK); REGISTER_GATHER_AND_PACK(quint8); REGISTER_GATHER_AND_PACK(qint8); REGISTER_GATHER_AND_PACK(qint32); diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 7fd5809ca4..eadea18f76 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -73,9 +73,6 @@ void VariableOp::Compute(OpKernelContext* ctx) { // here is valid because it owns a ref on var. ctx->set_output_ref(0, var->mu(), var->tensor()); if (ctx->track_allocations() && var->tensor()->IsInitialized()) { - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - attr.set_nic_compatible(true); ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes()); } var->Unref(); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 96944f27cd..b5e42f5384 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -1851,7 +1851,7 @@ class MklPrimitiveFactory { } private: - static inline std::unordered_map<std::string, MklPrimitive*> &GetHashMap() { + static inline std::unordered_map<std::string, MklPrimitive*>& GetHashMap() { static thread_local std::unordered_map<std::string, MklPrimitive*> map_; return map_; } diff --git a/tensorflow/docs_src/api_guides/python/spectral_ops.md b/tensorflow/docs_src/api_guides/python/spectral_ops.md index 022c471ef1..dd13802f00 100644 --- a/tensorflow/docs_src/api_guides/python/spectral_ops.md +++ b/tensorflow/docs_src/api_guides/python/spectral_ops.md @@ -23,3 +23,4 @@ that you can use to transform Tensors of real and complex signals. ## Discrete Cosine Transforms * @{tf.spectral.dct} +* @{tf.spectral.idct} diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md new file mode 100644 index 0000000000..bd2a80d9ef --- /dev/null +++ b/tensorflow/docs_src/get_started/index.md @@ -0,0 +1,29 @@ +# Get Started + +If you are new to machine learning, we recommend taking the following online +course prior to diving into TensorFlow documentation: + + * [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/), + which introduces machine learning concepts and encourages experimentation + with existing TensorFlow code. + +TensorFlow is a tool for machine learning. While it contains a wide range of +functionality, TensorFlow is mainly designed for deep neural network models. + +The easiest way to get started with TensorFlow is by using Eager Execution. + + * @{$get_started/eager}, is for anyone new to machine learning or TensorFlow. + +TensorFlow provides many APIs. The remainder of this section focuses on the +Estimator API which provide scalable, high-performance models. See the +@{$estimators} guide. + +For more advanced users: + + * The @{$low_level_intro$Low Level Introduction} demonstrates how to use + TensorFlow outside of the Estimator framework, for debugging and + experimentation. + * The @{$guide$Programmer's Guide} details major + TensorFlow components. + * The @{$tutorials$Tutorials} provide walkthroughs of a variety of + TensorFlow models. diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index ce43d09b63..4c4f3f3934 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -2010,13 +2010,35 @@ Slice(b, {2, 1}, {4, 3}) produces: See also [`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). -Sorts the elements in the operand. +There are two versions of the Sort instruction: a single-operand and a +two-operand version. <b>`Sort(operand)`</b> Arguments | Type | Semantics +--------- | ------- | -------------------- +`operand` | `XlaOp` | The operand to sort. + +Sorts the elements in the operand in ascending order. The operand must be rank-1. +If the operand's elements have floating point type, and the operand contains +NaN elements, the order of elements in the output is implementation-defined. + +<b>`Sort(key, value)`</b> + +Sorts both the key and the value operands. The keys are sorted as in the +single-operand version. The values are sorted according to the order of their +corresponding keys. For example, if the inputs are `keys = [3, 1]` and +`values = [42, 50]`, then the output of the sort is the tuple `{[1, 3], [50, 42]}`. +The sort is not guaranteed to be stable, that is, if the keys array contains +duplicates, the order of their corresponding values may not be preserved. + +Arguments | Type | Semantics --------- | ------- | ------------------- -`operand` | `XlaOp` | The operand to sort +`keys` | `XlaOp` | The sort keys. +`values` | `XlaOp` | The values to sort. + +The `keys` and `values` operand must both be rank-1, and must have the same +dimensions, but may have different element types. ## Transpose diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go index 35b0cb352e..ea8af221ae 100644 --- a/tensorflow/go/attrs_test.go +++ b/tensorflow/go/attrs_test.go @@ -28,7 +28,7 @@ func TestOperationAttrs(t *testing.T) { i := 0 makeConst := func(v interface{}) Output { op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v) - i += 1 + i++ if err != nil { t.Fatal(err) } @@ -71,6 +71,7 @@ func TestOperationAttrs(t *testing.T) { "boundaries": []float32(nil), }, }, + /* TODO(ashankar): debug this issue and add it back later. { Name: "list(type),list(shape)", Type: "InfeedEnqueueTuple", @@ -111,6 +112,7 @@ func TestOperationAttrs(t *testing.T) { "device_ordinal": int64(0), }, }, + */ { Name: "list(int),int", Type: "StringToHashBucketStrong", diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 2df69ee299..d5bd99bdd9 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -36,20 +36,21 @@ namespace java { namespace { constexpr const char kLicense[] = - "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" - "\n" - "Licensed under the Apache License, Version 2.0 (the \"License\");\n" - "you may not use this file except in compliance with the License.\n" - "You may obtain a copy of the License at\n" - "\n" - " http://www.apache.org/licenses/LICENSE-2.0\n" - "\n" - "Unless required by applicable law or agreed to in writing, software\n" - "distributed under the License is distributed on an \"AS IS\" BASIS,\n" - "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - "See the License for the specific language governing permissions and\n" - "limitations under the License.\n" - "=======================================================================*/\n"; + "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" + "\n" + "Licensed under the Apache License, Version 2.0 (the \"License\");\n" + "you may not use this file except in compliance with the License.\n" + "You may obtain a copy of the License at\n" + "\n" + " http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + "Unless required by applicable law or agreed to in writing, software\n" + "distributed under the License is distributed on an \"AS IS\" BASIS,\n" + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + "See the License for the specific language governing permissions and\n" + "limitations under the License.\n" + "=======================================================================*/" + "\n"; // There is three different modes to render an op class, depending on the // number and type of outputs it has: diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 3524160d87..796d6a62dc 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,6 +15,18 @@ limitations under the License. package org.tensorflow.processor; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; import java.io.IOException; import java.util.Collection; import java.util.Collections; @@ -23,7 +35,6 @@ import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; - import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; @@ -44,19 +55,6 @@ import javax.lang.model.util.ElementFilter; import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; -import com.google.common.base.CaseFormat; -import com.google.common.base.Strings; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.JavaFile; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; - /** * A compile-time Processor that aggregates classes annotated with {@link * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please @@ -115,10 +113,12 @@ public final class OperatorProcessor extends AbstractProcessor { // generated our code, flag the location of each such class. if (hasRun) { for (Element e : annotated) { - error(e, "The Operator processor has already processed @Operator annotated sources\n" + - "and written out an Ops API. It cannot process additional @Operator sources.\n" + - "One reason this can happen is if other annotation processors generate\n" + - "new @Operator source files."); + error( + e, + "The Operator processor has already processed @Operator annotated sources\n" + + "and written out an Ops API. It cannot process additional @Operator sources.\n" + + "One reason this can happen is if other annotation processors generate\n" + + "new @Operator source files."); } return true; } @@ -146,9 +146,11 @@ public final class OperatorProcessor extends AbstractProcessor { return Collections.singleton("org.tensorflow.op.annotation.Operator"); } - private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final Pattern JAVADOC_TAG_PATTERN = + Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); - private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator"); + private static final TypeName T_OPERATOR = + ClassName.get("org.tensorflow.op.annotation", "Operator"); private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); private static final TypeName T_STRING = ClassName.get(String.class); @@ -167,20 +169,17 @@ public final class OperatorProcessor extends AbstractProcessor { private void write(TypeSpec spec) { try { - JavaFile.builder("org.tensorflow.op", spec) - .skipJavaLangImports(true) - .build() - .writeTo(filer); + JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); } catch (IOException e) { throw new AssertionError(e); } } private void writeApi(Multimap<String, MethodSpec> groupedMethods) { - Map<String, ClassName> groups = new HashMap<String, ClassName>(); - + Map<String, ClassName> groups = new HashMap<>(); + // Generate a API class for each group collected other than the default one (= empty string) - for (Map.Entry<String, Collection<MethodSpec>> entry: groupedMethods.asMap().entrySet()) { + for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) { if (!entry.getKey().isEmpty()) { TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); write(groupClass); @@ -193,12 +192,17 @@ public final class OperatorProcessor extends AbstractProcessor { } private boolean collectOpsMethods( - RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation) { + RoundEnvironment roundEnv, + Multimap<String, MethodSpec> groupedMethods, + TypeElement annotation) { boolean result = true; for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { // @Operator can only apply to types, so e must be a TypeElement. if (!(e instanceof TypeElement)) { - error(e, "@Operator can only be applied to classes, but this is a %s", e.getKind().toString()); + error( + e, + "@Operator can only be applied to classes, but this is a %s", + e.getKind().toString()); result = false; continue; } @@ -210,38 +214,42 @@ public final class OperatorProcessor extends AbstractProcessor { } return result; } - - private void collectOpMethods(Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { + + private void collectOpMethods( + Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { AnnotationMirror am = getAnnotationMirror(opClass, annotation); String groupName = getAnnotationElementValueAsString("group", am); String methodName = getAnnotationElementValueAsString("name", am); ClassName opClassName = ClassName.get(opClass); if (Strings.isNullOrEmpty(methodName)) { - methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); + methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); } - // Build a method for each @Operator found in the class path. There should be one method per operation factory called + // Build a method for each @Operator found in the class path. There should be one method per + // operation factory called // "create", which takes in parameter a scope and, optionally, a list of arguments for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { - if (opMethod.getModifiers().contains(Modifier.STATIC) && opMethod.getSimpleName().contentEquals("create")) { + if (opMethod.getModifiers().contains(Modifier.STATIC) + && opMethod.getSimpleName().contentEquals("create")) { MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); groupedMethods.put(groupName, method); } } } - private MethodSpec buildOpMethod(String methodName, ClassName opClassName, ExecutableElement factoryMethod) { + private MethodSpec buildOpMethod( + String methodName, ClassName opClassName, ExecutableElement factoryMethod) { MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName) - .addModifiers(Modifier.PUBLIC) - .returns(TypeName.get(factoryMethod.getReturnType())) - .varargs(factoryMethod.isVarArgs()) - .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.get(factoryMethod.getReturnType())) + .varargs(factoryMethod.isVarArgs()) + .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); - for (TypeParameterElement tp: factoryMethod.getTypeParameters()) { + for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); builder.addTypeVariable(tvn); } - for (TypeMirror thrownType: factoryMethod.getThrownTypes()) { + for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { builder.addException(TypeName.get(thrownType)); } StringBuilder call = new StringBuilder("return $T.create(scope"); @@ -259,13 +267,17 @@ public final class OperatorProcessor extends AbstractProcessor { call.append(")"); builder.addStatement(call.toString(), opClassName); return builder.build(); - } - + } + private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { StringBuilder javadoc = new StringBuilder(); - javadoc.append("Adds an {@link ").append(opClassName.simpleName()).append("} operation to the graph\n\n"); + javadoc + .append("Adds an {@link ") + .append(opClassName.simpleName()) + .append("} operation to the graph\n\n"); - // Add all javadoc tags found in the operator factory method but the first one, which should be in all cases the + // Add all javadoc tags found in the operator factory method but the first one, which should be + // in all cases the // 'scope' parameter that is implicitly passed by this API Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); boolean firstParam = true; @@ -277,136 +289,144 @@ public final class OperatorProcessor extends AbstractProcessor { } else { javadoc.append(tag).append('\n'); } - } + } javadoc.append("@see {@link ").append(opClassName).append("}\n"); return javadoc.toString(); } - + private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) { MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addParameter(T_SCOPE, "scope") - .addStatement("this.scope = scope"); - + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope"); + TypeSpec.Builder builder = TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addJavadoc("An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + - "@see {@link $T}\n", group, T_GRAPH, T_OPS) - .addMethods(methods) - .addMethod(ctorBuilder.build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + + "@see {@link $T}\n", + group, + T_GRAPH, + T_OPS) + .addMethods(methods) + .addMethod(ctorBuilder.build()); builder.addField( - FieldSpec.builder(T_SCOPE, "scope") - .addModifiers(Modifier.PRIVATE, Modifier.FINAL) - .build()); + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); return builder.build(); } - private static TypeSpec buildTopClass(Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { + private static TypeSpec buildTopClass( + Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addModifiers(Modifier.PRIVATE) - .addParameter(T_SCOPE, "scope") - .addStatement("this.scope = scope", T_SCOPE); + .addModifiers(Modifier.PRIVATE) + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope", T_SCOPE); - for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) { + for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); } TypeSpec.Builder opsBuilder = TypeSpec.classBuilder("Ops") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addJavadoc("An API for building a {@link $T} with operation wrappers\n<p>\n" + - "Any operation wrapper found in the classpath properly annotated as an {@link $T @Operator} is exposed\n" + - "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" + - "try (Graph g = new Graph()) {\n" + - " Ops ops = new Ops(g);\n" + - " // Operations are typed classes with convenience\n" + - " // builders in Ops.\n" + - " Constant three = ops.constant(3);\n" + - " // Single-result operations implement the Operand\n" + - " // interface, so this works too.\n" + - " Operand four = ops.constant(4);\n" + - " // Most builders are found within a group, and accept\n" + - " // Operand types as operands\n" + - " Operand nine = ops.math().add(four, ops.constant(5));\n" + - " // Multi-result operations however offer methods to\n" + - " // select a particular result for use.\n" + - " Operand result = \n" + - " ops.math().add(ops.array().unique(s, a).y(), b);\n" + - " // Optional attributes\n" + - " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + - " // Naming operators\n" + - " ops.withName(“foo”).constant(5); // name “foo”\n" + - " // Names can exist in a hierarchy\n" + - " Ops sub = ops.withSubScope(“sub”);\n" + - " sub.withName(“bar”).constant(4); // “sub/bar”\n" + - "}\n" + - "}</pre>\n", T_GRAPH, T_OPERATOR) - .addMethods(methods) - .addMethod(ctorBuilder.build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for building a {@link $T} with operation wrappers\n<p>\n" + + "Any operation wrapper found in the classpath properly annotated as an" + + "{@link $T @Operator} is exposed\n" + + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" + + "try (Graph g = new Graph()) {\n" + + " Ops ops = new Ops(g);\n" + + " // Operations are typed classes with convenience\n" + + " // builders in Ops.\n" + + " Constant three = ops.constant(3);\n" + + " // Single-result operations implement the Operand\n" + + " // interface, so this works too.\n" + + " Operand four = ops.constant(4);\n" + + " // Most builders are found within a group, and accept\n" + + " // Operand types as operands\n" + + " Operand nine = ops.math().add(four, ops.constant(5));\n" + + " // Multi-result operations however offer methods to\n" + + " // select a particular result for use.\n" + + " Operand result = \n" + + " ops.math().add(ops.array().unique(s, a).y(), b);\n" + + " // Optional attributes\n" + + " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + + " // Naming operators\n" + + " ops.withName(“foo”).constant(5); // name “foo”\n" + + " // Names can exist in a hierarchy\n" + + " Ops sub = ops.withSubScope(“sub”);\n" + + " sub.withName(“bar”).constant(4); // “sub/bar”\n" + + "}\n" + + "}</pre>\n", + T_GRAPH, + T_OPERATOR) + .addMethods(methods) + .addMethod(ctorBuilder.build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withSubScope") - .addModifiers(Modifier.PUBLIC) - .addParameter(T_STRING, "childScopeName") - .returns(T_OPS) - .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) - .addJavadoc( - "Returns an API that adds operations to the graph with the provided name prefix.\n\n" + - "@see {@link $T#withSubScope(String)}\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "childScopeName") + .returns(T_OPS) + .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided name prefix.\n" + + "\n@see {@link $T#withSubScope(String)}\n", + T_SCOPE) + .build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withName") - .addModifiers(Modifier.PUBLIC) - .addParameter(T_STRING, "opName") - .returns(T_OPS) - .addStatement("return new Ops(scope.withName(opName))") - .addJavadoc( - "Returns an API that uses the provided name for an op.\n\n" + - "@see {@link $T#withName(String)}\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "opName") + .returns(T_OPS) + .addStatement("return new Ops(scope.withName(opName))") + .addJavadoc( + "Returns an API that uses the provided name for an op.\n\n" + + "@see {@link $T#withName(String)}\n", + T_SCOPE) + .build()); opsBuilder.addField( - FieldSpec.builder(T_SCOPE, "scope") - .addModifiers(Modifier.PRIVATE, Modifier.FINAL) - .build()); + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); opsBuilder.addMethod( MethodSpec.methodBuilder("scope") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .returns(T_SCOPE) - .addStatement("return scope") - .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(T_SCOPE) + .addStatement("return scope") + .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) + .build()); - for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) { + for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { opsBuilder.addField( FieldSpec.builder(entry.getValue(), entry.getKey()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .build()); - + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + opsBuilder.addMethod( MethodSpec.methodBuilder(entry.getKey()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .returns(entry.getValue()) - .addStatement("return $L", entry.getKey()) - .addJavadoc("Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(entry.getValue()) + .addStatement("return $L", entry.getKey()) + .addJavadoc( + "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .build()); } opsBuilder.addMethod( MethodSpec.methodBuilder("create") - .addModifiers(Modifier.PUBLIC, Modifier.STATIC) - .addParameter(T_GRAPH, "graph") - .returns(T_OPS) - .addStatement("return new Ops(new $T(graph))", T_SCOPE) - .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(T_GRAPH, "graph") + .returns(T_OPS) + .addStatement("return new Ops(new $T(graph))", T_SCOPE) + .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .build()); return opsBuilder.build(); } @@ -417,12 +437,16 @@ public final class OperatorProcessor extends AbstractProcessor { return am; } } - throw new IllegalArgumentException("Annotation " + annotation.getSimpleName() + " not present on element " - + element.getSimpleName()); + throw new IllegalArgumentException( + "Annotation " + + annotation.getSimpleName() + + " not present on element " + + element.getSimpleName()); } - + private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { - for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : am.getElementValues().entrySet()) { + for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : + am.getElementValues().entrySet()) { if (entry.getKey().getSimpleName().contentEquals(elementName)) { return entry.getValue().getValue().toString(); } diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD new file mode 100644 index 0000000000..5f55b22818 --- /dev/null +++ b/tensorflow/python/compat/BUILD @@ -0,0 +1,10 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "compat", + srcs = ["compat.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], +) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py new file mode 100644 index 0000000000..e05ad55447 --- /dev/null +++ b/tensorflow/python/compat/compat.py @@ -0,0 +1,81 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for API compatibility between TensorFlow release versions. + +See +@{$guide/version_compat#backward_and_partial_forward_compatibility} +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime + +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 1) + + +def forward_compatible(year, month, day): + """Return true if the forward compatibility window has expired. + + Forward-compatibility refers to scenarios where the producer of a TensorFlow + model (a GraphDef or SavedModel) is compiled against a version of the + TensorFlow library newer than what the consumer was compiled against. The + "producer" is typically a Python program that constructs and trains a model + while the "consumer" is typically another program that loads and serves the + model. + + TensorFlow has been supporting a 3 week forward-compatibility window for + programs compiled from source at HEAD. + + For example, consider the case where a new operation `MyNewAwesomeAdd` is + created with the intent of replacing the implementation of an existing Python + wrapper - `tf.add`. The Python wrapper implementation should change from + something like: + + ```python + def add(inputs, name=None): + return gen_math_ops.add(inputs, name) + ``` + + to: + + ```python + from tensorflow.python.compat import compat + + def add(inputs, name=None): + if compat.forward_compatible(year, month, day): + # Can use the awesome new implementation. + return gen_math_ops.my_new_awesome_add(inputs, name) + # To maintain forward compatibiltiy, use the old implementation. + return gen_math_ops.add(inputs, name) + ``` + + Where `year`, `month`, and `day` specify the date beyond which binaries + that consume a model are expected to have been updated to include the + new operations. This date is typically at least 3 weeks beyond the date + the code that adds the new operation is committed. + + Args: + year: A year (e.g., 2018). + month: A month (1 <= month <= 12) in year. + day: A day (1 <= day <= 31, or 30, or 29, or 28) in month. + + Returns: + True if the caller can expect that serialized TensorFlow graphs produced + can be consumed by programs that are compiled with the TensorFlow library + source code after (year, month, day). + """ + return _FORWARD_COMPATIBILITY_HORIZON > datetime.date(year, month, day) diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 6941cacf23..c025dc8aa5 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -454,6 +454,17 @@ py_binary( ], ) +py_binary( + name = "debug_keras", + srcs = ["examples/debug_keras.py"], + srcs_version = "PY2AND3", + deps = [ + ":debug_py", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + py_test( name = "common_test", size = "small", @@ -1086,6 +1097,7 @@ py_test( "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variables", + "//third_party/py/numpy", ], ) @@ -1096,6 +1108,7 @@ sh_test( data = [ ":debug_errors", ":debug_fibonacci", + ":debug_keras", ":debug_mnist", ":debug_tflearn_iris", ":offline_analyzer", diff --git a/tensorflow/python/debug/examples/debug_keras.py b/tensorflow/python/debug/examples/debug_keras.py new file mode 100644 index 0000000000..3272d85ade --- /dev/null +++ b/tensorflow/python/debug/examples/debug_keras.py @@ -0,0 +1,89 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tfdbg example: debugging tf.keras models training on tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf + +from tensorflow.python import debug as tf_debug + + +def main(_): + # Create a dummy dataset. + num_examples = 8 + steps_per_epoch = 2 + input_dims = 3 + output_dims = 1 + xs = np.zeros([num_examples, input_dims]) + ys = np.zeros([num_examples, output_dims]) + dataset = tf.data.Dataset.from_tensor_slices( + (xs, ys)).repeat(num_examples).batch(int(num_examples / steps_per_epoch)) + + sess = tf.Session() + if FLAGS.debug: + # Use the command-line interface (CLI) of tfdbg. + sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) + elif FLAGS.tensorboard_debug_address: + # Use the TensorBoard Debugger Plugin (GUI of tfdbg). + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, FLAGS.tensorboard_debug_address) + tf.keras.backend.set_session(sess) + + # Create a dummy model. + model = tf.keras.Sequential([ + tf.keras.layers.Dense(1, input_shape=[input_dims])]) + model.compile(loss="mse", optimizer="sgd") + + # Train the model using the dummy dataset created above. + model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--debug", + type="bool", + nargs="?", + const=True, + default=False, + help="Use debugger to track down bad values during training. " + "Mutually exclusive with the --tensorboard_debug_address flag.") + parser.add_argument( + "--ui_type", + type=str, + default="curses", + help="Command-line user interface type (curses | readline).") + parser.add_argument( + "--tensorboard_debug_address", + type=str, + default=None, + help="Connect to the TensorBoard Debugger Plugin backend specified by " + "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " + "--debug flag.") + parser.add_argument( + "--epochs", + type=int, + default=2, + help="Number of epochs to train the model for.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh index e9c45a7e6e..2d35b2d8bb 100755 --- a/tensorflow/python/debug/examples/examples_test.sh +++ b/tensorflow/python/debug/examples/examples_test.sh @@ -48,12 +48,14 @@ if [[ -z "${PYTHON_BIN_PATH}" ]]; then DEBUG_ERRORS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_errors" DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist" DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_tflearn_iris" + DEBUG_KERAS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_keras" OFFLINE_ANALYZER_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/offline_analyzer" else DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_fibonacci" DEBUG_ERRORS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_errors" DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_mnist" DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_tflearn_iris" + DEBUG_KERAS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_keras" OFFLINE_ANALYZER_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.cli.offline_analyzer" fi @@ -96,6 +98,11 @@ if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then exit 1 fi +# Test debugging of tf.keras. +cat << EOF | "${DEBUG_KERAS_BIN}" --debug --ui_type=readline +run -f has_inf_or_nan +EOF + # Test offline_analyzer. echo echo "Testing offline_analyzer" diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index c530204bbf..b9524ce649 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface): self._default_session_context_manager = None + # A cache for callables created from CallableOptions. + self._cached_callables_from_options = dict() + @property def graph(self): return self._sess.graph @@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): """Wrapper around Session.run() that inserts tensor watch options. Args: @@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. callable_runner: A `callable` returned by `Session.make_callable()`. If not `None`, `fetches` and `feed_dict` must both be `None`. - callable_runner_args: An optional list of arguments to `callable_runner`. + Mutually exclusive with `callable_options`. + callable_runner_args: An optional list of arguments to `callable_runner` + or for `callable_options`. + callable_options: An instance of `config_pb2.CallableOptions`, to be + used with `Session._make_callable_from_options()`. Mutually exclusive + with `callable_runner`. Returns: Simply forwards the output of the wrapped `Session.run()` call. @@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface): ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` is not `None` and either or both of `fetches` and `feed_dict` is `None`. """ - if not callable_runner: + if callable_runner and callable_options: + raise ValueError( + "callable_runner and callable_options are mutually exclusive, but " + "are both specified in this call to BaseDebugWrapperSession.run().") + + if not (callable_runner or callable_options): self.increment_run_call_count() - else: - if fetches or feed_dict: - raise ValueError( - "callable_runner and fetches/feed_dict are mutually exclusive, but " - "are used simultaneously.") + elif callable_runner and (fetches or feed_dict): + raise ValueError( + "callable_runner and fetches/feed_dict are mutually exclusive, " + "but are used simultaneously.") empty_fetches = not nest.flatten(fetches) if empty_fetches: @@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface): if self._is_disabled_thread() or empty_fetches: if callable_runner: return callable_runner(*callable_runner_args) + elif callable_options: + # pylint:disable=protected-access + return self._sess._make_callable_from_options( + callable_options)(*callable_runner_args) + # pylint:enable=protected-access else: return self._sess.run(fetches, feed_dict=feed_dict, @@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface): if run_start_resp.action == OnRunStartAction.DEBUG_RUN: # Decorate RunOption to fill in debugger tensor watch specifications. - decorated_run_options = options or config_pb2.RunOptions() + decorated_run_options = None + if callable_options: + callable_options_id = id(callable_options) + if callable_options_id not in self._cached_callables_from_options: + # Make a copy of callable_options to avoid mutating it. + new_callable_options = config_pb2.CallableOptions() + new_callable_options.CopyFrom(callable_options) + decorated_run_options = new_callable_options.run_options + else: + decorated_run_options = options or config_pb2.RunOptions() + run_metadata = run_metadata or config_pb2.RunMetadata() - self._decorate_run_options_for_debug( - decorated_run_options, - run_start_resp.debug_urls, - debug_ops=run_start_resp.debug_ops, - node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, - op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=( - run_start_resp.tensor_dtype_regex_whitelist), - tolerate_debug_op_creation_failures=( - run_start_resp.tolerate_debug_op_creation_failures)) + if decorated_run_options: + self._decorate_run_options_for_debug( + decorated_run_options, + run_start_resp.debug_urls, + debug_ops=run_start_resp.debug_ops, + node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, + op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, + tensor_dtype_regex_whitelist=( + run_start_resp.tensor_dtype_regex_whitelist), + tolerate_debug_op_creation_failures=( + run_start_resp.tolerate_debug_op_creation_failures)) # Invoke the run() method of the wrapped Session. Catch any TensorFlow # runtime errors. @@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface): retvals = callable_runner(*callable_runner_args, options=decorated_run_options, run_metadata=run_metadata) + elif callable_options: + # pylint:disable=protected-access + if callable_options_id in self._cached_callables_from_options: + callable_object = self._cached_callables_from_options[ + callable_options_id] + else: + callable_object = self._sess._make_callable_from_options( + new_callable_options) + self._cached_callables_from_options[ + callable_options_id] = callable_object + # pylint:enable=protected-access + retvals = callable_object( + *callable_runner_args, run_metadata=run_metadata) else: retvals = self._sess.run(fetches, feed_dict=feed_dict, @@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata=kwargs.get("run_metadata", None), callable_runner=runner, callable_runner_args=runner_args) + return wrapped_runner + def _make_callable_from_options(self, callable_options): + def wrapped_runner(*feed_values, **kwargs): + return self.run(None, + run_metadata=kwargs.get("run_metadata", None), + callable_options=callable_options, + callable_runner_args=feed_values) return wrapped_runner @property diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 1f9c8fa5a9..85944fa611 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -215,7 +215,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): if self._send_traceback_and_source_code: self._sent_graph_version = publish_traceback( self._grpc_debug_server_urls, self.graph, feed_dict, fetches, @@ -226,4 +227,5 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): options=options, run_metadata=run_metadata, callable_runner=callable_runner, - callable_runner_args=callable_runner_args) + callable_runner_args=callable_runner_args, + callable_options=callable_options) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index 4e551ab995..668ffb57f1 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -596,7 +596,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): # Register tab completion for the filter names. curses_cli.register_tab_comp_context(["run", "r"], list(self._tensor_filters.keys())) - if self._feed_dict: + if self._feed_dict and hasattr(self._feed_dict, "keys"): # Register tab completion for feed_dict keys. feed_keys = [common.get_graph_element_name(key) for key in self._feed_dict.keys()] diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index b06fa26a93..05c9eaa4d2 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -21,7 +21,10 @@ import os import shutil import tempfile +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import debugger_cli_common @@ -149,7 +152,13 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): dtypes.float32, shape=([5, 5]), name="sparse_placeholder") self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph) - self.sess = session.Session() + rewriter_config = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + config_proto = config_pb2.ConfigProto(graph_options=graph_options) + self.sess = session.Session(config=config_proto) # Initialize variable. self.sess.run(variables.global_variables_initializer()) @@ -393,6 +402,113 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.assertAllClose(42.0, tensor_runner(41.0, 1.0)) self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"])) + def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self): + variable_1 = variables.Variable( + 10.5, dtype=dtypes.float32, name="variable_1") + a = math_ops.add(variable_1, variable_1, "callable_a") + math_ops.add(a, a, "callable_b") + self.sess.run(variable_1.initializer) + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + for _ in range(2): + callable_output = sess_callable() + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual( + ["callable_a", "callable_b", "variable_1", "variable_1/read"], + node_names) + + def testDebuggingMakeCallableFromOptionsWithOneFeedWorks(self): + ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1") + a = math_ops.add(ph1, ph1, "callable_a") + math_ops.add(a, a, "callable_b") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.feed.append("callable_ph1") + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + ph1_value = np.array([10.5, -10.5], dtype=np.float32) + + for _ in range(2): + callable_output = sess_callable(ph1_value) + self.assertAllClose( + np.array([42.0, -42.0], dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual(["callable_a", "callable_b"], node_names) + + def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self): + ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1") + ph2 = array_ops.placeholder(dtypes.float32, name="callable_ph2") + a = math_ops.add(ph1, ph2, "callable_a") + math_ops.add(a, a, "callable_b") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.feed.append("callable_ph1") + callable_options.feed.append("callable_ph2") + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + ph1_value = np.array(5.0, dtype=np.float32) + ph2_value = np.array(16.0, dtype=np.float32) + + for _ in range(2): + callable_output = sess_callable(ph1_value, ph2_value) + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual(["callable_a", "callable_b"], node_names) + + def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self): + variable_1 = variables.Variable( + 10.5, dtype=dtypes.float32, name="variable_1") + a = math_ops.add(variable_1, variable_1, "callable_a") + math_ops.add(a, a, "callable_b") + self.sess.run(variable_1.initializer) + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"], ["run"]], self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.fetch.append("callable_b") + callable_options.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE + + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + run_metadata = config_pb2.RunMetadata() + # Call the callable with a custom run_metadata. + callable_output = sess_callable(run_metadata=run_metadata) + # Verify that step_stats is populated in the custom run_metadata. + self.assertTrue(run_metadata.step_stats) + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(1, len(debug_dumps)) + debug_dump = debug_dumps[0] + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual( + ["callable_a", "callable_b", "variable_1", "variable_1/read"], + node_names) + def testRuntimeErrorShouldBeCaught(self): wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( [["run"], ["run"]], self.sess, dump_root=self._tmp_dir) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index a81ef90513..7edcb0931d 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -782,6 +782,9 @@ class _PolymorphicFunction(object): kwd_values = _deterministic_dict_values(kwds) inputs = args + kwd_values signature = tuple(_cache_key(x) for x in inputs) + # The graph, or whether we're executing eagerly, should be a part of the + # signature so we don't improperly capture tensors such as variables. + signature += tuple([context.executing_eagerly() or ops.get_default_graph()]) if signature not in self._arguments_to_functions: graph_function = _trace_and_define_function( diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index ad00adbabb..cf32f6e7fb 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -105,6 +105,18 @@ class FunctionTest(test.TestCase): self.assertAllEqual(grads.eval(), 2.0) self.assertEqual(grads.shape, v.shape) + def testGraphEagerIsolation(self): + + @function.defun + def f(): + v = resource_variable_ops.ResourceVariable(1.0) + return v.read_value() + + self.assertAllEqual(f(), 1.0) + + with ops.Graph().as_default(): + self.assertEqual(f().shape, ()) + def testBasicDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 733c7fb95d..2a0e4e7617 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -1296,6 +1297,31 @@ class EstimatorEvaluateTest(test.TestCase): dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint()) self.assertEqual(5, scores['global_step']) + def test_wrong_shape_throws_reasonable_error(self): + """Make sure we are helpful when model_fns change. See b/110263146.""" + def _get_model_fn(val=1): + def _model_fn(features, labels, mode): + del features, labels # unused + variables.Variable(val, name='weight') + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=constant_op.constant([[1.]]), + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1)) + return _model_fn + + model_fn_1 = _get_model_fn() + model_fn_2 = _get_model_fn(val=[1]) + + est1 = estimator.Estimator(model_fn=model_fn_1) + est1.train(dummy_input_fn, steps=5) + est2 = estimator.Estimator( + model_fn=model_fn_2, model_dir=est1.model_dir) + + expected_msg = 'Restoring from checkpoint failed.*a mismatch between' + with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg): + est2.train(dummy_input_fn, steps=1,) + def test_scaffold_is_used(self): def _model_fn_scaffold(features, labels, mode): diff --git a/tensorflow/python/keras/datasets/mnist.py b/tensorflow/python/keras/datasets/mnist.py index 2a1c8d5f51..a96b581960 100644 --- a/tensorflow/python/keras/datasets/mnist.py +++ b/tensorflow/python/keras/datasets/mnist.py @@ -50,5 +50,5 @@ def load_data(path='mnist.npz'): with np.load(path) as f: x_train, y_train = f['x_train'], f['y_train'] x_test, y_test = f['x_test'], f['y_test'] - + return (x_train, y_train), (x_test, y_test) diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py index 5e95cd4340..d5ccd44604 100644 --- a/tensorflow/python/keras/engine/saving.py +++ b/tensorflow/python/keras/engine/saving.py @@ -854,7 +854,16 @@ def load_weights_from_hdf5_group_by_name(f, layers): str(len(weight_values)) + ' element(s).') # Set values. for i in range(len(weight_values)): - weight_value_tuples.append((symbolic_weights[i], weight_values[i])) + if K.int_shape(symbolic_weights[i]) != weight_values[i].shape: + raise ValueError('Layer #' + str(k) +' (named "' + layer.name + + '"), weight ' + str(symbolic_weights[i]) + + ' has shape {}'.format(K.int_shape( + symbolic_weights[i])) + + ', but the saved weight has shape ' + + str(weight_values[i].shape) + '.') + + else: + weight_value_tuples.append((symbolic_weights[i], weight_values[i])) K.batch_set_value(weight_value_tuples) diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 1a0aa60609..030328f2a6 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import os import shutil import tempfile - from absl.testing import parameterized import numpy as np @@ -31,6 +30,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import saving from tensorflow.python.keras.engine import training from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops @@ -248,6 +248,82 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): self.assertAllClose(y, ref_y) + def test_sequential_weight_loading_group_name_with_incorrect_length(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + num_classes = 2 + with self.test_session(): + ref_model = keras.models.Sequential() + ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, + name='d1')) + ref_model.add(keras.layers.Dense(num_classes, name='d2')) + ref_model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + + f_ref_model = h5py.File(h5_path, 'w') + saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) + + f_model = h5py.File(h5_path, 'r') + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, use_bias=False, + input_dim=input_dim, name='d1')) + model.add(keras.layers.Dense(num_classes, name='d2')) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named \"d1\"\) expects 1 ' + r'weight\(s\), but the saved weights have 2 ' + r'element\(s\)\.'): + saving.load_weights_from_hdf5_group_by_name(f_model, model.layers) + + def test_sequential_weight_loading_group_name_with_incorrect_shape(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + num_classes = 2 + with self.test_session(): + ref_model = keras.models.Sequential() + ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, + name='d1')) + ref_model.add(keras.layers.Dense(num_classes, name='d2')) + ref_model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + + f_ref_model = h5py.File(h5_path, 'w') + saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) + + f_model = h5py.File(h5_path, 'r') + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim, + name='d1')) + model.add(keras.layers.Dense(num_classes, name='d2')) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named "d1"\), weight ' + r'<tf\.Variable \'d1_1\/kernel:0\' ' + r'shape=\(3, 10\) dtype=float32> has ' + r'shape \(3, 10\), but the saved weight has ' + r'shape \(3, 5\)\.'): + saving.load_weights_from_hdf5_group_by_name(f_model, model.layers) + class TestWholeModelSaving(test.TestCase): diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py index cb86a69990..b244beb5b5 100644 --- a/tensorflow/python/keras/estimator/__init__.py +++ b/tensorflow/python/keras/estimator/__init__.py @@ -25,7 +25,7 @@ from tensorflow.python.util.tf_export import tf_export # everything will work as normal. try: - import tensorflow.python.estimator.keras as keras_lib # pylint: disable=g-import-not-at-top + from tensorflow.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top model_to_estimator = tf_export('keras.estimator.model_to_estimator')( keras_lib.model_to_estimator) except Exception: # pylint: disable=broad-except diff --git a/tensorflow/python/kernel_tests/dct_ops_test.py b/tensorflow/python/kernel_tests/dct_ops_test.py index 93b2ff4561..97d7e2d8f9 100644 --- a/tensorflow/python/kernel_tests/dct_ops_test.py +++ b/tensorflow/python/kernel_tests/dct_ops_test.py @@ -40,50 +40,92 @@ def try_import(name): # pylint: disable=invalid-name fftpack = try_import("scipy.fftpack") +def _np_dct2(signals, norm=None): + """Computes the DCT-II manually with NumPy.""" + # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 + dct_size = signals.shape[-1] + dct = np.zeros_like(signals) + for k in range(dct_size): + phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) + dct[..., k] = np.sum(signals * phi, axis=-1) + # SciPy's `dct` has a scaling factor of 2.0 which we follow. + # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src + if norm == "ortho": + # The orthonormal scaling includes a factor of 0.5 which we combine with + # the overall scaling of 2.0 to cancel. + dct[..., 0] *= np.sqrt(1.0 / dct_size) + dct[..., 1:] *= np.sqrt(2.0 / dct_size) + else: + dct *= 2.0 + return dct + + +def _np_dct3(signals, norm=None): + """Computes the DCT-III manually with NumPy.""" + # SciPy's `dct` has a scaling factor of 2.0 which we follow. + # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src + dct_size = signals.shape[-1] + signals = np.array(signals) # make a copy so we can modify + if norm == "ortho": + signals[..., 0] *= np.sqrt(4.0 / dct_size) + signals[..., 1:] *= np.sqrt(2.0 / dct_size) + else: + signals *= 2.0 + dct = np.zeros_like(signals) + # X_k = 0.5 * x_0 + + # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1 + half_x0 = 0.5 * signals[..., 0] + for k in range(dct_size): + phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size) + dct[..., k] = half_x0 + np.sum(signals[..., 1:] * phi, axis=-1) + return dct + + +NP_DCT = {2: _np_dct2, 3: _np_dct3} +NP_IDCT = {2: _np_dct3, 3: _np_dct2} + + class DCTOpsTest(test.TestCase): - def _np_dct2(self, signals, norm=None): - """Computes the DCT-II manually with NumPy.""" - # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 - dct_size = signals.shape[-1] - dct = np.zeros_like(signals) - for k in range(dct_size): - phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) - dct[..., k] = np.sum(signals * phi, axis=-1) - # SciPy's `dct` has a scaling factor of 2.0 which we follow. - # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src - if norm == "ortho": - # The orthonormal scaling includes a factor of 0.5 which we combine with - # the overall scaling of 2.0 to cancel. - dct[..., 0] *= np.sqrt(1.0 / dct_size) - dct[..., 1:] *= np.sqrt(2.0 / dct_size) - else: - dct *= 2.0 - return dct - - def _compare(self, signals, norm, atol=5e-4, rtol=5e-4): - """Compares the DCT to SciPy (if available) and a NumPy implementation.""" - np_dct = self._np_dct2(signals, norm) - tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval() + def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4): + """Compares (I)DCT to SciPy (if available) and a NumPy implementation.""" + np_dct = NP_DCT[dct_type](signals, norm) + tf_dct = spectral_ops.dct(signals, type=dct_type, norm=norm).eval() self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol) + np_idct = NP_IDCT[dct_type](signals, norm) + tf_idct = spectral_ops.idct(signals, type=dct_type, norm=norm).eval() + self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol) if fftpack: - scipy_dct = fftpack.dct(signals, type=2, norm=norm) + scipy_dct = fftpack.dct(signals, type=dct_type, norm=norm) self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol) + scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm) + self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol) + # Verify inverse(forward(s)) == s, up to a normalization factor. + tf_idct_dct = spectral_ops.idct( + tf_dct, type=dct_type, norm=norm).eval() + tf_dct_idct = spectral_ops.dct( + tf_idct, type=dct_type, norm=norm).eval() + if norm is None: + tf_idct_dct *= 0.5 / signals.shape[-1] + tf_dct_idct *= 0.5 / signals.shape[-1] + self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol) + self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol) def test_random(self): """Test randomly generated batches of data.""" with spectral_ops_test_util.fft_kernel_label_map(): with self.test_session(use_gpu=True): - for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]): + for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]): signals = np.random.rand(*shape).astype(np.float32) for norm in (None, "ortho"): - self._compare(signals, norm) + self._compare(signals, norm, 2) + self._compare(signals, norm, 3) def test_error(self): signals = np.random.rand(10) # Unsupported type. with self.assertRaises(ValueError): - spectral_ops.dct(signals, type=3) + spectral_ops.dct(signals, type=1) # Unknown normalization. with self.assertRaises(ValueError): spectral_ops.dct(signals, norm="bad") diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h index 98354083c7..d4621d61ee 100644 --- a/tensorflow/python/lib/core/numpy.h +++ b/tensorflow/python/lib/core/numpy.h @@ -30,8 +30,8 @@ limitations under the License. #endif // Place `<locale>` before <Python.h> to avoid build failure in macOS. -#include <locale> #include <Python.h> +#include <locale> #include "numpy/arrayobject.h" #include "numpy/ufuncobject.h" diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc index 572693b1cf..6b6c82015f 100644 --- a/tensorflow/python/lib/core/py_util.cc +++ b/tensorflow/python/lib/core/py_util.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/python/lib/core/py_util.h" // Place `<locale>` before <Python.h> to avoid build failure in macOS. -#include <locale> #include <Python.h> +#include <locale> #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 2c7751f792..a2eae452ae 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -57,6 +57,7 @@ ops.NotDifferentiable('NonMaxSuppression') ops.NotDifferentiable('NonMaxSuppressionV2') +# pylint: disable=invalid-name def _assert(cond, ex_type, msg): """A polymorphic assert, works with tensors and boolean expressions. @@ -1070,15 +1071,16 @@ def resize_images(images, @tf_export('image.resize_image_with_pad') -def resize_image_with_pad(image, target_height, target_width, +def resize_image_with_pad(image, + target_height, + target_width, method=ResizeMethod.BILINEAR): - """ - Resizes and pads an image to a target width and height. + """Resizes and pads an image to a target width and height. Resizes an image to a target width and height by keeping the aspect ratio the same without distortion. If the target dimensions don't match the image dimensions, the image - is resized and then padded with zeroes to match requested + is resized and then padded with zeroes to match requested dimensions. Args: @@ -1139,10 +1141,10 @@ def resize_image_with_pad(image, target_height, target_width, ratio = max_(f_width / f_target_width, f_height / f_target_height) resized_height_float = f_height / ratio resized_width_float = f_width / ratio - resized_height = math_ops.cast(math_ops.floor(resized_height_float), - dtype=dtypes.int32) - resized_width = math_ops.cast(math_ops.floor(resized_width_float), - dtype=dtypes.int32) + resized_height = math_ops.cast( + math_ops.floor(resized_height_float), dtype=dtypes.int32) + resized_width = math_ops.cast( + math_ops.floor(resized_width_float), dtype=dtypes.int32) padding_height = (f_target_height - resized_height_float) / 2 padding_width = (f_target_width - resized_width_float) / 2 @@ -1154,13 +1156,13 @@ def resize_image_with_pad(image, target_height, target_width, # Resize first, then pad to meet requested dimensions resized = resize_images(image, [resized_height, resized_width], method) - padded = pad_to_bounding_box(resized, p_height, p_width, - target_height, target_width) + padded = pad_to_bounding_box(resized, p_height, p_width, target_height, + target_width) if padded.get_shape().ndims is None: raise ValueError('padded contains no shape.') - _, padded_height, padded_width, _ = _ImageDimensions(padded, rank=4) + _ImageDimensions(padded, rank=4) if not is_batch: padded = array_ops.squeeze(padded, squeeze_dims=[0]) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 8e40de140d..cf9761803b 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -2731,7 +2731,7 @@ class ResizeImageWithPadTest(test_util.TensorFlowTestCase): try: self._ResizeImageWithPad(x, target_height, target_width, use_tensor_inputs) - except Exception as e: + except Exception as e: # pylint: disable=broad-except if err_msg not in str(e): raise else: diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 45e3bd65d2..6b709e5e7f 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -237,8 +237,8 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase): def testApproximateEqualShape(self): for dtype in [np.float32, np.double]: - x = np.array([1, 2], dtype=np.float32) - y = np.array([[1, 2]], dtype=np.float32) + x = np.array([1, 2], dtype=dtype) + y = np.array([[1, 2]], dtype=dtype) # The inputs 'x' and 'y' must have the same shape. with self.assertRaisesRegexp( ValueError, "Shapes must be equal rank, but are 1 and 2"): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 215140e987..deba133fb9 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops @@ -131,6 +132,18 @@ def _maybe_tensor_shape_from_tensor(shape): return shape +def _should_cache(): + """Returns True if a default caching device should be set, otherwise False.""" + if context.executing_eagerly(): + return False + # Don't set a caching device when running in a loop, since it is possible that + # train steps could be wrapped in a tf.while_loop. In that scenario caching + # prevents forward computations in loop iterations from re-reading the + # updated weights. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingWhileContext(ctxt) is None + + # pylint: disable=unused-argument def _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, @@ -558,7 +571,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -1015,7 +1028,7 @@ def raw_rnn(cell, loop_fn, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -1228,7 +1241,7 @@ def static_rnn(cell, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 6efcd39f13..9a10abfcf7 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -201,8 +201,8 @@ def einsum(equation, *inputs, **kwargs): indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ - equation = equation.replace(" ", "") - + equation = equation.replace(' ', '') + name = kwargs.pop('name', None) if kwargs: raise TypeError('invalid keyword arguments for this function: ' + ', '.join( diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py index 28054f50ef..293aace728 100644 --- a/tensorflow/python/ops/spectral_ops.py +++ b/tensorflow/python/ops/spectral_ops.py @@ -167,8 +167,8 @@ def _validate_dct_arguments(dct_type, n, axis, norm): raise NotImplementedError("The DCT length argument is not implemented.") if axis != -1: raise NotImplementedError("axis must be -1. Got: %s" % axis) - if dct_type != 2: - raise ValueError("Only the Type II DCT is supported.") + if dct_type not in (2, 3): + raise ValueError("Only Types II and III (I)DCT are supported.") if norm not in (None, "ortho"): raise ValueError( "Unknown normalization. Expected None or 'ortho', got: %s" % norm) @@ -179,18 +179,20 @@ def _validate_dct_arguments(dct_type, n, axis, norm): def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. - Currently only Type II is supported. Implemented using a length `2N` padded - @{tf.spectral.rfft}, as described here: https://dsp.stackexchange.com/a/10606 + Currently only Types II and III are supported. Type II is implemented using a + length `2N` padded @{tf.spectral.rfft}, as described here: + https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward + inverse of Type II (i.e. using a length `2N` padded @{tf.spectral.irfft}). @compatibility(scipy) - Equivalent to scipy.fftpack.dct for the Type-II DCT. + Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT. https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html @end_compatibility Args: input: A `[..., samples]` `float32` `Tensor` containing the signals to take the DCT of. - type: The DCT type to perform. Must be 2. + type: The DCT type to perform. Must be 2 or 3. n: For future expansion. The length of the transform. Must be `None`. axis: For future expansion. The axis to compute the DCT along. Must be `-1`. norm: The normalization to apply. `None` for no normalization or `'ortho'` @@ -201,8 +203,8 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl A `[..., samples]` `float32` `Tensor` containing the DCT of `input`. Raises: - ValueError: If `type` is not `2`, `n` is not `None, `axis` is not `-1`, or - `norm` is not `None` or `'ortho'`. + ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not + `-1`, or `norm` is not `None` or `'ortho'`. [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform """ @@ -214,22 +216,91 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl axis_dim = input.shape[-1].value or _array_ops.shape(input)[-1] axis_dim_float = _math_ops.to_float(axis_dim) - scale = 2.0 * _math_ops.exp(_math_ops.complex( - 0.0, -_math.pi * _math_ops.range(axis_dim_float) / - (2.0 * axis_dim_float))) - - # TODO(rjryan): Benchmark performance and memory usage of the various - # approaches to computing a DCT via the RFFT. - dct2 = _math_ops.real( - rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) - - if norm == "ortho": - n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) - n2 = n1 * _math_ops.sqrt(2.0) - # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. - weights = _array_ops.pad( - _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], - constant_values=n2) - dct2 *= weights - - return dct2 + if type == 2: + scale = 2.0 * _math_ops.exp( + _math_ops.complex( + 0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 / + axis_dim_float)) + + # TODO(rjryan): Benchmark performance and memory usage of the various + # approaches to computing a DCT via the RFFT. + dct2 = _math_ops.real( + rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) + + if norm == "ortho": + n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) + n2 = n1 * _math_ops.sqrt(2.0) + # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. + weights = _array_ops.pad( + _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], + constant_values=n2) + dct2 *= weights + + return dct2 + + elif type == 3: + if norm == "ortho": + n1 = _math_ops.sqrt(axis_dim_float) + n2 = n1 * _math_ops.sqrt(0.5) + # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. + weights = _array_ops.pad( + _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], + constant_values=n2) + input *= weights + else: + input *= axis_dim_float + scale = 2.0 * _math_ops.exp( + _math_ops.complex( + 0.0, + _math_ops.range(axis_dim_float) * _math.pi * 0.5 / + axis_dim_float)) + dct3 = _math_ops.real( + irfft( + scale * _math_ops.complex(input, 0.0), + fft_length=[2 * axis_dim]))[..., :axis_dim] + + return dct3 + + +# TODO(rjryan): Implement `type`, `n` and `axis` parameters. +@tf_export("spectral.idct") +def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin + """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. + + Currently only Types II and III are supported. Type III is the inverse of + Type II, and vice versa. + + Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is + not `'ortho'`. That is: + `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`. + When `norm='ortho'`, we have: + `signal == idct(dct(signal, norm='ortho'), norm='ortho')`. + + @compatibility(scipy) + Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT. + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html + @end_compatibility + + Args: + input: A `[..., samples]` `float32` `Tensor` containing the signals to take + the DCT of. + type: The IDCT type to perform. Must be 2 or 3. + n: For future expansion. The length of the transform. Must be `None`. + axis: For future expansion. The axis to compute the DCT along. Must be `-1`. + norm: The normalization to apply. `None` for no normalization or `'ortho'` + for orthonormal normalization. + name: An optional name for the operation. + + Returns: + A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`. + + Raises: + ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not + `-1`, or `norm` is not `None` or `'ortho'`. + + [idct]: + https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms + """ + _validate_dct_arguments(type, n, axis, norm) + inverse_type = {2: 3, 3: 2}[type] + return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 53ed89e4ab..1ee975fbe4 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -22,7 +22,6 @@ from __future__ import print_function import collections import os.path import re -import sys import time import uuid @@ -1043,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: - raise ValueError("Invalid checkpoint state loaded from %s", - checkpoint_dir) + raise ValueError("Invalid checkpoint state loaded from " + + checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): @@ -1706,12 +1705,17 @@ class Saver(object): save_path: Path where parameters were previously saved. Raises: - ValueError: If save_path is None. + ValueError: If save_path is None or not a valid checkpoint. """ if self._is_empty: return if save_path is None: raise ValueError("Can't load save_path when it is None.") + + if not checkpoint_exists(compat.as_text(save_path)): + raise ValueError("The passed save_path is not a valid checkpoint: " + + compat.as_text(save_path)) + logging.info("Restoring parameters from %s", compat.as_text(save_path)) try: if context.executing_eagerly(): @@ -1719,23 +1723,24 @@ class Saver(object): else: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - except errors.NotFoundError: - exception_type, exception_value, exception_traceback = sys.exc_info() - # The checkpoint would not be loaded successfully as is. Try to parse it - # as an object-based checkpoint. - should_reraise = False + except errors.NotFoundError as err: + # There are three common conditions that might cause this error: + # 0. The file is missing. We ignore here, as this is checked above. + # 1. This is an object-based checkpoint trying name-based loading. + # 2. The graph has been altered and a variable or other name is missing. + + # 1. The checkpoint would not be loaded successfully as is. Try to parse + # it as an object-based checkpoint. try: reader = pywrap_tensorflow.NewCheckpointReader(save_path) object_graph_string = reader.get_tensor( checkpointable.OBJECT_GRAPH_PROTO_KEY) except errors.NotFoundError: - # This is not an object-based checkpoint, or the checkpoint doesn't - # exist. Re-raise the original exception, but do it outside the except - # block so the object graph lookup isn't included in the stack trace. - should_reraise = True - if should_reraise: - six.reraise(exception_type, exception_value, exception_traceback) - del exception_traceback # avoid reference cycles + # 2. This is not an object-based checkpoint, which likely means there + # is a graph mismatch. Re-raise the original error with + # a helpful message (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do # the restore. @@ -1747,6 +1752,11 @@ class Saver(object): self._restore_from_object_based_checkpoint( sess=sess, save_path=save_path, object_graph_string=object_graph_string) + except errors.InvalidArgumentError as err: + # There is a mismatch between the graph and the checkpoint being loaded. + # We add a more reasonable error message here to help users (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a mismatch between the current graph and the graph") def _restore_from_object_based_checkpoint(self, sess, save_path, object_graph_string): @@ -2139,6 +2149,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): return meta_graph_filename +def _wrap_restore_error_with_msg(err, extra_verbiage): + err_msg = ("Restoring from checkpoint failed. This is most likely " + "due to {} from the checkpoint. Please ensure that you " + "have not altered the graph expected based on the checkpoint. " + "Original error:\n\n{}").format(extra_verbiage, err.message) + return err.__class__(err.node_def, err.op, err_msg) + + ops.register_proto_function( ops.GraphKeys.SAVERS, proto_type=saver_pb2.SaverDef, diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index f235300eb5..ae9c244aaf 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -24,10 +24,8 @@ import math import os import random import shutil -import sys import tempfile import time -import traceback import numpy as np import six @@ -369,8 +367,8 @@ class SaverTest(test.TestCase): for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): with self.test_session() as sess: save = saver_module.Saver({"v0": v0}, write_version=ver) - with self.assertRaisesRegexp(errors.NotFoundError, - "Failed to find any matching files for"): + with self.assertRaisesRegexp( + ValueError, "The passed save_path is not a valid checkpoint:"): save.restore(sess, "invalid path") def testInt64(self): @@ -3139,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase): errors.NotFoundError, "Key b not found in checkpoint"): b_saver.restore(sess=sess, save_path=save_path) - def testCheckpointNotFoundErrorRaised(self): - # Restore does some tricky exception handling to figure out if it should - # load an object-based checkpoint. Tests that the exception handling isn't - # too broad. - a = resource_variable_ops.ResourceVariable(1., name="a") - saver = saver_module.Saver([a]) - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.NotFoundError, - "Failed to find any matching files for path_which_does_not_exist"): - saver.restore(sess=sess, save_path="path_which_does_not_exist") - try: - saver.restore(sess=sess, save_path="path_which_does_not_exist") - except errors.NotFoundError: - # Make sure we don't have a confusing "During handling of the above - # exception" block in Python 3. - # pylint: disable=no-value-for-parameter - exception_string = "\n".join( - traceback.format_exception(*sys.exc_info())) - # pylint: enable=no-value-for-parameter - self.assertNotIn("NewCheckpointReader", exception_string) + with self.assertRaises(errors.NotFoundError) as cs: + b_saver.restore(sess=sess, save_path=save_path) + + # Make sure we don't have a confusing "During handling of the above + # exception" block in Python 3. + self.assertNotIn("NewCheckpointReader", cs.exception.message) + + def testGraphChangedForRestoreErrorRaised(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + with ops_lib.Graph().as_default() as g: + a = variables.Variable(1., name="a") + a_saver = saver_module.Saver([a]) + + with self.test_session(graph=g) as sess: + sess.run(a.initializer) + save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) + + with ops_lib.Graph().as_default() as g: + a = variables.Variable([1.], name="a") + a_saver = saver_module.Saver([a]) + with self.test_session(graph=g) as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "a mismatch between the current graph and the graph"): + a_saver.restore(sess=sess, save_path=save_path) def testLoadFromObjectBasedGraph(self): checkpoint_directory = self.get_temp_dir() diff --git a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt index 4f306540cc..6a421ef12d 100644 --- a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt @@ -17,6 +17,10 @@ tf_module { argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { + name: "idct" + argspec: "args=[\'input\', \'type\', \'n\', \'axis\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'2\', \'None\', \'-1\', \'None\', \'None\'], " + } + member_method { name: "ifft" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le index e879c34bbd..f496ac59b6 100644 --- a/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le @@ -8,7 +8,6 @@ RUN /install/install_bootstrap_deb_packages.sh RUN add-apt-repository -y ppa:openjdk-r/ppa RUN /install/install_deb_packages.sh RUN apt-get update && apt-get install -y libopenblas-dev -RUN /install/install_hdf5_ppc64le.sh RUN /install/install_pip_packages.sh RUN /install/install_bazel_from_source.sh RUN /install/install_proto3.sh diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le index 8967138747..3eddc56550 100644 --- a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le +++ b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le @@ -14,7 +14,6 @@ RUN /install/install_bootstrap_deb_packages.sh RUN add-apt-repository -y ppa:openjdk-r/ppa RUN /install/install_deb_packages.sh RUN apt-get update && apt-get install -y libopenblas-dev -RUN /install/install_hdf5_ppc64le.sh RUN /install/install_pip_packages.sh RUN /install/install_bazel_from_source.sh RUN /install/install_golang_ppc64le.sh diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 05676f9551..f0a437c183 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -349,12 +349,12 @@ do_external_licenses_check(){ # Blacklist echo ${MISSING_LICENSES_FILE} - grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -v ${MISSING_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -e "@com_github_googlecloudplatform_google_cloud_cpp//google" -v ${MISSING_LICENSES_FILE} > temp.txt mv temp.txt ${MISSING_LICENSES_FILE} # Whitelist echo ${EXTRA_LICENSE_FILE} - grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -v ${EXTRA_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -v ${EXTRA_LICENSES_FILE} > temp.txt mv temp.txt ${EXTRA_LICENSES_FILE} diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 05c23cd3ee..173f418dc8 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -115,6 +115,7 @@ genrule( "//third_party/fft2d:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", @@ -156,6 +157,7 @@ genrule( "//third_party/fft2d:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index a0caf42331..c9d53f46c3 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -130,6 +130,8 @@ filegroup( "@astor_archive//:LICENSE", "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_github_googleapis_googleapis//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", "@com_google_absl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index c630ca04b8..1236de2657 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -53,7 +53,7 @@ REQUIRED_PACKAGES = [ 'gast >= 0.2.0', 'numpy >= 1.13.3', 'six >= 1.10.0', - 'protobuf >= 3.6.0', + 'protobuf >= 3.4.0', 'setuptools <= 39.1.0', 'tensorboard >= 1.8.0, < 1.9.0', 'termcolor >= 1.1.0', diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index cae6f51eb5..172eed0b57 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -107,11 +107,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "eigen_archive", urls = [ - "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", - "https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/e5e305a158a0.tar.gz", + "https://bitbucket.org/eigen/eigen/get/e5e305a158a0.tar.gz", ], - sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9", - strip_prefix = "eigen-eigen-fd6845384b86", + sha256 = "8bbe676d69e7f59070c83a949454b8b6344034e0ebbf686b337528e5dc04c7de", + strip_prefix = "eigen-eigen-e5e305a158a0", build_file = clean_dep("//third_party:eigen.BUILD"), ) @@ -142,11 +142,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "ortools_archive", urls = [ - "https://mirror.bazel.build/github.com/google/or-tools/archive/v6.7.2.tar.gz", - "https://github.com/google/or-tools/archive/v6.7.2.tar.gz", + "https://mirror.bazel.build/github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", + # Please uncomment me, when the next upgrade happens. Then + # remove the whitelist entry in third_party/repo.bzl. + # "https://github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", ], - sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9", - strip_prefix = "or-tools-6.7.2/src", + sha256 = "932075525642b04ac6f1b50589f1df5cd72ec2f448b721fd32234cf183f0e755", + strip_prefix = "or-tools-253f7955c6a1fd805408fba2e42ac6d45b312d15/src", build_file = clean_dep("//third_party:ortools.BUILD"), ) @@ -162,6 +164,27 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ) tf_http_archive( + name = "com_github_googlecloudplatform_google_cloud_cpp", + urls = [ + "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f9ff105957965bcf87f7cb9a93e951c3d08d1734.tar.gz", + "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f9ff105957965bcf87f7cb9a93e951c3d08d1734.tar.gz", + ], + sha256 = "edb347aae9869ffdcf8df6288335bcc535fec46da946b385c16968e96a74b208", + strip_prefix = "google-cloud-cpp-f9ff105957965bcf87f7cb9a93e951c3d08d1734", + ) + + tf_http_archive( + name = "com_github_googleapis_googleapis", + urls = [ + "https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip", + "https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip", + ], + sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378", + strip_prefix="googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb", + build_file = clean_dep("//third_party:googleapis.BUILD"), + ) + + tf_http_archive( name = "gemmlowp", urls = [ "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", @@ -231,11 +254,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "org_sqlite", urls = [ - "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip", - "https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip", + "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip", + "https://www.sqlite.org/2018/sqlite-amalgamation-3230100.zip", ], - sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6", - strip_prefix = "sqlite-amalgamation-3240000", + sha256 = "4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc", + strip_prefix = "sqlite-amalgamation-3230100", build_file = clean_dep("//third_party:sqlite.BUILD"), ) @@ -426,11 +449,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "grpc", urls = [ - "https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.12.1.tar.gz", - "https://github.com/grpc/grpc/archive/v1.12.1.tar.gz", + "https://mirror.bazel.build/github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz", + "https://github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz", ], - sha256 = "f6afbfafa8e7b524727d1ff37ff22fe9c3dcca07bd864e7a9d1efabf1d15d13c", - strip_prefix = "grpc-1.12.1", + sha256 = "895b31310e718a61f7335759a778c068a6edde1c089883598a0830cbb7075673", + strip_prefix = "grpc-d184fa229d75d336aedea0041bd59cb93e7e267f", ) @@ -660,12 +683,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "cython", - sha256 = "05e3eb7f06043f5ff2028338370329e71c29f57315e95f4dc6ad7c4971dd4c6f", + sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5", urls = [ - "https://mirror.bazel.build/github.com/cython/cython/archive/0.28.3.tar.gz", - "https://github.com/cython/cython/archive/0.28.3.tar.gz", + "https://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", + "https://github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", ], - strip_prefix = "cython-0.28.3", + strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17", build_file = clean_dep("//third_party:cython.BUILD"), delete = ["BUILD.bazel"], ) @@ -673,11 +696,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "bazel_toolchains", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2cec6c9f6d12224e93d9b3f337b24e41602de3ba.tar.gz", - "https://github.com/bazelbuild/bazel-toolchains/archive/2cec6c9f6d12224e93d9b3f337b24e41602de3ba.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz", ], - strip_prefix = "bazel-toolchains-2cec6c9f6d12224e93d9b3f337b24e41602de3ba", - sha256 = "9b8d85b61d8945422e86ac31e4d4d2d967542c080d1da1b45364da7fd6bdd638", + strip_prefix = "bazel-toolchains-44200e0c026d86c53470d107b3697a3e46469c43", + sha256 = "699b55a6916c687f4b7dc092dbbf5f64672cde0dc965f79717735ec4e5416556", ) tf_http_archive( diff --git a/third_party/googleapis.BUILD b/third_party/googleapis.BUILD new file mode 100644 index 0000000000..95e999af18 --- /dev/null +++ b/third_party/googleapis.BUILD @@ -0,0 +1,45 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) +licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE"]) + +load("@protobuf_archive//:protobuf.bzl", "cc_proto_library") + +cc_proto_library( + name = "bigtable_protos", + srcs = [ + "google/bigtable/admin/v2/bigtable_instance_admin.proto", + "google/bigtable/admin/v2/bigtable_table_admin.proto", + "google/bigtable/admin/v2/common.proto", + "google/bigtable/admin/v2/instance.proto", + "google/bigtable/admin/v2/table.proto", + "google/bigtable/v2/bigtable.proto", + "google/bigtable/v2/data.proto", + "google/iam/v1/iam_policy.proto", + "google/iam/v1/policy.proto", + "google/longrunning/operations.proto", + "google/rpc/status.proto", + "google/rpc/error_details.proto", + "google/api/annotations.proto", + "google/api/auth.proto", + "google/api/http.proto", + ], + include = ".", + protoc = "@protobuf_archive//:protoc", + default_runtime = "@protobuf_archive//:protobuf", + deps = ["@protobuf_archive//:cc_wkt_protos"], + use_grpc_plugin = True, +) |