From f35dc0a522ae630902baa5be16d2a53b59266770 Mon Sep 17 00:00:00 2001 From: Bruno Goncalves <882745+brunomorishita@users.noreply.github.com> Date: Sat, 28 Apr 2018 19:24:22 -0300 Subject: Fix cmake library path for libpng16.a --- tensorflow/contrib/cmake/external/png.cmake | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index ad2af01bc0..1a147e9c8e 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== include (ExternalProject) +include (GNUInstallDirs) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) @@ -35,7 +36,7 @@ if(WIN32) endif() endif() else() - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a) + set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/${CMAKE_INSTALL_LIBDIR}/libpng16.a) endif() set(png_HEADERS -- cgit v1.2.3 From f78fd433118830482dddbf6055751898a19265de Mon Sep 17 00:00:00 2001 From: jiefangxuanyan <505745416@qq.com> Date: Wed, 13 Jun 2018 17:28:23 +0800 Subject: Specify endianness in expected_result array to fix #15767. --- tensorflow/python/kernel_tests/decode_raw_op_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py index 122a9ed469..0bd8bc3c7b 100644 --- a/tensorflow/python/kernel_tests/decode_raw_op_test.py +++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py @@ -79,7 +79,7 @@ class DecodeRawOpTest(test.TestCase): decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.float16) self.assertEqual([None, None], decode.get_shape().as_list()) - expected_result = np.matrix([[1, -2, -3, 4]], dtype=np.float16) + expected_result = np.matrix([[1, -2, -3, 4]], dtype=" Date: Wed, 8 Aug 2018 14:34:16 -0700 Subject: Add deprecation warning to tf.gfile.FastGFile. Fixes #12663. --- tensorflow/python/platform/gfile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index 45de047894..510701e344 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -33,6 +33,7 @@ from tensorflow.python.lib.io.file_io import rename as Rename from tensorflow.python.lib.io.file_io import stat as Stat from tensorflow.python.lib.io.file_io import walk as Walk # pylint: enable=unused-import +from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -52,6 +53,7 @@ class GFile(_FileIO): @tf_export('gfile.FastGFile') +@deprecated(None, 'Use tf.gfile.GFile.') class FastGFile(_FileIO): """File I/O wrappers without thread locking. -- cgit v1.2.3 From 6c14d85b41c565ed9dabc3677aedf76757097242 Mon Sep 17 00:00:00 2001 From: rasmi Date: Wed, 8 Aug 2018 16:35:12 -0700 Subject: Changed order of export and deprecated decorators. --- tensorflow/python/platform/gfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index 510701e344..ac53609434 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -52,8 +52,8 @@ class GFile(_FileIO): super(GFile, self).__init__(name=name, mode=mode) -@tf_export('gfile.FastGFile') @deprecated(None, 'Use tf.gfile.GFile.') +@tf_export('gfile.FastGFile') class FastGFile(_FileIO): """File I/O wrappers without thread locking. -- cgit v1.2.3 From c3c6c45987692e8bc73eff2f10f9ec1a82f55287 Mon Sep 17 00:00:00 2001 From: rasmi Date: Thu, 9 Aug 2018 10:27:37 -0700 Subject: Moved @deprecated decorator to __init__ --- tensorflow/python/platform/gfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index ac53609434..5927bc2409 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -52,7 +52,6 @@ class GFile(_FileIO): super(GFile, self).__init__(name=name, mode=mode) -@deprecated(None, 'Use tf.gfile.GFile.') @tf_export('gfile.FastGFile') class FastGFile(_FileIO): """File I/O wrappers without thread locking. @@ -64,6 +63,7 @@ class FastGFile(_FileIO): invocations in network filesystems). """ + @deprecated(None, 'Use tf.gfile.GFile.') def __init__(self, name, mode='r'): super(FastGFile, self).__init__(name=name, mode=mode) -- cgit v1.2.3 From 22ebbbc60e5d94d67cdf6c26b44919f7dbb8f600 Mon Sep 17 00:00:00 2001 From: feiquan Date: Mon, 13 Aug 2018 23:44:38 +0800 Subject: extends the tensor index operator to support character access --- tensorflow/contrib/autograph/operators/slices.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py index 04fbeb2f6e..d878bddf3c 100644 --- a/tensorflow/contrib/autograph/operators/slices.py +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import gen_string_ops # TODO(mdan): Support extended slices. @@ -57,6 +58,8 @@ def get_item(target, i, opts): elif tensor_util.is_tensor(target): if target.dtype == dtypes.variant: return _tf_tensor_list_get_item(target, i, opts) + if target.dtype == dtypes.string: + return _tf_tensor_string_get_item(target, i) else: return _tf_tensor_get_item(target, i) else: @@ -81,6 +84,10 @@ def _tf_tensor_get_item(target, i): """Overload of get_item that stages a Tensor (not Tensor list) read.""" return target[i] +def _tf_tensor_string_get_item(target, i): + """Overload of get_item that stages a Tensor string read.""" + x = gen_string_ops.substr(target, i, 1) + return x def _py_get_item(target, i): """Overload of get_item that executes a Python list modification.""" -- cgit v1.2.3 From 349d81c80a5b64ae09a36624571ec24d9e7a8b1d Mon Sep 17 00:00:00 2001 From: feiquan Date: Tue, 14 Aug 2018 00:07:28 +0800 Subject: add test for gen_item_tensor_string --- tensorflow/contrib/autograph/operators/slices_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py index d4aacb9d20..9c0b2c77a1 100644 --- a/tensorflow/contrib/autograph/operators/slices_test.py +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -46,6 +46,13 @@ class SlicesTest(test.TestCase): with self.test_session() as sess: self.assertAllEqual(sess.run(t), [3, 4]) + def test_get_item_tensor_string(self): + initial_str = constant_op.constant("abcd") + t = slices.get_item(initial_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) + + with self.test_session() as sess: + self.assertEqual(sess.run(t), b"b") + if __name__ == '__main__': test.main() -- cgit v1.2.3 From 48aef32dcd356fa6bae490fa1c853b9b2cdd4846 Mon Sep 17 00:00:00 2001 From: kouml Date: Wed, 15 Aug 2018 02:27:32 +0900 Subject: removing redundant semicolon --- tensorflow/contrib/lite/toco/python/toco_from_protos_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py index 3761e0095e..75c1c8970c 100644 --- a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py +++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py @@ -50,7 +50,7 @@ class TocoFromProtosTest(googletest.TestCase): toco_flags.output_format = toco_flags_pb2.TFLITE toco_flags.inference_input_type = types_pb2.FLOAT toco_flags.inference_type = types_pb2.FLOAT - toco_flags.allow_custom_ops = True; + toco_flags.allow_custom_ops = True model_flags = model_flags_pb2.ModelFlags() input_array = model_flags.input_arrays.add() input_array.name = TensorName(in_tensor) -- cgit v1.2.3 From f2134cbd2ec4dd98f9f20ac41e4f46cdd0246af2 Mon Sep 17 00:00:00 2001 From: feiquan Date: Wed, 15 Aug 2018 08:47:22 +0800 Subject: use get_item_tensor_string for string with rank 0 --- tensorflow/contrib/autograph/operators/slices_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py index 9c0b2c77a1..5300428462 100644 --- a/tensorflow/contrib/autograph/operators/slices_test.py +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -53,6 +53,12 @@ class SlicesTest(test.TestCase): with self.test_session() as sess: self.assertEqual(sess.run(t), b"b") + initial_list_str = constant_op.constant(["abcd", "bcde"]) + t = slices.get_item(initial_list_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) + + with self.test_session() as sess: + self.assertEqual(sess.run(t), b"bcde") + if __name__ == '__main__': test.main() -- cgit v1.2.3 From 1843dc2bef2beabc1ac6765c14e03b1a07823bef Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 23 Jul 2018 14:43:28 -0700 Subject: Network.to_json should handle numpy.ndarray correctly --- tensorflow/python/keras/engine/network.py | 5 ++++- tensorflow/python/keras/engine/topology_test.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 708fa1c807..3cdd714d7e 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1574,7 +1574,10 @@ class Network(base_layer.Layer): def get_json_type(obj): # If obj is any numpy type if type(obj).__module__ == np.__name__: - return obj.item() + if isinstance(obj, np.ndarray): + return obj.tolist() + else: + return obj.item() # If obj is a python 'type' if type(obj).__name__ == type.__name__: diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 079c8dae71..3dfa933913 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -913,6 +913,28 @@ class TopologyConstructionTest(test.TestCase): self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4) + def test_constant_initializer_with_numpy(self): + + with self.test_session(): + model = keras.models.Sequential() + model.add( + keras.layers.Dense( + 2, + input_shape = (3,), + kernel_initializer = keras.initializers.Constant(np.ones((3, 2))) + ) + ) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + json_str = model.to_json() + keras.models.model_from_json(json_str) + + if yaml is not None: + yaml_str = model.to_yaml() + keras.models.model_from_yaml(yaml_str) + + class DeferredModeTest(test.TestCase): def testDeferredTensorAttributes(self): -- cgit v1.2.3 From 5ef4de5b01d10c4dae86a1e69cf1296671d55e47 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 15 Aug 2018 17:40:22 -0700 Subject: Fix bad indentation --- tensorflow/python/keras/engine/topology_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 3dfa933913..25ae3a61c3 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -918,11 +918,11 @@ class TopologyConstructionTest(test.TestCase): with self.test_session(): model = keras.models.Sequential() model.add( - keras.layers.Dense( - 2, - input_shape = (3,), - kernel_initializer = keras.initializers.Constant(np.ones((3, 2))) - ) + keras.layers.Dense( + 2, + input_shape = (3,), + kernel_initializer = keras.initializers.Constant(np.ones((3, 2))) + ) ) model.add(keras.layers.Dense(3)) model.compile(loss='mse', optimizer='sgd', metrics=['acc']) -- cgit v1.2.3 From 4a1fdff581db18e3262daebbc1f9543936bf47d1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 16 Aug 2018 13:14:34 -0700 Subject: Reorg code to escape bad indentation. --- tensorflow/python/keras/engine/topology_test.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 25ae3a61c3..1fcd77d7f6 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -912,18 +912,13 @@ class TopologyConstructionTest(test.TestCase): assert out.shape == (4, 3, 2, 1) self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4) - def test_constant_initializer_with_numpy(self): with self.test_session(): + initializer = keras.initializers.Constant(np.ones((3, 2))) model = keras.models.Sequential() - model.add( - keras.layers.Dense( - 2, - input_shape = (3,), - kernel_initializer = keras.initializers.Constant(np.ones((3, 2))) - ) - ) + model.add(keras.layers.Dense(2, input_shape=(3,), + kernel_initializer=initializer)) model.add(keras.layers.Dense(3)) model.compile(loss='mse', optimizer='sgd', metrics=['acc']) -- cgit v1.2.3 From 4c2f6aeaaf4aeafccc85a289a5a105d52738b410 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 17 Aug 2018 17:06:47 -0400 Subject: Simplyfing the evaluation step by taking argmax of the softmax of the predictions instead of tf.multinomial --- .../examples/generative_examples/image_captioning_with_attention.ipynb | 2 +- .../eager/python/examples/generative_examples/text_generation.ipynb | 2 +- .../eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 315d7a4893..e0f7137184 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1056,7 +1056,7 @@ "\n", " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", "\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", " result.append(index_word[predicted_id])\n", "\n", " if index_word[predicted_id] == '':\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index 40bc098724..b13e5aae9b 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -610,7 +610,7 @@ "\n", " # using a multinomial distribution to predict the word returned by the model\n", " predictions = predictions / temperature\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", " \n", " # We pass the predicted word as the next input to the model\n", " # along with the previous hidden state\n", diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index f1e1f99c57..3e02d9fbb0 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -677,7 +677,7 @@ " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", -- cgit v1.2.3 From c36ff7ae1d667979fa49899bf97de26cf35321de Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 17 Aug 2018 20:44:14 -0400 Subject: Removing tf.nn.softmax --- .../examples/generative_examples/image_captioning_with_attention.ipynb | 2 +- .../eager/python/examples/generative_examples/text_generation.ipynb | 2 +- .../eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index e0f7137184..5c753ec0f5 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1056,7 +1056,7 @@ "\n", " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", "\n", - " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", " result.append(index_word[predicted_id])\n", "\n", " if index_word[predicted_id] == '':\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index b13e5aae9b..e0d5e494d4 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -610,7 +610,7 @@ "\n", " # using a multinomial distribution to predict the word returned by the model\n", " predictions = predictions / temperature\n", - " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", " \n", " # We pass the predicted word as the next input to the model\n", " # along with the previous hidden state\n", diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 3e02d9fbb0..560fc8c5a2 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -677,7 +677,7 @@ " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", - " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", -- cgit v1.2.3 From e357bcea4b10d5e5cbc3a4ba59385e832401ba8d Mon Sep 17 00:00:00 2001 From: Dao Zhang Date: Thu, 23 Aug 2018 20:11:10 +0800 Subject: merge_repeated option is confusing I have the same question with [WIP: Remove invalid merge_repeated option from CTC beam decoder](https://github.com/tensorflow/tensorflow/pull/15586), it's a pity I haven't seen any changes for so long. Generally I will use the default value of merge_repeated: True, but I found it's confusing, that is, I got the wrong anser, it has been explained well in [WIP: Remove invalid merge_repeated option from CTC beam decoder](https://github.com/tensorflow/tensorflow/pull/15586). And the top path in ctc_beam_search_decoder is similar with sequence in ctc_greedy_decoder, this is confusing, I have found the project [CRNN](https://github.com/Belval/CRNN/blob/master/CRNN/crnn.py)(line 167) and some other projects use the wrong settings. So I think it's better to give a explain here, this has no conflict with the existing code. --- tensorflow/python/ops/ctc_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 908e793902..6bfe405b2b 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -242,11 +242,11 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, If `merge_repeated` is `True`, merge repeated classes in the output beams. This means that if consecutive entries in a beam are the same, - only the first of these is emitted. That is, when the top path - is `A B B B B`, the return value is: + only the first of these is emitted. That is, when the sequence is `A B B * B * B` (where '*' + is the blank label), the return value is: * `A B` if `merge_repeated = True`. - * `A B B B B` if `merge_repeated = False`. + * `A B B B` if `merge_repeated = False`. Args: inputs: 3-D `float` `Tensor`, size -- cgit v1.2.3 From 512f95d4b5e350fa0709aeef975730f22112b970 Mon Sep 17 00:00:00 2001 From: Clayne Robison Date: Fri, 24 Aug 2018 11:34:10 -0700 Subject: [Intel MKL] Adding cc tests to the MKL public CI tests. --- tensorflow/tools/ci_build/linux/cpu/run_mkl.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh index 2a9f295188..7be5f454ec 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh @@ -33,7 +33,7 @@ yes "" | $PYTHON_BIN_PATH configure.py # Setting KMP_BLOCKTIME to 0 lets OpenMP threads to sleep right after parallel execution # in an MKL primitive. This reduces the effects of an oversubscription of OpenMP threads # caused by executing multiple tests concurrently. -bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \ +bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=cc,py -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ --config=mkl --test_env=KMP_BLOCKTIME=0 --config=opt --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... -- cgit v1.2.3 From a7deb79f258a5dded26fcf85e9416f8463def451 Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Wed, 11 Jul 2018 11:24:58 +0800 Subject: [XLA/AOT] Build LLVM with Bazel on Windows --- third_party/llvm/llvm.bzl | 170 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 141 insertions(+), 29 deletions(-) diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index d493a3c476..626e0db3b1 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -150,6 +150,35 @@ def expand_cmake_vars(name, src, dst, cmake_vars): # The set of CMake variables common to all targets. cmake_vars = { + # LLVM features + "ENABLE_BACKTRACES": 1, + "LLVM_BINDIR": "/dev/null", + "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0, + "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0, + "LLVM_ENABLE_THREADS": 1, + "LLVM_ENABLE_ZLIB": 1, + "LLVM_HAS_ATOMICS": 1, + "LLVM_INCLUDEDIR": "/dev/null", + "LLVM_INFODIR": "/dev/null", + "LLVM_MANDIR": "/dev/null", + "LLVM_NATIVE_TARGET": 1, + "LLVM_NATIVE_TARGETINFO": 1, + "LLVM_NATIVE_TARGETMC": 1, + "LLVM_NATIVE_ASMPRINTER": 1, + "LLVM_NATIVE_ASMPARSER": 1, + "LLVM_NATIVE_DISASSEMBLER": 1, + "LLVM_PREFIX": "/dev/null", + "LLVM_VERSION_MAJOR": 0, + "LLVM_VERSION_MINOR": 0, + "LLVM_VERSION_PATCH": 0, + "PACKAGE_NAME": "llvm", + "PACKAGE_STRING": "llvm tensorflow-trunk", + "PACKAGE_VERSION": "tensorflow-trunk", + "RETSIGTYPE": "void", +} + +# The set of CMake variables common to POSIX targets. +posix_cmake_vars = { # Headers "HAVE_DIRENT_H": 1, "HAVE_DLFCN_H": 1, @@ -206,32 +235,8 @@ cmake_vars = { "HAVE__UNWIND_BACKTRACE": 1, # LLVM features - "ENABLE_BACKTRACES": 1, - "LLVM_BINDIR": "/dev/null", - "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0, - "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0, - "LLVM_ENABLE_THREADS": 1, - "LLVM_ENABLE_ZLIB": 1, - "LLVM_HAS_ATOMICS": 1, - "LLVM_INCLUDEDIR": "/dev/null", - "LLVM_INFODIR": "/dev/null", - "LLVM_MANDIR": "/dev/null", - "LLVM_NATIVE_TARGET": 1, - "LLVM_NATIVE_TARGETINFO": 1, - "LLVM_NATIVE_TARGETMC": 1, - "LLVM_NATIVE_ASMPRINTER": 1, - "LLVM_NATIVE_ASMPARSER": 1, - "LLVM_NATIVE_DISASSEMBLER": 1, "LLVM_ON_UNIX": 1, - "LLVM_PREFIX": "/dev/null", - "LLVM_VERSION_MAJOR": 0, - "LLVM_VERSION_MINOR": 0, - "LLVM_VERSION_PATCH": 0, "LTDL_SHLIB_EXT": ".so", - "PACKAGE_NAME": "llvm", - "PACKAGE_STRING": "llvm tensorflow-trunk", - "PACKAGE_VERSION": "tensorflow-trunk", - "RETSIGTYPE": "void", } # CMake variables specific to the Linux platform @@ -247,6 +252,40 @@ darwin_cmake_vars = { "HAVE_MALLOC_MALLOC_H": 1, } +# CMake variables specific to the Windows platform. +win32_cmake_vars = { + # Headers + "HAVE_ERRNO_H": 1, + "HAVE_EXECINFO_H": 1, + "HAVE_FCNTL_H": 1, + "HAVE_FENV_H": 1, + "HAVE_INTTYPES_H": 1, + "HAVE_MALLOC_H": 1, + "HAVE_SIGNAL_H": 1, + "HAVE_STDINT_H": 1, + "HAVE_SYS_STAT_H": 1, + "HAVE_SYS_TYPES_H": 1, + "HAVE_ZLIB_H": 1, + + # Features + "BACKTRACE_HEADER": "execinfo.h", + "HAVE_GETCWD": 1, + "HAVE_INT64_T": 1, + "HAVE_STRERROR": 1, + "HAVE_STRTOLL": 1, + "HAVE_SYSCONF": 1, + "HAVE_UINT64_T": 1, + "HAVE__CHSIZE_S": 1, + "HAVE___CHKSTK": 1, + + # MSVC specific + "stricmp": "_stricmp", + "strdup": "_strdup", + + # LLVM features + "LTDL_SHLIB_EXT": ".dll", +} + # Select a set of CMake variables based on the platform. # TODO(phawkins): use a better method to select the right host triple, rather # than hardcoding x86_64. @@ -265,6 +304,13 @@ llvm_all_cmake_vars = select({ linux_cmake_vars, ), ), + "@org_tensorflow//tensorflow:windows": cmake_var_string( + _dict_add( + cmake_vars, + llvm_target_cmake_vars("X86", "x86_64-pc-win32"), + win32_cmake_vars, + ), + ), "//conditions:default": cmake_var_string( _dict_add( cmake_vars, @@ -274,23 +320,89 @@ llvm_all_cmake_vars = select({ ), }) -llvm_linkopts = ["-ldl", "-lm", "-lpthread"] +llvm_linkopts = select({ + "@org_tensorflow//tensorflow:windows": [], + "//conditions:default": ["-ldl", "-lm", "-lpthread"], +}) -llvm_defines = [ - "LLVM_ENABLE_STATS", +llvm_defines = select({ + "@org_tensorflow//tensorflow:windows": [ + "_CRT_SECURE_NO_DEPRECATE", + "_CRT_SECURE_NO_WARNINGS", + "_CRT_NONSTDC_NO_DEPRECATE", + "_CRT_NONSTDC_NO_WARNINGS", + "_SCL_SECURE_NO_DEPRECATE", + "_SCL_SECURE_NO_WARNINGS", + "UNICODE", + "_UNICODE", + ], + "//conditions:default": ["_DEBUG"], +}) + [ "__STDC_LIMIT_MACROS", "__STDC_CONSTANT_MACROS", "__STDC_FORMAT_MACROS", - "_DEBUG", "LLVM_BUILD_GLOBAL_ISEL", ] -llvm_copts = [] +llvm_copts = select({ + "@org_tensorflow//tensorflow:windows": [ + "-Zc:inline", + "-Zc:strictStrings", + "-Zc:rvalueCast", + "-Oi", + "-wd4141", + "-wd4146", + "-wd4180", + "-wd4244", + "-wd4258", + "-wd4267", + "-wd4291", + "-wd4345", + "-wd4351", + "-wd4355", + "-wd4456", + "-wd4457", + "-wd4458", + "-wd4459", + "-wd4503", + "-wd4624", + "-wd4722", + "-wd4800", + "-wd4100", + "-wd4127", + "-wd4512", + "-wd4505", + "-wd4610", + "-wd4510", + "-wd4702", + "-wd4245", + "-wd4706", + "-wd4310", + "-wd4701", + "-wd4703", + "-wd4389", + "-wd4611", + "-wd4805", + "-wd4204", + "-wd4577", + "-wd4091", + "-wd4592", + "-wd4319", + "-wd4324", + "-w14062", + "-we4238", + ], + "//conditions:default": [], +}) # Platform specific sources for libSupport. def llvm_support_platform_specific_srcs_glob(): return select({ + "@org_tensorflow//tensorflow:windows": native.glob([ + "lib/Support/Windows/*.inc", + "lib/Support/Windows/*.h" + ]), "//conditions:default": native.glob([ "lib/Support/Unix/*.inc", "lib/Support/Unix/*.h", -- cgit v1.2.3 From 4a4ce8c6bff872f2a5522b289845491ea2da6f1e Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Wed, 11 Jul 2018 11:32:54 +0800 Subject: Add back LLVM_ENABLE_STATS --- third_party/llvm/llvm.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index 626e0db3b1..6da3e0755c 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -338,6 +338,7 @@ llvm_defines = select({ ], "//conditions:default": ["_DEBUG"], }) + [ + "LLVM_ENABLE_STATS", "__STDC_LIMIT_MACROS", "__STDC_CONSTANT_MACROS", "__STDC_FORMAT_MACROS", -- cgit v1.2.3 From d0b4230bc3052f080c901f7d999cf848c7d81450 Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Sat, 11 Aug 2018 18:11:47 +0800 Subject: Actually add posix_cmake_vars --- third_party/llvm/llvm.bzl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index 6da3e0755c..586935b6e6 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -294,6 +294,7 @@ llvm_all_cmake_vars = select({ _dict_add( cmake_vars, llvm_target_cmake_vars("X86", "x86_64-apple-darwin"), + posix_cmake_vars, darwin_cmake_vars, ), ), @@ -301,6 +302,7 @@ llvm_all_cmake_vars = select({ _dict_add( cmake_vars, llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu"), + posix_cmake_vars, linux_cmake_vars, ), ), @@ -315,6 +317,7 @@ llvm_all_cmake_vars = select({ _dict_add( cmake_vars, llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu"), + posix_cmake_vars, linux_cmake_vars, ), ), -- cgit v1.2.3 From b146281fd7f11325251fb085aca6bda8e2d77bfd Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Mon, 27 Aug 2018 11:33:21 -0700 Subject: [Intel MKL] Using default CPU allocator for small allocations in MklCPUAllocator This PR adds support to use default CPU allocator for handling small-size allocations. We found that BFC allocator does not do well on small allocations, but is good for large allocations. --- tensorflow/core/common_runtime/mkl_cpu_allocator.h | 177 +++++++++++++++++++-- 1 file changed, 168 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 99bd43e090..2778213a82 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/framework/allocator_registry.h" +#include "tensorflow/core/platform/mutex.h" #ifndef INTEL_MKL_DNN_ONLY #include "i_malloc.h" @@ -48,6 +50,120 @@ class MklSubAllocator : public SubAllocator { void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); } }; +/// CPU allocator that handles small-size allocations by calling +/// suballocator directly. Mostly, it is just a wrapper around a suballocator +/// (that calls malloc and free directly) with support for bookkeeping. +class MklSmallSizeAllocator : public VisitableAllocator { + public: + MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory, + const string& name) : sub_allocator_(sub_allocator), + name_(name) { + stats_.bytes_limit = total_memory; + } + ~MklSmallSizeAllocator() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator); + + inline string Name() override { return name_; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* ptr = nullptr; + if ((ptr = sub_allocator_->Alloc(alignment, num_bytes)) != nullptr) { + std::pair map_val(ptr, num_bytes); + mutex_lock l(mutex_); + // Check that insertion in the hash map was successful. + CHECK_EQ(map_.insert(map_val).second, true); + // Increment statistics for small-size allocations. + IncrementStats(num_bytes); + // Call alloc visitors. + for (const auto& visitor : alloc_visitors_) { + visitor(ptr, num_bytes); + } + } + return ptr; + } + + void DeallocateRaw(void* ptr) override { + if (ptr == nullptr) { + LOG(ERROR) << "tried to deallocate nullptr"; + return; + } + + mutex_lock l(mutex_); + auto map_iter = map_.find(ptr); + if (map_iter != map_.end()) { + // Call free visitors. + size_t dealloc_bytes = map_iter->second; + for (const auto& visitor : free_visitors_) { + visitor(ptr, dealloc_bytes); + } + sub_allocator_->Free(ptr, dealloc_bytes); + DecrementStats(dealloc_bytes); + map_.erase(map_iter); + } + } + + inline bool IsSmallSizeAllocation(const void* ptr) const { + mutex_lock l(mutex_); + return map_.find(ptr) != map_.end(); + } + + void GetStats(AllocatorStats* stats) override { + mutex_lock l(mutex_); + *stats = stats_; + } + + void ClearStats() override { + mutex_lock l(mutex_); + stats_.Clear(); + } + + void AddAllocVisitor(Visitor visitor) override { + mutex_lock l(mutex_); + alloc_visitors_.push_back(visitor); + } + + void AddFreeVisitor(Visitor visitor) override { + mutex_lock l(mutex_); + free_visitors_.push_back(visitor); + } + + private: + /// Increment statistics for the allocator handling small allocations. + inline void IncrementStats(size_t alloc_size) { + ++stats_.num_allocs; + stats_.bytes_in_use += alloc_size; + stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, + stats_.bytes_in_use); + stats_.max_alloc_size = std::max(alloc_size, + static_cast(stats_.max_alloc_size)); + } + + /// Decrement statistics for the allocator handling small allocations. + inline void DecrementStats(size_t dealloc_size) { + stats_.bytes_in_use -= dealloc_size; + } + + SubAllocator* sub_allocator_; // Not owned by this class. + + /// Mutex for protecting updates to map of allocations. + mutable mutex mutex_; + + /// Allocator name + string name_; + + /// Hash map to keep track of "small" allocations + /// We do not use BFC allocator for small allocations. + std::unordered_map map_ GUARDED_BY(mutex_); + + /// Allocator stats for small allocs + AllocatorStats stats_ GUARDED_BY(mutex_); + + /// Visitors + std::vector alloc_visitors_ GUARDED_BY(mutex_); + std::vector free_visitors_ GUARDED_BY(mutex_); +}; + /// CPU allocator for MKL that wraps BFC allocator and intercepts /// and redirects memory allocation calls from MKL. class MklCPUAllocator : public VisitableAllocator { @@ -62,7 +178,10 @@ class MklCPUAllocator : public VisitableAllocator { MklCPUAllocator() { TF_CHECK_OK(Initialize()); } - ~MklCPUAllocator() override { delete allocator_; } + ~MklCPUAllocator() override { + delete small_size_allocator_; + delete large_size_allocator_; + } Status Initialize() { VLOG(2) << "MklCPUAllocator: In MklCPUAllocator"; @@ -96,7 +215,11 @@ class MklCPUAllocator : public VisitableAllocator { } VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes; - allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes, + + sub_allocator_ = new MklSubAllocator(); + small_size_allocator_ = new MklSmallSizeAllocator(sub_allocator_, + max_mem_bytes, kName); + large_size_allocator_ = new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName); #ifndef INTEL_MKL_DNN_ONLY // For redirecting all allocations from MKL to this allocator @@ -112,23 +235,52 @@ class MklCPUAllocator : public VisitableAllocator { inline string Name() override { return kName; } inline void* AllocateRaw(size_t alignment, size_t num_bytes) override { - return allocator_->AllocateRaw(alignment, num_bytes); + // If the allocation size is less than threshold, call small allocator, + // otherwise call large-size allocator (BFC). We found that BFC allocator + // does not deliver good performance for small allocations when + // inter_op_parallelism_threads is high. + return (num_bytes < kSmallAllocationsThreshold) ? + small_size_allocator_->AllocateRaw(alignment, num_bytes) : + large_size_allocator_->AllocateRaw(alignment, num_bytes); } inline void DeallocateRaw(void* ptr) override { - allocator_->DeallocateRaw(ptr); + // Check if ptr is for "small" allocation. If it is, then call Free + // directly. Otherwise, call BFC to handle free. + if (small_size_allocator_->IsSmallSizeAllocation(ptr)) { + small_size_allocator_->DeallocateRaw(ptr); + } else { + large_size_allocator_->DeallocateRaw(ptr); + } } - void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); } + void GetStats(AllocatorStats* stats) override { + AllocatorStats l_stats, s_stats; + small_size_allocator_->GetStats(&s_stats); + large_size_allocator_->GetStats(&l_stats); + + // Combine statistics from small-size and large-size allocator. + stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs; + stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use; + stats->max_bytes_in_use = l_stats.max_bytes_in_use + + s_stats.max_bytes_in_use; + stats->max_alloc_size = std::max(l_stats.max_alloc_size, + s_stats.max_alloc_size); + } - void ClearStats() override { allocator_->ClearStats(); } + void ClearStats() override { + small_size_allocator_->ClearStats(); + large_size_allocator_->ClearStats(); + } void AddAllocVisitor(Visitor visitor) override { - allocator_->AddAllocVisitor(visitor); + small_size_allocator_->AddAllocVisitor(visitor); + large_size_allocator_->AddAllocVisitor(visitor); } void AddFreeVisitor(Visitor visitor) override { - allocator_->AddFreeVisitor(visitor); + small_size_allocator_->AddFreeVisitor(visitor); + large_size_allocator_->AddFreeVisitor(visitor); } private: @@ -165,7 +317,14 @@ class MklCPUAllocator : public VisitableAllocator { /// The alignment that we need for the allocations static constexpr const size_t kAlignment = 64; - VisitableAllocator* allocator_; // owned by this class + VisitableAllocator* large_size_allocator_; // owned by this class + MklSmallSizeAllocator* small_size_allocator_; // owned by this class. + + SubAllocator* sub_allocator_; // not owned by this class + + /// Size in bytes that defines the upper-bound for "small" allocations. + /// Any allocation below this threshold is "small" allocation. + static constexpr const size_t kSmallAllocationsThreshold = 4096; }; } // namespace tensorflow -- cgit v1.2.3 From f7d27bc67e5d89e5f4bb6d6a0a198c28fa8af46f Mon Sep 17 00:00:00 2001 From: Sangjung Woo Date: Thu, 30 Aug 2018 17:17:23 +0900 Subject: fix the comparison error when building a CPP API application When building a CPP API application with "-Wall -Werror" option , `error: comparison between signed and unsigned integer expressions' occurs since return type of num_elements() is 'int64' instead of 'size_t' in ops.h to express -1. This patch fixes this bug by explicit type casting. * related issue: https://github.com/tensorflow/tensorflow/issues/20428 Signed-off-by: Sangjung Woo --- tensorflow/cc/framework/ops.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index a085e1d6e2..0717e7dd4b 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -150,7 +150,7 @@ class Input { Initializer(const std::initializer_list& v, const TensorShape& shape) { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); - if (t.NumElements() != v.size()) { + if (t.NumElements() != static_cast(v.size())) { status = errors::InvalidArgument( "Cannot construct a tensor with ", t.NumElements(), " from an initializer list with ", v.size(), " elements"); -- cgit v1.2.3 From 74af314e4573e168d38072f646495034412ff061 Mon Sep 17 00:00:00 2001 From: 在原佐为 Date: Mon, 3 Sep 2018 10:09:05 +0800 Subject: use single quotation marks for single-line strings --- tensorflow/contrib/autograph/operators/slices_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py index 5300428462..329d9f1f43 100644 --- a/tensorflow/contrib/autograph/operators/slices_test.py +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -47,13 +47,13 @@ class SlicesTest(test.TestCase): self.assertAllEqual(sess.run(t), [3, 4]) def test_get_item_tensor_string(self): - initial_str = constant_op.constant("abcd") + initial_str = constant_op.constant('abcd') t = slices.get_item(initial_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) with self.test_session() as sess: - self.assertEqual(sess.run(t), b"b") + self.assertEqual(sess.run(t), b'b') - initial_list_str = constant_op.constant(["abcd", "bcde"]) + initial_list_str = constant_op.constant(['abcd', 'bcde']) t = slices.get_item(initial_list_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) with self.test_session() as sess: -- cgit v1.2.3 From 752e94a7d73a5c11a1b51b08bc170b0d91724a1c Mon Sep 17 00:00:00 2001 From: 在原佐为 Date: Mon, 3 Sep 2018 10:09:44 +0800 Subject: use single quotation marks for single-line strings --- tensorflow/contrib/autograph/operators/slices_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py index 329d9f1f43..2c5ffed4f2 100644 --- a/tensorflow/contrib/autograph/operators/slices_test.py +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -57,7 +57,7 @@ class SlicesTest(test.TestCase): t = slices.get_item(initial_list_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) with self.test_session() as sess: - self.assertEqual(sess.run(t), b"bcde") + self.assertEqual(sess.run(t), b'bcde') if __name__ == '__main__': -- cgit v1.2.3 From d118516dd6c5b9fd2f0bfa2b870e7cfb5063e7dc Mon Sep 17 00:00:00 2001 From: Roger Xin Date: Mon, 3 Sep 2018 11:52:42 -0400 Subject: Fix issues in maxout layer --- tensorflow/contrib/layers/python/layers/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 04668f112d..a82d4c1951 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -3109,7 +3109,7 @@ def maxout(inputs, num_units, axis=-1, scope=None): inputs: Tensor input num_units: Specifies how many features will remain after maxout in the `axis` dimension (usually channel). - This must be multiple of number of `axis`. + This must be a factor of number of features. axis: The dimension where max pooling will be performed. Default is the last dimension. scope: Optional scope for variable_scope. @@ -3128,7 +3128,7 @@ def maxout(inputs, num_units, axis=-1, scope=None): raise ValueError('number of features({}) is not ' 'a multiple of num_units({})'.format( num_channels, num_units)) - shape[axis] = -1 + shape[axis] = num_units shape += [num_channels // num_units] # Dealing with batches with arbitrary sizes -- cgit v1.2.3 From ce9e5b035b32ef02cd7d10f6ffdd27cc2a75664d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 1 Sep 2018 01:40:41 +0000 Subject: Fix syntax error in single_image_random_dot_stereograms caused by locale This fix tries to address the issue raised in 21164 where the single_image_random_dot_stereograms in different locale (like de_DE) caused syntax error in python like: ``` File "", line 28 def single_image_random_dot_stereograms(depth_values, hidden_surface_removal=True, convergence_dots_size=8, dots_per_inch=72, eye_separation=2,5, mu=0,333299994, normalize=True, normalize_max=-100, normalize_min=100, border_level=0, number_colors=256, output_image_shape=[1024, 768, 1], output_data_window=[1022, 757], name=None): ^ SyntaxError: invalid syntax ``` The issue was that the float to string conversion in python_op_gen_internal.cc triggered snprintf (in `FloatToBuffer`) which is local dependent and generates something like `eye_separatiion=2,5` in DE locale. This fix replaced the float to string conversion with locale-independent ``` std::ostringstream s; s.imbue(std::locale::classic()); ``` This fix fixes 21164. Signed-off-by: Yong Tang --- tensorflow/python/framework/python_op_gen_internal.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index f2270342b0..8ddd1e6432 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -435,7 +435,10 @@ string AttrValueToPython(const string& type, const AttrValue& value, if (std::isnan(value.f()) || std::isinf(value.f())) { return strings::StrCat("float('", value.f(), "')"); } else { - return strings::StrCat(value.f()); + std::ostringstream s; + s.imbue(std::locale::classic()); + s << value.f(); + return s.str(); } } else if (type == "bool") { return value.b() ? "True" : "False"; -- cgit v1.2.3 From a8a0ec4a2eaf37c853afe410964978715c3d02bb Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 1 Sep 2018 01:55:43 +0000 Subject: Add precision to match the existing behavior. Signed-off-by: Yong Tang --- tensorflow/python/framework/python_op_gen_internal.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index 8ddd1e6432..dafaf2fd3a 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/python/framework/python_op_gen_internal.h" #include +#include +#include #include #include #include "tensorflow/core/framework/api_def.pb.h" @@ -435,9 +437,11 @@ string AttrValueToPython(const string& type, const AttrValue& value, if (std::isnan(value.f()) || std::isinf(value.f())) { return strings::StrCat("float('", value.f(), "')"); } else { + // Use locale-independent conversion. + static_assert(FLT_DIG < 10, "FLT_DIG is too big"); std::ostringstream s; s.imbue(std::locale::classic()); - s << value.f(); + s << std::setprecision(FLT_DIG) << value.f(); return s.str(); } } else if (type == "bool") { -- cgit v1.2.3 From 569426a13fbae66c0acd7ed728a62f413407b898 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 1 Sep 2018 01:58:35 +0000 Subject: Sanitize with clang-foramt Signed-off-by: Yong Tang --- tensorflow/python/framework/python_op_gen_internal.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index dafaf2fd3a..7c4941a586 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/python/framework/python_op_gen_internal.h" -#include #include +#include #include #include #include -- cgit v1.2.3 From ad997f1c24829dbe3c687d449a757202c401bb6f Mon Sep 17 00:00:00 2001 From: 在原佐为 Date: Tue, 4 Sep 2018 23:25:30 +0800 Subject: only apply _string_get_item for string with rank 0 --- tensorflow/contrib/autograph/operators/slices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py index d878bddf3c..a885bdab5b 100644 --- a/tensorflow/contrib/autograph/operators/slices.py +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -58,7 +58,7 @@ def get_item(target, i, opts): elif tensor_util.is_tensor(target): if target.dtype == dtypes.variant: return _tf_tensor_list_get_item(target, i, opts) - if target.dtype == dtypes.string: + elif target.dtype == dtypes.string and target.get_shape() == (): # target is string with rank 0 return _tf_tensor_string_get_item(target, i) else: return _tf_tensor_get_item(target, i) -- cgit v1.2.3 From 9c7ca4c83b2e98517d0ccbba81b6b7fbc178d731 Mon Sep 17 00:00:00 2001 From: 在原佐为 Date: Wed, 5 Sep 2018 08:15:42 +0800 Subject: use ndims --- tensorflow/contrib/autograph/operators/slices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py index a885bdab5b..4b3f7ebee8 100644 --- a/tensorflow/contrib/autograph/operators/slices.py +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -58,7 +58,7 @@ def get_item(target, i, opts): elif tensor_util.is_tensor(target): if target.dtype == dtypes.variant: return _tf_tensor_list_get_item(target, i, opts) - elif target.dtype == dtypes.string and target.get_shape() == (): # target is string with rank 0 + elif target.dtype == dtypes.string and target.shape.ndims == 0: # target is string with rank 0 return _tf_tensor_string_get_item(target, i) else: return _tf_tensor_get_item(target, i) -- cgit v1.2.3 From f00855ee9c8ae8878a2feca7c2c8a23e4b9c6c11 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 5 Sep 2018 06:06:23 +0000 Subject: Update include order of the header files in python_op_gen_internal.cc, to conform to `Experimental clang-format Check` Signed-off-by: Yong Tang --- tensorflow/python/framework/python_op_gen_internal.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index 7c4941a586..f6aef5bc50 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -23,12 +23,12 @@ limitations under the License. #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/framework/tensor.pb_text.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor.pb_text.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -- cgit v1.2.3 From 32e96b1dc588cccf4e008259f831c4e50d948dc7 Mon Sep 17 00:00:00 2001 From: "Yan Facai (颜发才)" Date: Wed, 5 Sep 2018 15:46:09 +0800 Subject: ENH: add gradient for broadcast_to --- .../python/kernel_tests/broadcast_to_ops_test.py | 20 ++++++++++++++++++++ tensorflow/python/ops/array_grad.py | 19 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 6a1bd958ba..282a619094 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.platform import test as test_lib @@ -81,5 +82,24 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) + def testGradient(self): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientForScalar(self): + x = constant_op.constant(1, dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + if __name__ == "__main__": test_lib.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 6ae869b89e..ade86e85bf 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -805,3 +805,22 @@ def _ScatterNdNonAliasingAddGrad(op, grad): indices = op.inputs[1] updates_grad = array_ops.gather_nd(grad, indices) return [grad, None, updates_grad] + + +@ops.RegisterGradient("BroadcastTo") +def _BroadcastToGrad(op, grad): + input_value = op.inputs[0] + broadcast_shape = op.inputs[1] + # Assign ids for each position in input_value. + input_value_shape = array_ops.shape(input_value) + input_value_size = array_ops.size(input_value) + ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape) + broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape) + # Group by ids and sum its gradients. + grad_flatten = array_ops.reshape(grad, [-1]) + broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1]) + updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten, + broadcast_ids_flatten, + input_value_size) + updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape) + return [updates_grad, None] -- cgit v1.2.3 From 24bd1154b3c83cbf07883010240c3d1d13e25833 Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Wed, 5 Sep 2018 15:28:00 -0700 Subject: Addressing review comments --- tensorflow/core/common_runtime/mkl_cpu_allocator.h | 55 +++++++++++++--------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 2778213a82..553f07020e 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -50,9 +50,9 @@ class MklSubAllocator : public SubAllocator { void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); } }; -/// CPU allocator that handles small-size allocations by calling -/// suballocator directly. Mostly, it is just a wrapper around a suballocator -/// (that calls malloc and free directly) with support for bookkeeping. +// CPU allocator that handles small-size allocations by calling +// suballocator directly. Mostly, it is just a wrapper around a suballocator +// (that calls malloc and free directly) with support for bookkeeping. class MklSmallSizeAllocator : public VisitableAllocator { public: MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory, @@ -67,12 +67,12 @@ class MklSmallSizeAllocator : public VisitableAllocator { inline string Name() override { return name_; } void* AllocateRaw(size_t alignment, size_t num_bytes) override { - void* ptr = nullptr; - if ((ptr = sub_allocator_->Alloc(alignment, num_bytes)) != nullptr) { + void* ptr = sub_allocator_->Alloc(alignment, num_bytes); + if (ptr != nullptr) { std::pair map_val(ptr, num_bytes); mutex_lock l(mutex_); // Check that insertion in the hash map was successful. - CHECK_EQ(map_.insert(map_val).second, true); + CHECK(map_.insert(map_val).second); // Increment statistics for small-size allocations. IncrementStats(num_bytes); // Call alloc visitors. @@ -100,6 +100,9 @@ class MklSmallSizeAllocator : public VisitableAllocator { sub_allocator_->Free(ptr, dealloc_bytes); DecrementStats(dealloc_bytes); map_.erase(map_iter); + } else { + LOG(ERROR) << "tried to deallocate invalid pointer"; + return; } } @@ -129,8 +132,8 @@ class MklSmallSizeAllocator : public VisitableAllocator { } private: - /// Increment statistics for the allocator handling small allocations. - inline void IncrementStats(size_t alloc_size) { + // Increment statistics for the allocator handling small allocations. + inline void IncrementStats(size_t alloc_size) GUARDED_BY(mutex_) { ++stats_.num_allocs; stats_.bytes_in_use += alloc_size; stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, @@ -139,27 +142,27 @@ class MklSmallSizeAllocator : public VisitableAllocator { static_cast(stats_.max_alloc_size)); } - /// Decrement statistics for the allocator handling small allocations. - inline void DecrementStats(size_t dealloc_size) { + // Decrement statistics for the allocator handling small allocations. + inline void DecrementStats(size_t dealloc_size) GUARDED_BY(mutex_) { stats_.bytes_in_use -= dealloc_size; } SubAllocator* sub_allocator_; // Not owned by this class. - /// Mutex for protecting updates to map of allocations. + // Mutex for protecting updates to map of allocations. mutable mutex mutex_; - /// Allocator name + // Allocator name string name_; - /// Hash map to keep track of "small" allocations - /// We do not use BFC allocator for small allocations. + // Hash map to keep track of "small" allocations + // We do not use BFC allocator for small allocations. std::unordered_map map_ GUARDED_BY(mutex_); - /// Allocator stats for small allocs + // Allocator stats for small allocs AllocatorStats stats_ GUARDED_BY(mutex_); - /// Visitors + // Visitors std::vector alloc_visitors_ GUARDED_BY(mutex_); std::vector free_visitors_ GUARDED_BY(mutex_); }; @@ -217,6 +220,9 @@ class MklCPUAllocator : public VisitableAllocator { VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes; sub_allocator_ = new MklSubAllocator(); + + // SubAllocator is owned by BFCAllocator, so we do not need to deallocate + // it in MklSmallSizeAllocator. small_size_allocator_ = new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); large_size_allocator_ = new BFCAllocator(sub_allocator_, max_mem_bytes, @@ -264,8 +270,11 @@ class MklCPUAllocator : public VisitableAllocator { stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use; stats->max_bytes_in_use = l_stats.max_bytes_in_use + s_stats.max_bytes_in_use; - stats->max_alloc_size = std::max(l_stats.max_alloc_size, - s_stats.max_alloc_size); + + // Since small-size allocations go to MklSmallSizeAllocator, + // max_alloc_size from large_size_allocator would be the maximum + // size allocated by MklCPUAllocator. + stats->max_alloc_size = l_stats.max_alloc_size; } void ClearStats() override { @@ -308,13 +317,13 @@ class MklCPUAllocator : public VisitableAllocator { TF_CHECK_OK(s); // way to assert with an error message } - /// Do we allow growth in BFC Allocator + // Do we allow growth in BFC Allocator static const bool kAllowGrowth = true; - /// Name + // Name static constexpr const char* kName = "mklcpu"; - /// The alignment that we need for the allocations + // The alignment that we need for the allocations static constexpr const size_t kAlignment = 64; VisitableAllocator* large_size_allocator_; // owned by this class @@ -322,8 +331,8 @@ class MklCPUAllocator : public VisitableAllocator { SubAllocator* sub_allocator_; // not owned by this class - /// Size in bytes that defines the upper-bound for "small" allocations. - /// Any allocation below this threshold is "small" allocation. + // Size in bytes that defines the upper-bound for "small" allocations. + // Any allocation below this threshold is "small" allocation. static constexpr const size_t kSmallAllocationsThreshold = 4096; }; -- cgit v1.2.3 From 352d2a0a2a099ae830855c94a30f9ea657556aef Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Wed, 5 Sep 2018 16:35:38 -0700 Subject: Addressing review comments --- tensorflow/core/common_runtime/mkl_cpu_allocator.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 553f07020e..200ca57a9a 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -133,7 +133,8 @@ class MklSmallSizeAllocator : public VisitableAllocator { private: // Increment statistics for the allocator handling small allocations. - inline void IncrementStats(size_t alloc_size) GUARDED_BY(mutex_) { + inline void + IncrementStats(size_t alloc_size) EXCLUSIVE_LOCKS_REQUIRED(mutex_) { ++stats_.num_allocs; stats_.bytes_in_use += alloc_size; stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, @@ -143,7 +144,8 @@ class MklSmallSizeAllocator : public VisitableAllocator { } // Decrement statistics for the allocator handling small allocations. - inline void DecrementStats(size_t dealloc_size) GUARDED_BY(mutex_) { + inline void + DecrementStats(size_t dealloc_size) EXCLUSIVE_LOCKS_REQUIRED(mutex_) { stats_.bytes_in_use -= dealloc_size; } -- cgit v1.2.3 From d41f5ffb9cdc1c047db2f7b8a71ef24d39d12fb0 Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Wed, 4 Jul 2018 09:04:57 +0800 Subject: [Bazel/MSVC] Enable jpeg SIMD for MSVC - Add config/msvc.h when building nasm on Windows - Update Windows SIMD for libjpeg-turbo 2.0.0 - Add missing source files --- third_party/jpeg/jpeg.BUILD | 139 +++++++++++++++++++++++++++++++++++++++++++- third_party/nasm.BUILD | 5 +- 2 files changed, 141 insertions(+), 3 deletions(-) diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD index 5edf4f8120..1b9b9bf2f5 100644 --- a/third_party/jpeg/jpeg.BUILD +++ b/third_party/jpeg/jpeg.BUILD @@ -11,8 +11,8 @@ libjpegturbo_nocopts = "-[W]error" WIN_COPTS = [ "/Ox", - "/w14711", # function 'function' selected for inline expansion - "/w14710", # 'function' : function not inlined + "-DWITH_SIMD", + "-wd4996", ] libjpegturbo_copts = select({ @@ -127,6 +127,7 @@ cc_library( ":armeabi-v7a": [":simd_armv7a"], ":arm64-v8a": [":simd_armv8a"], ":linux_ppc64le": [":simd_altivec"], + ":windows": [":simd_win_x86_64"], "//conditions:default": [":simd_none"], }), ) @@ -350,6 +351,140 @@ cc_library( nocopts = libjpegturbo_nocopts, ) +cc_library( + name = "simd_win_x86_64", + srcs = [ + "jchuff.h", + "jconfig.h", + "jconfigint.h", + "jdct.h", + "jerror.h", + "jinclude.h", + "jmorecfg.h", + "jpegint.h", + "jpeglib.h", + "jsimd.h", + "jsimddct.h", + "simd/jsimd.h", + "simd/x86_64/jsimd.c", + "simd/x86_64/jccolor-avx2.obj", + "simd/x86_64/jccolor-sse2.obj", + "simd/x86_64/jcgray-avx2.obj", + "simd/x86_64/jcgray-sse2.obj", + "simd/x86_64/jchuff-sse2.obj", + "simd/x86_64/jcphuff-sse2.obj", + "simd/x86_64/jcsample-avx2.obj", + "simd/x86_64/jcsample-sse2.obj", + "simd/x86_64/jdcolor-avx2.obj", + "simd/x86_64/jdcolor-sse2.obj", + "simd/x86_64/jdmerge-avx2.obj", + "simd/x86_64/jdmerge-sse2.obj", + "simd/x86_64/jdsample-avx2.obj", + "simd/x86_64/jdsample-sse2.obj", + "simd/x86_64/jfdctflt-sse.obj", + "simd/x86_64/jfdctfst-sse2.obj", + "simd/x86_64/jfdctint-avx2.obj", + "simd/x86_64/jfdctint-sse2.obj", + "simd/x86_64/jidctflt-sse2.obj", + "simd/x86_64/jidctfst-sse2.obj", + "simd/x86_64/jidctint-avx2.obj", + "simd/x86_64/jidctint-sse2.obj", + "simd/x86_64/jidctred-sse2.obj", + "simd/x86_64/jquantf-sse2.obj", + "simd/x86_64/jquanti-avx2.obj", + "simd/x86_64/jquanti-sse2.obj", + "simd/x86_64/jsimdcpu.obj", + ], + copts = libjpegturbo_copts, +) + +genrule( + name = "simd_win_x86_64_assemble", + srcs = [ + "jconfig.h", + "jconfigint.h", + "simd/x86_64/jccolext-avx2.asm", + "simd/x86_64/jccolext-sse2.asm", + "simd/x86_64/jccolor-avx2.asm", + "simd/x86_64/jccolor-sse2.asm", + "simd/x86_64/jcgray-avx2.asm", + "simd/x86_64/jcgray-sse2.asm", + "simd/x86_64/jcgryext-avx2.asm", + "simd/x86_64/jcgryext-sse2.asm", + "simd/x86_64/jchuff-sse2.asm", + "simd/x86_64/jcphuff-sse2.asm", + "simd/x86_64/jcsample-avx2.asm", + "simd/x86_64/jcsample-sse2.asm", + "simd/x86_64/jdcolext-avx2.asm", + "simd/x86_64/jdcolext-sse2.asm", + "simd/x86_64/jdcolor-avx2.asm", + "simd/x86_64/jdcolor-sse2.asm", + "simd/x86_64/jdmerge-avx2.asm", + "simd/x86_64/jdmerge-sse2.asm", + "simd/x86_64/jdmrgext-avx2.asm", + "simd/x86_64/jdmrgext-sse2.asm", + "simd/x86_64/jdsample-avx2.asm", + "simd/x86_64/jdsample-sse2.asm", + "simd/x86_64/jfdctflt-sse.asm", + "simd/x86_64/jfdctfst-sse2.asm", + "simd/x86_64/jfdctint-avx2.asm", + "simd/x86_64/jfdctint-sse2.asm", + "simd/x86_64/jidctflt-sse2.asm", + "simd/x86_64/jidctfst-sse2.asm", + "simd/x86_64/jidctint-avx2.asm", + "simd/x86_64/jidctint-sse2.asm", + "simd/x86_64/jidctred-sse2.asm", + "simd/x86_64/jquantf-sse2.asm", + "simd/x86_64/jquanti-avx2.asm", + "simd/x86_64/jquanti-sse2.asm", + "simd/x86_64/jsimdcpu.asm", + "simd/nasm/jcolsamp.inc", + "simd/nasm/jdct.inc", + "simd/nasm/jpeg_nbits_table.inc", + "simd/nasm/jsimdcfg.inc", + "simd/nasm/jsimdcfg.inc.h", + "simd/nasm/jsimdext.inc", + ], + outs = [ + "simd/x86_64/jccolor-avx2.obj", + "simd/x86_64/jccolor-sse2.obj", + "simd/x86_64/jcgray-avx2.obj", + "simd/x86_64/jcgray-sse2.obj", + "simd/x86_64/jchuff-sse2.obj", + "simd/x86_64/jcphuff-sse2.obj", + "simd/x86_64/jcsample-avx2.obj", + "simd/x86_64/jcsample-sse2.obj", + "simd/x86_64/jdcolor-avx2.obj", + "simd/x86_64/jdcolor-sse2.obj", + "simd/x86_64/jdmerge-avx2.obj", + "simd/x86_64/jdmerge-sse2.obj", + "simd/x86_64/jdsample-avx2.obj", + "simd/x86_64/jdsample-sse2.obj", + "simd/x86_64/jfdctflt-sse.obj", + "simd/x86_64/jfdctfst-sse2.obj", + "simd/x86_64/jfdctint-avx2.obj", + "simd/x86_64/jfdctint-sse2.obj", + "simd/x86_64/jidctflt-sse2.obj", + "simd/x86_64/jidctfst-sse2.obj", + "simd/x86_64/jidctint-avx2.obj", + "simd/x86_64/jidctint-sse2.obj", + "simd/x86_64/jidctred-sse2.obj", + "simd/x86_64/jquantf-sse2.obj", + "simd/x86_64/jquanti-avx2.obj", + "simd/x86_64/jquanti-sse2.obj", + "simd/x86_64/jsimdcpu.obj", + ], + cmd = "for out in $(OUTS); do\n" + + " $(location @nasm//:nasm) -fwin64 -DWIN64 -D__x86_64__" + + " -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" + + " -I $$(dirname $(location simd/nasm/jdct.inc))/" + + " -I $$(dirname $(location simd/nasm/jdct.inc))/../../win/" + + " -o $$out" + + " $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.obj}.asm)\n" + + "done", + tools = ["@nasm"], +) + cc_library( name = "simd_none", srcs = [ diff --git a/third_party/nasm.BUILD b/third_party/nasm.BUILD index 2b877883b9..d746a65e7e 100644 --- a/third_party/nasm.BUILD +++ b/third_party/nasm.BUILD @@ -133,7 +133,10 @@ cc_binary( "x86/regs.c", "x86/regs.h", "x86/regvals.c", - ], + ] + select({ + ":windows": ["config/msvc.h"], + "//conditions:default": [], + }), includes = [ "asm", "include", -- cgit v1.2.3 From 8859ee06cc0cba03d05ce9677b05ff1993c34b03 Mon Sep 17 00:00:00 2001 From: "Yan Facai (颜发才)" Date: Thu, 6 Sep 2018 22:45:25 +0800 Subject: TST: add more test cases --- .../python/kernel_tests/broadcast_to_ops_test.py | 30 ++++++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 282a619094..8bcf27466c 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -82,8 +82,8 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) - def testGradient(self): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + def testGradientForScalar(self): + x = constant_op.constant(1, dtype=dtypes.float32) v = array_ops.broadcast_to(x, [2, 4, 3]) out = 2 * v with self.test_session(): @@ -91,9 +91,29 @@ class BroadcastToTest(test_util.TensorFlowTestCase): out, out.get_shape()) self.assertLess(err, 1e-4) - def testGradientForScalar(self): - x = constant_op.constant(1, dtype=dtypes.float32) - v = array_ops.broadcast_to(x, [2, 4, 3]) + def testGradientWithSameRank(self): + x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)), + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 5, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithIncreasingRank(self): + x = constant_op.constant([[1], [2]], + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 2, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithBroadcastAllDimensions(self): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 4, 6]) out = 2 * v with self.test_session(): err = gradient_checker.compute_gradient_error(x, x.get_shape(), -- cgit v1.2.3 From fc662e10661d44e5f00d3a93e0f0be867244880d Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Thu, 6 Sep 2018 11:58:00 -0700 Subject: Fixing clang formatting issue --- tensorflow/core/common_runtime/mkl_cpu_allocator.h | 48 +++++++++++----------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index b80d507774..49f6695330 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -24,10 +24,10 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/platform/mutex.h" #ifndef INTEL_MKL_DNN_ONLY @@ -56,8 +56,8 @@ class MklSubAllocator : public SubAllocator { class MklSmallSizeAllocator : public VisitableAllocator { public: MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory, - const string& name) : sub_allocator_(sub_allocator), - name_(name) { + const string& name) + : sub_allocator_(sub_allocator), name_(name) { stats_.bytes_limit = total_memory; } ~MklSmallSizeAllocator() override {} @@ -133,19 +133,19 @@ class MklSmallSizeAllocator : public VisitableAllocator { private: // Increment statistics for the allocator handling small allocations. - inline void - IncrementStats(size_t alloc_size) EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + inline void IncrementStats(size_t alloc_size) + EXCLUSIVE_LOCKS_REQUIRED(mutex_) { ++stats_.num_allocs; stats_.bytes_in_use += alloc_size; - stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, - stats_.bytes_in_use); - stats_.max_alloc_size = std::max(alloc_size, - static_cast(stats_.max_alloc_size)); + stats_.max_bytes_in_use = + std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); + stats_.max_alloc_size = + std::max(alloc_size, static_cast(stats_.max_alloc_size)); } // Decrement statistics for the allocator handling small allocations. - inline void - DecrementStats(size_t dealloc_size) EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + inline void DecrementStats(size_t dealloc_size) + EXCLUSIVE_LOCKS_REQUIRED(mutex_) { stats_.bytes_in_use -= dealloc_size; } @@ -225,10 +225,10 @@ class MklCPUAllocator : public VisitableAllocator { // SubAllocator is owned by BFCAllocator, so we do not need to deallocate // it in MklSmallSizeAllocator. - small_size_allocator_ = new MklSmallSizeAllocator(sub_allocator_, - max_mem_bytes, kName); - large_size_allocator_ = new BFCAllocator(sub_allocator_, max_mem_bytes, - kAllowGrowth, kName); + small_size_allocator_ = + new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); + large_size_allocator_ = + new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName); #ifndef INTEL_MKL_DNN_ONLY // For redirecting all allocations from MKL to this allocator // From: http://software.intel.com/en-us/node/528565 @@ -247,9 +247,9 @@ class MklCPUAllocator : public VisitableAllocator { // otherwise call large-size allocator (BFC). We found that BFC allocator // does not deliver good performance for small allocations when // inter_op_parallelism_threads is high. - return (num_bytes < kSmallAllocationsThreshold) ? - small_size_allocator_->AllocateRaw(alignment, num_bytes) : - large_size_allocator_->AllocateRaw(alignment, num_bytes); + return (num_bytes < kSmallAllocationsThreshold) + ? small_size_allocator_->AllocateRaw(alignment, num_bytes) + : large_size_allocator_->AllocateRaw(alignment, num_bytes); } inline void DeallocateRaw(void* ptr) override { @@ -270,8 +270,8 @@ class MklCPUAllocator : public VisitableAllocator { // Combine statistics from small-size and large-size allocator. stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs; stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use; - stats->max_bytes_in_use = l_stats.max_bytes_in_use + - s_stats.max_bytes_in_use; + stats->max_bytes_in_use = + l_stats.max_bytes_in_use + s_stats.max_bytes_in_use; // Since small-size allocations go to MklSmallSizeAllocator, // max_alloc_size from large_size_allocator would be the maximum @@ -311,14 +311,14 @@ class MklCPUAllocator : public VisitableAllocator { Status s = Status(error::Code::UNIMPLEMENTED, "Unimplemented case for hooking MKL function."); TF_CHECK_OK(s); // way to assert with an error message - return nullptr; // return a value and make static code analyzers happy + return nullptr; // return a value and make static code analyzers happy } static inline void* ReallocHook(void* ptr, size_t size) { Status s = Status(error::Code::UNIMPLEMENTED, "Unimplemented case for hooking MKL function."); TF_CHECK_OK(s); // way to assert with an error message - return nullptr; // return a value and make static code analyzers happy + return nullptr; // return a value and make static code analyzers happy } // Do we allow growth in BFC Allocator @@ -330,7 +330,7 @@ class MklCPUAllocator : public VisitableAllocator { // The alignment that we need for the allocations static constexpr const size_t kAlignment = 64; - VisitableAllocator* large_size_allocator_; // owned by this class + VisitableAllocator* large_size_allocator_; // owned by this class MklSmallSizeAllocator* small_size_allocator_; // owned by this class. SubAllocator* sub_allocator_; // not owned by this class @@ -338,7 +338,7 @@ class MklCPUAllocator : public VisitableAllocator { // Size in bytes that defines the upper-bound for "small" allocations. // Any allocation below this threshold is "small" allocation. static constexpr const size_t kSmallAllocationsThreshold = 4096; - + // Prevent copying and assignment TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator); }; -- cgit v1.2.3 From 380abf51677b180face81953ddf63676074d4de2 Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Thu, 6 Sep 2018 13:49:48 -0700 Subject: Fixing clang format error - v2 --- tensorflow/core/common_runtime/mkl_cpu_allocator.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 49f6695330..df9c3a686c 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -138,9 +138,9 @@ class MklSmallSizeAllocator : public VisitableAllocator { ++stats_.num_allocs; stats_.bytes_in_use += alloc_size; stats_.max_bytes_in_use = - std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); + std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); stats_.max_alloc_size = - std::max(alloc_size, static_cast(stats_.max_alloc_size)); + std::max(alloc_size, static_cast(stats_.max_alloc_size)); } // Decrement statistics for the allocator handling small allocations. @@ -226,9 +226,9 @@ class MklCPUAllocator : public VisitableAllocator { // SubAllocator is owned by BFCAllocator, so we do not need to deallocate // it in MklSmallSizeAllocator. small_size_allocator_ = - new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); + new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); large_size_allocator_ = - new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName); + new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName); #ifndef INTEL_MKL_DNN_ONLY // For redirecting all allocations from MKL to this allocator // From: http://software.intel.com/en-us/node/528565 @@ -248,8 +248,8 @@ class MklCPUAllocator : public VisitableAllocator { // does not deliver good performance for small allocations when // inter_op_parallelism_threads is high. return (num_bytes < kSmallAllocationsThreshold) - ? small_size_allocator_->AllocateRaw(alignment, num_bytes) - : large_size_allocator_->AllocateRaw(alignment, num_bytes); + ? small_size_allocator_->AllocateRaw(alignment, num_bytes) + : large_size_allocator_->AllocateRaw(alignment, num_bytes); } inline void DeallocateRaw(void* ptr) override { -- cgit v1.2.3 From dce54446805ca6be5b4ecd7d5226f2a80a0e9aa1 Mon Sep 17 00:00:00 2001 From: "Yan Facai (颜发才)" Date: Fri, 7 Sep 2018 07:44:43 +0800 Subject: TST: make scalar test cpu-only --- tensorflow/python/kernel_tests/broadcast_to_ops_test.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 8bcf27466c..bd2339f31d 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker @@ -83,12 +84,15 @@ class BroadcastToTest(test_util.TensorFlowTestCase): self.assertAllEqual(shape, v_np.shape) def testGradientForScalar(self): - x = constant_op.constant(1, dtype=dtypes.float32) - v = array_ops.broadcast_to(x, [2, 4, 3]) - out = 2 * v - with self.test_session(): - err = gradient_checker.compute_gradient_error(x, x.get_shape(), - out, out.get_shape()) + # TODO(alextp): There is a bug with broadcast_to on GPU from scalars, + # hence we make this test cpu-only. + with ops.device("cpu:0"): + x = constant_op.constant(1, dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) self.assertLess(err, 1e-4) def testGradientWithSameRank(self): -- cgit v1.2.3 From 2007f9752e116c46cb82c08a54f5c5e711a7c59d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 7 Sep 2018 03:44:47 +0000 Subject: Fix NoneType error in tf.nn.depthwise_conv2d with unknown shape This fix tries to address the issue raised in 22110 where tf.nn.depthwise_conv2d thowns out NoneType error when the input shape is unknown. This fix fixes 22110. Signed-off-by: Yong Tang --- tensorflow/python/ops/nn_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index ef9afd9e8e..2526e6fee2 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -510,7 +510,7 @@ class _WithSpaceToBatch(object): # Recover channel information for output shape if channels are not last. if self.data_format is not None and self.data_format.startswith("NC"): - if not result_converted.shape[1].value: + if not result_converted.shape[1].value and filter is not None: output_shape = result_converted.shape.as_list() output_shape[1] = filter.shape[-1] result_converted.set_shape(output_shape) -- cgit v1.2.3 From 991ba4b385fb57fabd9947de0e2006db8b32e54f Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 7 Sep 2018 03:46:17 +0000 Subject: Add test case for tf.nn.depthwise_conv2d with unkown input shape Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/depthwise_conv_op_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index 58845552db..0c049bd8ab 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -205,6 +205,14 @@ class DepthwiseConv2DTest(test.TestCase): use_gpu=True, grouped_conv=True) + def testDepthwiseConv2DWithUnknownShape(self): + # GitHub issue 22110. + with self.test_session(use_gpu=True): + x = array_ops.placeholder(dtypes.float32) + f = np.ones([1, 1, 1, 1], np.float32) + v = nn_impl.depthwise_conv2d(x, f, [1, 1, 1, 1], "VALID", rate=[2, 1], data_format="NCHW") + self.assertAllEqual(np.ones([1, 1, 1, 1], np.float32), v.eval(feed_dict={x: np.ones([1, 1, 1, 1], np.float32)})) + def testDepthwiseConv2DFormat(self): if not test.is_gpu_available(): return -- cgit v1.2.3 From 173c26a684938b06785e19e68c7ea9b86f5ab34c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 7 Sep 2018 03:57:16 +0000 Subject: Pylint fix Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/depthwise_conv_op_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index 0c049bd8ab..59674eb3a1 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -210,8 +210,11 @@ class DepthwiseConv2DTest(test.TestCase): with self.test_session(use_gpu=True): x = array_ops.placeholder(dtypes.float32) f = np.ones([1, 1, 1, 1], np.float32) - v = nn_impl.depthwise_conv2d(x, f, [1, 1, 1, 1], "VALID", rate=[2, 1], data_format="NCHW") - self.assertAllEqual(np.ones([1, 1, 1, 1], np.float32), v.eval(feed_dict={x: np.ones([1, 1, 1, 1], np.float32)})) + v = nn_impl.depthwise_conv2d( + x, f, [1, 1, 1, 1], "VALID", rate=[2, 1], data_format="NCHW") + self.assertAllEqual( + np.ones([1, 1, 1, 1], np.float32), + v.eval(feed_dict={x: np.ones([1, 1, 1, 1], np.float32)})) def testDepthwiseConv2DFormat(self): if not test.is_gpu_available(): -- cgit v1.2.3 From f4fc839fb279522d139622e6a52c14021318326d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 7 Sep 2018 04:01:10 +0000 Subject: Only enable test if gpu is available (NCHW does not have the CPU implementation) Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/depthwise_conv_op_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index 59674eb3a1..5741f2ec64 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -207,6 +207,8 @@ class DepthwiseConv2DTest(test.TestCase): def testDepthwiseConv2DWithUnknownShape(self): # GitHub issue 22110. + if not test.is_gpu_available(): + return with self.test_session(use_gpu=True): x = array_ops.placeholder(dtypes.float32) f = np.ones([1, 1, 1, 1], np.float32) -- cgit v1.2.3 From a11cb4cb1500f35266667d9f72b0a0534f2d1581 Mon Sep 17 00:00:00 2001 From: BY Shen Date: Fri, 7 Sep 2018 22:20:37 +0800 Subject: Fix a bug in TF_LITE_ENSURE_OK. --- tensorflow/contrib/lite/context.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index b23183b743..58977b5c47 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -148,7 +148,7 @@ void TfLiteIntArrayFree(TfLiteIntArray* v); #define TF_LITE_ENSURE_OK(context, status) \ do { \ if ((status) != kTfLiteOk) { \ - return status; \ + return kTfLiteError; \ } \ } while (0) -- cgit v1.2.3 From aec495d6acdbdfac97ce91dd0782eb88e307c055 Mon Sep 17 00:00:00 2001 From: pengwa Date: Sat, 8 Sep 2018 11:20:23 +0800 Subject: add more ValueError description in dynamic_rnn document --- tensorflow/python/ops/rnn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 5c00d929bf..4f3d8c2318 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -709,6 +709,10 @@ def _dynamic_rnn_loop(cell, Raises: ValueError: If the input depth cannot be inferred via shape inference from the inputs. + ValueError: If time is not the same for all the elements in the + input. + ValueError: If batch_size is not the same for all the elements + in the input. """ state = initial_state assert isinstance(parallel_iterations, int), "parallel_iterations must be int" -- cgit v1.2.3 From bfead6061a6f10c5a3e5d05f8a946443fb9a3218 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 8 Sep 2018 02:01:31 -0700 Subject: compat: Update forward compatibility horizon to 2018-09-08 PiperOrigin-RevId: 212097666 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 7a3fc27592..ca72cbac1a 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 7) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 8) @tf_export("compat.forward_compatible") -- cgit v1.2.3 From f40c960fff788b6770b9b4015734e54604f7481b Mon Sep 17 00:00:00 2001 From: Jonathan Homer Date: Sat, 8 Sep 2018 13:52:04 +0100 Subject: Changed PWD to pwd for bash examples Shell command PWD should be lowercase pwd for it work correct. Obvious typo corrected. --- tensorflow/tools/dockerfiles/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md index d64db35afb..5996573cf1 100644 --- a/tensorflow/tools/dockerfiles/README.md +++ b/tensorflow/tools/dockerfiles/README.md @@ -34,13 +34,13 @@ documentation](https://docs.docker.com/engine/reference/run/). # User permissions (-u) are required if you use (-v). # CPU-based images -$ docker run -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf +$ docker run -u $(id -u):$(id -g) -v $(pwd):/my-devel -it tf # GPU-based images (set up nvidia-docker2 first) -$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf +$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(pwd):/my-devel -it tf # Images with Jupyter run on port 8888, and needs a volume for notebooks -$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(PWD):/notebooks -it tf +$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(pwd):/notebooks -it tf ``` These images do not come with the TensorFlow source code -- but the development -- cgit v1.2.3 From 40037223b33fcdf178509ba5ece4ba33425c4627 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Sat, 8 Sep 2018 09:19:22 -0700 Subject: Automated rollback of commit 0065d3389a63a529469dc71e950c66da2ebdbc24 PiperOrigin-RevId: 212119629 --- tensorflow/contrib/lite/experimental/writer/BUILD | 66 ++++ .../lite/experimental/writer/enum_mapping.h | 116 +++++++ .../experimental/writer/option_writer_generator.cc | 370 +++++++++++++++++++++ .../contrib/lite/experimental/writer/writer.cc | 41 +++ .../contrib/lite/experimental/writer/writer_lib.cc | 281 ++++++++++++++++ .../contrib/lite/experimental/writer/writer_lib.h | 126 +++++++ .../lite/experimental/writer/writer_lib_test.cc | 62 ++++ tensorflow/contrib/lite/schema/BUILD | 14 + third_party/flatbuffers/BUILD.bazel | 1 + third_party/flatbuffers/build_defs.bzl | 19 +- 10 files changed, 1088 insertions(+), 8 deletions(-) create mode 100644 tensorflow/contrib/lite/experimental/writer/BUILD create mode 100644 tensorflow/contrib/lite/experimental/writer/enum_mapping.h create mode 100644 tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc create mode 100644 tensorflow/contrib/lite/experimental/writer/writer.cc create mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib.cc create mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib.h create mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD new file mode 100644 index 0000000000..82d39c00ab --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/BUILD @@ -0,0 +1,66 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +cc_binary( + name = "option_writer_generator", + srcs = ["option_writer_generator.cc"], + deps = [ + "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", + "@flatbuffers", + ], +) + +cc_library( + name = "writer_lib", + srcs = [ + "enum_mapping.h", + "writer_lib.cc", + ], + hdrs = [ + "writer_lib.h", + ], + data = [ + ":option_writer_gen", + ], + textual_hdrs = ["option_writer_generated.h"], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", + ], +) + +cc_binary( + name = "writer", + srcs = ["writer.cc"], + deps = [ + ":writer_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_test( + name = "writer_lib_test", + size = "small", + srcs = ["writer_lib_test.cc"], + deps = [ + ":writer_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +genrule( + name = "option_writer_gen", + outs = ["option_writer_generated.h"], + cmd = "$(location :option_writer_generator) $(@)", + tools = [":option_writer_generator"], +) diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h new file mode 100644 index 0000000000..8bc464fd71 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" + +// TODO(aselle): Ideally extract this from the schema. + +namespace tflite { + +inline ActivationFunctionType TfLiteActivationToSchemaActivation( + TfLiteFusedActivation act) { + switch (act) { + case kTfLiteActNone: + return ActivationFunctionType_NONE; + case kTfLiteActRelu: + return ActivationFunctionType_RELU; + case kTfLiteActRelu1: + return ActivationFunctionType_RELU_N1_TO_1; + case kTfLiteActRelu6: + return ActivationFunctionType_RELU6; + case kTfLiteActTanh: + return ActivationFunctionType_TANH; + case kTfLiteActSignBit: + return ActivationFunctionType_SIGN_BIT; + case kTfLiteActSigmoid: + return ActivationFunctionType_NONE; // TODO(aselle): Add to schema + } + return ActivationFunctionType_NONE; +} + +inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) { + switch (padding) { + case kTfLitePaddingUnknown: + return Padding_SAME; // TODO(aselle): Consider an error. + case kTfLitePaddingSame: + return Padding_SAME; + case kTfLitePaddingValid: + return Padding_VALID; + } + return Padding_SAME; // TODO(aselle): Consider an error. +} + +inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { + switch (type) { + // case kTfLiteNoType: return TensorType_NONE; + case kTfLiteNoType: + return TensorType_FLOAT32; // TODO(aselle): Consider an error. + case kTfLiteFloat32: + return TensorType_FLOAT32; + case kTfLiteInt32: + return TensorType_INT32; + case kTfLiteUInt8: + return TensorType_UINT8; + case kTfLiteInt64: + return TensorType_INT64; + case kTfLiteString: + return TensorType_STRING; + case kTfLiteBool: + return TensorType_BOOL; + case kTfLiteInt16: + return TensorType_INT16; + case kTfLiteComplex64: + return TensorType_COMPLEX64; + } + // TODO(aselle): consider an error +} + +inline FullyConnectedOptionsWeightsFormat +FullyConnectedOptionsWeightsFormatToSchema( + TfLiteFullyConnectedWeightsFormat format) { + switch (format) { + case kTfLiteFullyConnectedWeightsFormatDefault: + return FullyConnectedOptionsWeightsFormat_DEFAULT; + case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8: + return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; + } +} + +inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) { + switch (type) { + case kTfLiteLSTMFullKernel: + return LSTMKernelType_FULL; + case kTfLiteLSTMBasicKernel: + return LSTMKernelType_BASIC; + } +} + +inline LSHProjectionType LSHProjectionTypeToSchema( + TfLiteLSHProjectionType type) { + switch (type) { + case kTfLiteLshProjectionUnknown: + return LSHProjectionType_UNKNOWN; + case kTfLiteLshProjectionSparse: + return LSHProjectionType_SPARSE; + case kTfLiteLshProjectionDense: + return LSHProjectionType_DENSE; + } +} + +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc new file mode 100644 index 0000000000..e6d5a776b3 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc @@ -0,0 +1,370 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include "flatbuffers/minireflect.h" // flatbuffers +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" + +namespace tflite { +namespace { +// This is generated by grepping +// cat third_party/tensorflow/contrib/lite/builtin_op_data.h +//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}" +static const char* param_structs[] = {"TfLiteConvParams", + "TfLitePoolParams", + "TfLiteDepthwiseConvParams", + "TfLiteSVDFParams", + "TfLiteRNNParams", + "TfLiteSequenceRNNParams", + "TfLiteFullyConnectedParams", + "TfLiteLSHProjectionParams", + "TfLiteSoftmaxParams", + "TfLiteConcatenationParams", + "TfLiteAddParams", + "TfLiteSpaceToBatchNDParams", + "TfLiteBatchToSpaceNDParams", + "TfLiteMulParams", + "TfLiteSubParams", + "TfLiteDivParams", + "TfLiteL2NormParams", + "TfLiteLocalResponseNormParams", + "TfLiteLSTMParams", + "TfLiteResizeBilinearParams", + "TfLitePadParams", + "TfLitePadV2Params", + "TfLiteReshapeParams", + "TfLiteSkipGramParams", + "TfLiteSpaceToDepthParams", + "TfLiteCastParams", + "TfLiteEmbeddingLookupSparseParams", + "TfLiteGatherParams", + "TfLiteTransposeParams", + "TfLiteReducerParams", + "TfLiteSplitParams", + "TfLiteSqueezeParams", + "TfLiteStridedSliceParams", + "TfLiteArgMaxParams", + "TfLiteArgMinParams", + "TfLiteTransposeConvParams", + "TfLiteSparseToDenseParams", + "TfLiteShapeParams", + "TfLiteFakeQuantParams", + "TfLitePackParams", + "TfLiteOneHotParams", + nullptr}; +} // namespace + +// Get rid of all underscores and make everything lower case to make name +// matching work for stuff like 3D vs 3d or RNN vs Rnn. +std::string ToCollapsed(const std::string& in) { + const char* s = in.c_str(); + bool first = true; + std::string out; + while (*s != '\0') { + if (*s == '_') { + first = true; + } else if (first) { + out.push_back(tolower(*s)); + first = false; + } else { + out.push_back(tolower(*s)); + } + s++; + } + return out; +} + +// A collection of information about builtin ops. +class OpOptionData { + public: + OpOptionData() { + BuildOpList(); + BuildOptionToTypeFunctionMap(); + BuildOpToOptionMap(); + } + + // A list of builtin operations + const std::vector& ops() const { return ops_; } + // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions') + const std::unordered_map& op_to_option() { + return op_to_option_; + } + // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions' + const std::unordered_map& option_to_struct() { + return option_to_struct_; + } + // Maps from option to a flatbuffer type function that describes that option. + const std::unordered_map& + option_to_type_function() { + return option_to_type_function_; + } + + private: + void BuildOpList() { + for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr; + ++curr) { + if (strlen(*curr) != 0) ops_.push_back(*curr); + } + } + + void BuildOptionToTypeFunctionMap() { + auto d = tflite::BuiltinOptionsTypeTable(); + for (int i = 0; i < d->num_elems; i++) { + flatbuffers::TypeCode code = d->type_codes[i]; + if (code.sequence_ref != -1) { + option_to_type_function_.insert( + std::make_pair(d->names[i], d->type_refs[code.sequence_ref])); + } + } + } + + void BuildOpToOptionMap() { + // Manually specified mappings between ops and options + op_to_option_["REDUCE_MAX"] = "ReducerOptions"; + op_to_option_["REDUCE_MIN"] = "ReducerOptions"; + op_to_option_["REDUCE_ANY"] = "ReducerOptions"; + op_to_option_["UNPACK"] = ""; + op_to_option_["SUM"] = "ReducerOptions"; + op_to_option_["REDUCE_MAX"] = "ReducerOptions"; + op_to_option_["REDUCE_PROD"] = "ReducerOptions"; + op_to_option_["MEAN"] = "ReducerOptions"; + op_to_option_["L2_POOL_2D"] = "Pool2DOptions"; + op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions"; + op_to_option_["MAX_POOL_2D"] = "Pool2DOptions"; + op_to_option_["L2_NORMALIZATION"] = "L2NormOptions"; + op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; + op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + // Manually specified mappings between ops and options (none) + op_to_option_["EMBEDDING_LOOKUP"] = + ""; // TODO(aselle): maybe something else. + op_to_option_["FLOOR"] = ""; + op_to_option_["HASHTABLE_LOOKUP"] = + ""; // TODO(aselle): maybe something else. + op_to_option_["LOGISTIC"] = ""; + op_to_option_["RELU"] = ""; + op_to_option_["RELU_N1_TO_1"] = ""; + op_to_option_["RELU6"] = ""; + op_to_option_["TANH"] = ""; + op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else. + op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else. + op_to_option_["PRELU"] = ""; + op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions + op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions + op_to_option_["SIN"] = ""; + op_to_option_["LOG"] = ""; + op_to_option_["SQRT"] = ""; + op_to_option_["RSQRT"] = ""; + + // TODO(aselle): These are undesirable hacks. Consider changing C structs + option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; + option_to_struct_["Conv2DOptions"] = "TfLiteConvParams"; + option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams"; + option_to_struct_["LocalResponseNormalizationOptions"] = + "TfLiteLocalResponseNormParams"; + // Now for every op, try to find an option. + bool fatal = false; + for (auto op_name : ops_) { + bool found_option = false; + auto d = tflite::BuiltinOptionsTypeTable(); + std::string collapsed_option_name_guess = + ToCollapsed(op_name) + "options"; + // O(n^2) but not that big of n. + for (int i = 0; i < d->num_elems; i++) { + std::string option_name = d->names[i]; + std::string collapsed_option_name = ToCollapsed(option_name); + if (collapsed_option_name_guess == collapsed_option_name) { + op_to_option_.insert(std::make_pair(op_name, option_name)); + found_option = true; + break; + } + } + auto it = op_to_option_.find(op_name); + if (it == op_to_option_.end()) { + std::cerr << "Didn't find option for " << op_name << std::endl; + fatal = true; + } else if (!it->second.empty()) { + std::string option_name = it->second; + + if (option_to_struct_.find(option_name) == option_to_struct_.end()) { + bool param_struct_found = false; + std::string params_guess = std::string("TfLite") + option_name; + size_t start = params_guess.find("Options"); + size_t len = strlen("Options"); + params_guess.replace(start, len, "Params"); + for (auto* param = param_structs; *param != nullptr; param++) { + if (*param == params_guess) { + param_struct_found = true; + break; + } + } + if (!param_struct_found) { + std::cerr << "Failed to get param struct for option " << option_name + << std::endl; + fatal = true; + } else { + option_to_struct_.insert(std::make_pair(option_name, params_guess)); + } + } + } + } + } + + private: + std::vector ops_; + std::unordered_map op_to_option_; + std::unordered_map option_to_struct_; + std::unordered_map + option_to_type_function_; +}; + +void GenerateImportForOp(FILE* fp, const std::string& op_name, + const std::string& option_name, + const std::string& option_type, + const flatbuffers::TypeTable* options, + const std::string& struct_name) { + // Skip tricky ones for now + if (struct_name == "TfLiteResizeBilinearParams") return; + if (struct_name == "TfLiteSqueezeParams") return; + if (struct_name == "TfLiteEmbeddingLookupSparseParams") return; + if (struct_name == "TfLiteReshapeParams") return; + + fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str()); + fprintf(fp, + " const auto* params = reinterpret_cast(builtin_op_data);\n", + struct_name.c_str()); + + for (size_t i = 0; i < options->num_elems; i++) { + std::string elem_name = options->names[i]; + // TODO(aselle): Irregular naming in builtins + if (elem_name == "fused_activation_function") + elem_name = "activation"; + else if (elem_name == "stride_w") + elem_name = "stride_width"; + else if (elem_name == "stride_h") + elem_name = "stride_height"; + else if (elem_name == "dilation_h_factor") + elem_name = "dilation_height_factor"; + else if (elem_name == "dilation_w_factor") + elem_name = "dilation_width_factor"; + else if (elem_name == "new_shape") + elem_name = "shape"; + + flatbuffers::TypeCode code = options->type_codes[i]; + auto contained_type = code.sequence_ref != -1 + ? options->type_refs[code.sequence_ref] + : nullptr; + std::string mapper = ""; + if (contained_type == TensorTypeTypeTable) { + mapper = "TfLiteTypeToSchemaType"; + } else if (contained_type == ActivationFunctionTypeTypeTable) { + mapper = "TfLiteActivationToSchemaActivation"; + } else if (contained_type == PaddingTypeTable) { + mapper = "TfLitePaddingToSchemaPadding"; + } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) { + mapper = "FullyConnectedOptionsWeightsFormatToSchema"; + } else if (contained_type == LSTMKernelTypeTypeTable) { + mapper = "LSTMKernelTypeToSchema"; + } else if (contained_type == LSHProjectionTypeTypeTable) { + mapper = "LSHProjectionTypeToSchema"; + } + + fprintf(fp, + " auto val%zu = " + "%s(params->%s);\n", + i, mapper.c_str(), elem_name.c_str()); + } + fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str()); + for (size_t i = 0; i < options->num_elems; i++) { + fprintf(fp, ", val%zu", i); + } + fprintf(fp, ").Union();\n"); + fprintf(fp, " return std::make_pair(%s, union_type);\n", + option_type.c_str()); + fprintf(fp, " }\n break;\n"); +} + +void GenerateImport(OpOptionData* option, FILE* fp) { + std::unordered_set ignores; + ignores.insert("CONCAT_EMBEDDINGS"); + ignores.insert("CALL"); + + // Allow any op that doesn't have an options struct to be blocked + // together + for (const auto& op_name : option->ops()) { + auto option_it = option->op_to_option().find(op_name); + if (!option_it->second.empty() && ignores.find(op_name) == ignores.end()) + continue; + fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); + } + fprintf(fp, + " return std::make_pair(BuiltinOptions_NONE, " + "flatbuffers::Offset());\n break;\n"); + + // Iterate over each ops + for (const auto& op_name : option->ops()) { + if (ignores.find(op_name) != ignores.end()) continue; + // Get to the option and struct names, continuing if not found. + auto option_it = option->op_to_option().find(op_name); + if (option_it->second.empty()) continue; + std::string option_name = option_it->second; + std::string option_type = "BuiltinOptions_" + option_name; + auto option_func_it = option->option_to_type_function().find(option_name); + if (option_func_it == option->option_to_type_function().end()) continue; + auto struct_name_it = option->option_to_struct().find(option_name); + if (struct_name_it == option->option_to_struct().end()) { + // If no C struct, then it better have no arguments. + auto type_info = option_func_it->second(); + if (type_info->num_elems != 0) { + // We have non-zero arguments in the schema, this means there + // should be a struct. + fprintf(stderr, + "Op %s uses option struct %s which has no builtin struct\n", + op_name.c_str(), option_name.c_str()); + exit(1); + } + fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); + fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());", + option_type.c_str(), option_name.c_str()); + } else { + // If C struct, then we need to assign all properties + auto struct_name = struct_name_it->second; + GenerateImportForOp(fp, op_name, option_name, option_type, + option_func_it->second(), struct_name); + } + } + // TODO(aselle): Handle unhandled cases more gracefully. + fprintf(fp, + "default: return std::make_pair(BuiltinOptions_NONE, " + "flatbuffers::Offset());\n break;\n"); +} + +} // namespace tflite + +int main(int argc, char* argv[]) { + tflite::OpOptionData option; + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + FILE* fp = fopen(argv[1], "w"); + tflite::GenerateImport(&option, fp); + fclose(fp); +} diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc new file mode 100644 index 0000000000..20ede214fb --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Just does a read/write loop of tflite file format using the interpreter as +// an intermediate. +// +// Usage: +// writer + +#include + +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +int main(int argc, char* argv[]) { + if (argc != 3) { + fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]); + return 1; + } + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromFile(argv[1]); + std::unique_ptr interpreter; + tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; + tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); + tflite::InterpreterWriter writer(interpreter.get()); + writer.Write(argv[2]); + + return 0; +} diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc new file mode 100644 index 0000000000..52b17faf82 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc @@ -0,0 +1,281 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { +template +using Offset = flatbuffers::Offset; +template +using Vector = flatbuffers::Vector; +using FlatBufferBuilder = flatbuffers::FlatBufferBuilder; + +std::pair> CreateBuiltinUnion( + FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) { + switch (op) { +#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h" + } + return std::make_pair(BuiltinOptions_NONE, Offset()); +} + +template +Offset> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb, + const T_INPUT& v) { + std::vector inputs(v.begin(), v.end()); + return fbb->template CreateVector(inputs); +} + +Offset>> InterpreterWriter::ExportOperators( + FlatBufferBuilder* fbb) { + std::vector> operators; + + std::vector operator_to_opcode; + // TODO(aselle): Augment this once we put execution plan in schema. + operator_to_opcode.resize(interpreter_->nodes_size(), -1); + for (int op_index : interpreter_->execution_plan()) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + const TfLiteRegistration* registration = &node_and_registration->second; + if (!registration->custom_name) { + operator_to_opcode[op_index] = + GetOpCodeForBuiltin(registration->builtin_code); + } else { + operator_to_opcode[op_index] = + GetOpCodeForCustom(registration->custom_name); + } + } + // second pass serialize operators + for (int op_index : interpreter_->execution_plan()) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + const TfLiteNode& node = node_and_registration->first; + const TfLiteRegistration& registration = node_and_registration->second; + Offset builtin_options; + BuiltinOptions builtin_options_type = BuiltinOptions_NONE; + // Custom data + // TODO(aselle): Custom options format is not known by default. Just assume + // for now. + auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS; + Offset> custom_options = 0; + + if (!registration.custom_name) { + // builtin + auto builtin_options_and_type = CreateBuiltinUnion( + fbb, static_cast(registration.builtin_code), + node.builtin_data); + builtin_options = builtin_options_and_type.second; + builtin_options_type = builtin_options_and_type.first; + } else { + auto custom_writer = custom_op_to_writer_.find(registration.custom_name); + if (custom_writer != custom_op_to_writer_.end() && + custom_writer->second) { + // delegate to custom writer if it exists + custom_writer->second(fbb, interpreter_, op_index, &custom_options, + &custom_options_format); + } else { + // use the custom data as fact + custom_options = fbb->CreateVector( + reinterpret_cast(node.custom_initial_data), + node.custom_initial_data_size); + } + } + + int opcode_index = operator_to_opcode[op_index]; + std::vector written_inputs = + RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs)); + std::vector written_outputs = + RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs)); + auto inputs = ExportVector(fbb, written_inputs); + auto outputs = ExportVector(fbb, written_outputs); + operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs, + builtin_options_type, builtin_options, + custom_options, custom_options_format)); + } + + return fbb->template CreateVector>(operators); +} + +Offset>> InterpreterWriter::ExportTensors( + FlatBufferBuilder* fbb) { + tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1); + + std::vector> tensors; + + // Make a map from tensor index to whether the tensor is a temporary. + std::vector tensor_is_temporary(interpreter_->tensors_size(), false); + for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + for (auto tensor_index : + TfLiteIntArrayView(node_and_registration->first.temporaries)) + tensor_is_temporary[tensor_index] = true; + } + + // Now we need to remap all used tensor indices + int curr_output_index = 0; + for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); + tensor_index++) { + if (!tensor_is_temporary[tensor_index]) { + tensor_to_written_tensor_[tensor_index] = curr_output_index++; + } + } + + for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); + ++tensor_index) { + // Skip temporaries. + if (tensor_is_temporary[tensor_index]) continue; + + if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) { + // We only need to convert non temporaries + if (tensor->allocation_type != kTfLiteArenaRw && + tensor->allocation_type != kTfLiteMmapRo && + tensor->allocation_type != kTfLiteArenaRwPersistent) + continue; + // Allocate a buffer index + int buffer_index = 0; // This is null + if (tensor->allocation_type == kTfLiteMmapRo) { + buffer_index = buffers_.size(); + buffers_.push_back(std::make_pair( + reinterpret_cast(tensor->data.raw), tensor->bytes)); + } + // Primitive type. + TensorType type = TfLiteTypeToSchemaType(tensor->type); + // Handle quantization + const Offset> null_array; + Offset> scale_array; + Offset> zero_point_array; + if (tensor->params.scale != 0.f) { + // We have quantization, make a single arugment array (multi channel + // quant needs updating here). + scale_array = fbb->CreateVector({tensor->params.scale}); + zero_point_array = + fbb->CreateVector({tensor->params.zero_point}); + } + Offset quantization_params = + CreateQuantizationParameters(*fbb, null_array, null_array, + scale_array, zero_point_array); + // Shape + TfLiteIntArrayView shape_view(tensor->dims); + std::vector shape = + std::vector(shape_view.begin(), shape_view.end()); + + tensors.push_back(CreateTensor(*fbb, ExportVector(fbb, shape), + type, buffer_index, + fbb->CreateString(tensor->name), + quantization_params, tensor->is_variable)); + } + } + return fbb->template CreateVector>(tensors); +} + +Offset>> InterpreterWriter::ExportBuffers( + FlatBufferBuilder* fbb) { + std::vector> buffer_vector; + for (auto buffer : buffers_) { + auto data_offset = fbb->CreateVector(buffer.first, buffer.second); + buffer_vector.push_back(CreateBuffer(*fbb, data_offset)); + } + return fbb->template CreateVector>(buffer_vector); +} + +Offset>> InterpreterWriter::CreateOpCodeTable( + FlatBufferBuilder* fbb) { + std::vector> codes; + for (auto it : opcodes_) { + const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str(); + codes.push_back(CreateOperatorCodeDirect( + *fbb, static_cast(it.builtin), custom_name)); + } + return fbb->template CreateVector>(codes); +} + +template +std::vector InterpreterWriter::RemapTensorIndicesToWritten( + const T& input) { + std::vector output; + output.reserve(input.size()); + for (int x : input) { + output.push_back(tensor_to_written_tensor_[x]); + } + return output; +} + +TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr* out, + size_t* size) { + if (!out || !size) return kTfLiteError; + FlatBufferBuilder builder(/*initial_size=*/10240); + + std::vector> subgraphs_as_vector; + { // subgraph specific stuff + auto tensors = ExportTensors(&builder); + std::vector written_inputs = + RemapTensorIndicesToWritten(interpreter_->inputs()); + std::vector written_outputs = + RemapTensorIndicesToWritten(interpreter_->outputs()); + auto inputs = ExportVector(&builder, written_inputs); + auto outputs = ExportVector(&builder, written_outputs); + + auto ops = ExportOperators(&builder); + subgraphs_as_vector.push_back( + CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0)); + } + Offset>> buffers = ExportBuffers(&builder); + + auto description = builder.CreateString("Exported from Interpreter."); + + auto op_codes = CreateOpCodeTable(&builder); + auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes, + builder.CreateVector(subgraphs_as_vector), + description, buffers); + ::tflite::FinishModelBuffer(builder, model); + const uint8_t* buffer = builder.GetBufferPointer(); + *size = builder.GetSize(); + (*out).reset(new uint8_t[*size]); + memcpy(out->get(), buffer, *size); + return kTfLiteOk; +} + +TfLiteStatus InterpreterWriter::Write(const std::string& filename) { + std::unique_ptr buffer; + size_t size; + TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size)); + + FILE* fp = fopen(filename.c_str(), "wb"); + if (!fp) return kTfLiteError; + + if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError; + if (fclose(fp)) return kTfLiteError; + + return kTfLiteOk; +} + +TfLiteStatus InterpreterWriter::RegisterCustomWriter( + const std::string& custom_name, CustomWriter custom_writer) { + if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) { + return kTfLiteError; + } + custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer)); + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h new file mode 100644 index 0000000000..a98108b496 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter. +// +// Usage: +// From command line: +// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer +// -- foo.tflite foo.out.tflite +// +// From C++ +// std::unique_ptr interpreter; +// // Build Interpreter however +// // ... +// InterpreterWriter(interpreter.get()).Write("output.tflite"); +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +// Handles writing TensorFlow Lite running interpreter to a serialized TF lite +// file format. +class InterpreterWriter { + public: + typedef flatbuffers::Offset (*CustomWriter)( + flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter, + int node_index, + flatbuffers::Offset>* output_options, + CustomOptionsFormat* custom_options_format); + + // Construct an interpreter writer for the specified `interpreter`. Then, + // a uses .Write() or .GetBuffer(...) to extract the data. + explicit InterpreterWriter(Interpreter* interpreter) + : interpreter_(interpreter) { + buffers_.push_back(std::make_pair(nullptr, 0)); + } + + // Get a buffer and size of a serialized flatbuffer. + TfLiteStatus GetBuffer(std::unique_ptr* out, size_t* size); + // Write the serialized flatbuffer to the prescribed `filename`. + TfLiteStatus Write(const std::string& filename); + // Registers a custom writer for a custom op. The customization allows the + // caller to change the custom data. + TfLiteStatus RegisterCustomWriter(const std::string& custom_name, + CustomWriter custom_writer); + + private: + template + using Offset = flatbuffers::Offset; + template + Offset> ExportVector( + flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v); + Offset>> ExportTensors( + flatbuffers::FlatBufferBuilder* fbb); + Offset>> ExportOperators( + flatbuffers::FlatBufferBuilder* fbb); + Offset>> CreateOpCodeTable( + flatbuffers::FlatBufferBuilder* fbb); + Offset>> ExportBuffers( + flatbuffers::FlatBufferBuilder* fbb); + + template + std::vector RemapTensorIndicesToWritten(const T& input); + + int GetOpCodeForBuiltin(int builtin_op_index) { + // auto it = builtin_op_to_opcode_.find(builtin_op_index); + std::pair result = + builtin_op_to_opcode_.insert( + std::make_pair(builtin_op_index, opcodes_.size())); + if (result.second) { + opcodes_.push_back({builtin_op_index, ""}); + } + return result.first->second; + } + + int GetOpCodeForCustom(const std::string& custom_name) { + std::pair result = + custom_op_to_opcode_.insert( + std::make_pair(custom_name, opcodes_.size())); + if (result.second) { + opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name}); + } + return result.first->second; + } + + // The interpreter we are writing + Interpreter* interpreter_; + // Keep track of byte buffers + std::vector> buffers_; + // List of op codes and mappings from builtin or custom op to opcode + struct OpCode { + int builtin; + std::string custom; + }; + // For every tensor index in the interpreter, the index in the written. + // This is different due to temporary tensors not being written. + std::vector tensor_to_written_tensor_; + // List of used opcodes + std::vector opcodes_; + std::unordered_map builtin_op_to_opcode_; + std::unordered_map custom_op_to_opcode_; + std::unordered_map custom_op_to_writer_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc new file mode 100644 index 0000000000..49194a76c8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +// Make an interpreter that has no tensors and no nodes +// TODO(b/113731921): add more tests. +TEST(Writer, BasicTest) { + Interpreter interpreter; + interpreter.AddTensors(3); + float foo[] = {1, 2, 3}; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3}, + TfLiteQuantizationParams()); + interpreter.SetTensorParametersReadOnly( + 1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(), + reinterpret_cast(foo), sizeof(foo)); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3}, + TfLiteQuantizationParams()); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({2}); + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteAddParams* builtin_data = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + builtin_data->activation = kTfLiteActNone; + const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1); + interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, + reinterpret_cast(builtin_data), reg); + + InterpreterWriter writer(&interpreter); + writer.Write("/tmp/test.tflite"); + std::unique_ptr model = + FlatBufferModel::BuildFromFile("/tmp/test.tflite"); + InterpreterBuilder builder(*model, resolver); + std::unique_ptr new_interpreter; + builder(&new_interpreter); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index 28a7e50003..55bf2c48b9 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -56,6 +56,20 @@ flatbuffer_cc_library( srcs = ["schema.fbs"], ) +# Generic schema for inference on device (but with reflections makes bigger). +flatbuffer_cc_library( + name = "schema_fbs_with_reflection", + srcs = ["schema.fbs"], + flatc_args = [ + "--reflect-types", + "--reflect-names", + "--no-union-value-namespacing", + "--gen-object-api", + ], + gen_reflections = True, + out_prefix = "reflection/", +) + # Schema test to make sure we don't introduce backward incompatible changes # to schemas. cc_test( diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel index 9d233a30d6..934c0d9650 100644 --- a/third_party/flatbuffers/BUILD.bazel +++ b/third_party/flatbuffers/BUILD.bazel @@ -142,6 +142,7 @@ filegroup( srcs = [ "include/flatbuffers/base.h", "include/flatbuffers/flatbuffers.h", + "include/flatbuffers/minireflect.h", "include/flatbuffers/stl_emulation.h", "include/flatbuffers/util.h", ], diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl index 2f25156668..235b44f7cf 100644 --- a/third_party/flatbuffers/build_defs.bzl +++ b/third_party/flatbuffers/build_defs.bzl @@ -92,14 +92,17 @@ def flatbuffer_library_public( cmd = reflection_genrule_cmd, message = "Generating flatbuffer reflection binary for %s:" % (name), ) - native.Fileset( - name = reflection_name, - out = "%s_out" % reflection_name, - entries = [ - native.FilesetEntry(files = reflection_outs), - ], - visibility = reflection_visiblity, - ) + # TODO(b/114456773): Make bazel rules proper and supported by flatbuffer + # Have to comment this since FilesetEntry is not supported in bazel + # skylark. + # native.Fileset( + # name = reflection_name, + # out = "%s_out" % reflection_name, + # entries = [ + # native.FilesetEntry(files = reflection_outs), + # ], + # visibility = reflection_visiblity, + # ) def flatbuffer_cc_library( name, -- cgit v1.2.3 From f04f67f58fc6a5823fc4a78bd068c76f69d9fdd2 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Sat, 8 Sep 2018 09:20:47 -0700 Subject: Sorting filenames in makefile lists alphabetically. PiperOrigin-RevId: 212119678 --- .../contrib/makefile/proto_text_cc_files.txt | 114 ++--- .../contrib/makefile/proto_text_pb_cc_files.txt | 74 +-- .../contrib/makefile/proto_text_pb_h_files.txt | 73 +-- tensorflow/contrib/makefile/tf_op_files.txt | 522 ++++++++++----------- tensorflow/contrib/makefile/tf_pb_text_files.txt | 56 +-- tensorflow/contrib/makefile/tf_proto_files.txt | 76 +-- 6 files changed, 458 insertions(+), 457 deletions(-) diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 7d26429f9c..b5c781ad76 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -1,62 +1,62 @@ -tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc -tensorflow/tools/proto_text/gen_proto_text_functions.cc tensorflow/core/framework/resource_handle.cc +tensorflow/core/lib/core/arena.cc +tensorflow/core/lib/core/coding.cc +tensorflow/core/lib/core/status.cc +tensorflow/core/lib/core/stringpiece.cc +tensorflow/core/lib/core/threadpool.cc +tensorflow/core/lib/hash/crc32c.cc +tensorflow/core/lib/hash/crc32c_accelerate.cc +tensorflow/core/lib/hash/hash.cc +tensorflow/core/lib/histogram/histogram.cc +tensorflow/core/lib/io/block.cc +tensorflow/core/lib/io/block_builder.cc +tensorflow/core/lib/io/buffered_inputstream.cc +tensorflow/core/lib/io/compression.cc +tensorflow/core/lib/io/format.cc +tensorflow/core/lib/io/inputbuffer.cc +tensorflow/core/lib/io/inputstream_interface.cc +tensorflow/core/lib/io/iterator.cc +tensorflow/core/lib/io/path.cc +tensorflow/core/lib/io/random_inputstream.cc +tensorflow/core/lib/io/record_reader.cc +tensorflow/core/lib/io/record_writer.cc +tensorflow/core/lib/io/table.cc +tensorflow/core/lib/io/table_builder.cc +tensorflow/core/lib/io/two_level_iterator.cc +tensorflow/core/lib/io/zlib_compression_options.cc +tensorflow/core/lib/io/zlib_inputstream.cc +tensorflow/core/lib/io/zlib_outputbuffer.cc +tensorflow/core/lib/random/distribution_sampler.cc +tensorflow/core/lib/random/random.cc +tensorflow/core/lib/random/simple_philox.cc +tensorflow/core/lib/random/weighted_picker.cc +tensorflow/core/lib/strings/numbers.cc +tensorflow/core/lib/strings/ordered_code.cc +tensorflow/core/lib/strings/proto_text_util.cc +tensorflow/core/lib/strings/scanner.cc +tensorflow/core/lib/strings/str_util.cc +tensorflow/core/lib/strings/strcat.cc +tensorflow/core/lib/strings/stringprintf.cc +tensorflow/core/lib/wav/wav_io.cc +tensorflow/core/platform/cpu_info.cc +tensorflow/core/platform/default/logging.cc +tensorflow/core/platform/default/mutex.cc tensorflow/core/platform/default/protobuf.cc -tensorflow/core/platform/tracing.cc -tensorflow/core/platform/tensor_coding.cc -tensorflow/core/platform/protobuf_util.cc -tensorflow/core/platform/posix/posix_file_system.cc -tensorflow/core/platform/posix/port.cc -tensorflow/core/platform/posix/error.cc -tensorflow/core/platform/posix/env.cc -tensorflow/core/platform/posix/load_library.cc -tensorflow/core/platform/posix/env_time.cc -tensorflow/core/platform/file_system.cc -tensorflow/core/platform/file_system_helper.cc +tensorflow/core/platform/default/tracing.cc +tensorflow/core/platform/denormal.cc tensorflow/core/platform/env.cc tensorflow/core/platform/env_time.cc +tensorflow/core/platform/file_system.cc +tensorflow/core/platform/file_system_helper.cc +tensorflow/core/platform/posix/env.cc +tensorflow/core/platform/posix/env_time.cc +tensorflow/core/platform/posix/error.cc +tensorflow/core/platform/posix/load_library.cc +tensorflow/core/platform/posix/port.cc +tensorflow/core/platform/posix/posix_file_system.cc +tensorflow/core/platform/protobuf_util.cc tensorflow/core/platform/setround.cc -tensorflow/core/platform/denormal.cc -tensorflow/core/platform/default/tracing.cc -tensorflow/core/platform/default/mutex.cc -tensorflow/core/platform/default/logging.cc -tensorflow/core/platform/cpu_info.cc -tensorflow/core/lib/wav/wav_io.cc -tensorflow/core/lib/strings/stringprintf.cc -tensorflow/core/lib/strings/strcat.cc -tensorflow/core/lib/strings/str_util.cc -tensorflow/core/lib/strings/scanner.cc -tensorflow/core/lib/strings/proto_text_util.cc -tensorflow/core/lib/strings/ordered_code.cc -tensorflow/core/lib/strings/numbers.cc -tensorflow/core/lib/random/weighted_picker.cc -tensorflow/core/lib/random/simple_philox.cc -tensorflow/core/lib/random/random.cc -tensorflow/core/lib/random/distribution_sampler.cc -tensorflow/core/lib/io/zlib_outputbuffer.cc -tensorflow/core/lib/io/zlib_inputstream.cc -tensorflow/core/lib/io/zlib_compression_options.cc -tensorflow/core/lib/io/two_level_iterator.cc -tensorflow/core/lib/io/table_builder.cc -tensorflow/core/lib/io/table.cc -tensorflow/core/lib/io/record_writer.cc -tensorflow/core/lib/io/record_reader.cc -tensorflow/core/lib/io/random_inputstream.cc -tensorflow/core/lib/io/path.cc -tensorflow/core/lib/io/iterator.cc -tensorflow/core/lib/io/inputstream_interface.cc -tensorflow/core/lib/io/inputbuffer.cc -tensorflow/core/lib/io/format.cc -tensorflow/core/lib/io/compression.cc -tensorflow/core/lib/io/buffered_inputstream.cc -tensorflow/core/lib/io/block_builder.cc -tensorflow/core/lib/io/block.cc -tensorflow/core/lib/histogram/histogram.cc -tensorflow/core/lib/hash/hash.cc -tensorflow/core/lib/hash/crc32c.cc -tensorflow/core/lib/hash/crc32c_accelerate.cc -tensorflow/core/lib/core/threadpool.cc -tensorflow/core/lib/core/stringpiece.cc -tensorflow/core/lib/core/status.cc -tensorflow/core/lib/core/coding.cc -tensorflow/core/lib/core/arena.cc +tensorflow/core/platform/tensor_coding.cc +tensorflow/core/platform/tracing.cc +tensorflow/tools/proto_text/gen_proto_text_functions.cc +tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 938c4a53ab..0d8df93d11 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -1,41 +1,41 @@ -tensorflow/core/util/test_log.pb.cc -tensorflow/core/util/saved_tensor_slice.pb.cc -tensorflow/core/util/memmapped_file_system.pb.cc -tensorflow/core/util/event.pb.cc -tensorflow/core/protobuf/tensorflow_server.pb.cc -tensorflow/core/protobuf/saver.pb.cc -tensorflow/core/protobuf/queue_runner.pb.cc -tensorflow/core/protobuf/named_tensor.pb.cc -tensorflow/core/protobuf/meta_graph.pb.cc +tensorflow/core/example/example.pb.cc +tensorflow/core/example/feature.pb.cc +tensorflow/core/framework/allocation_description.pb.cc +tensorflow/core/framework/api_def.pb.cc +tensorflow/core/framework/attr_value.pb.cc +tensorflow/core/framework/cost_graph.pb.cc +tensorflow/core/framework/device_attributes.pb.cc +tensorflow/core/framework/function.pb.cc +tensorflow/core/framework/graph.pb.cc +tensorflow/core/framework/graph_transfer_info.pb.cc +tensorflow/core/framework/kernel_def.pb.cc +tensorflow/core/framework/log_memory.pb.cc +tensorflow/core/framework/node_def.pb.cc +tensorflow/core/framework/op_def.pb.cc +tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc +tensorflow/core/framework/resource_handle.pb.cc +tensorflow/core/framework/step_stats.pb.cc +tensorflow/core/framework/summary.pb.cc +tensorflow/core/framework/tensor.pb.cc +tensorflow/core/framework/tensor_description.pb.cc +tensorflow/core/framework/tensor_shape.pb.cc +tensorflow/core/framework/tensor_slice.pb.cc +tensorflow/core/framework/types.pb.cc +tensorflow/core/framework/variable.pb.cc +tensorflow/core/framework/versions.pb.cc +tensorflow/core/grappler/costs/op_performance_data.pb.cc +tensorflow/core/lib/core/error_codes.pb.cc tensorflow/core/protobuf/cluster.pb.cc tensorflow/core/protobuf/config.pb.cc -tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/debug.pb.cc tensorflow/core/protobuf/device_properties.pb.cc -tensorflow/core/lib/core/error_codes.pb.cc -tensorflow/core/framework/versions.pb.cc -tensorflow/core/framework/variable.pb.cc -tensorflow/core/framework/types.pb.cc -tensorflow/core/framework/tensor_slice.pb.cc -tensorflow/core/framework/tensor_shape.pb.cc -tensorflow/core/framework/tensor_description.pb.cc -tensorflow/core/framework/tensor.pb.cc -tensorflow/core/framework/summary.pb.cc -tensorflow/core/framework/step_stats.pb.cc -tensorflow/core/framework/resource_handle.pb.cc -tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc -tensorflow/core/framework/api_def.pb.cc -tensorflow/core/framework/op_def.pb.cc -tensorflow/core/framework/node_def.pb.cc -tensorflow/core/framework/log_memory.pb.cc -tensorflow/core/framework/kernel_def.pb.cc -tensorflow/core/framework/graph_transfer_info.pb.cc -tensorflow/core/framework/graph.pb.cc -tensorflow/core/framework/function.pb.cc -tensorflow/core/framework/device_attributes.pb.cc -tensorflow/core/framework/cost_graph.pb.cc -tensorflow/core/framework/attr_value.pb.cc -tensorflow/core/framework/allocation_description.pb.cc -tensorflow/core/example/feature.pb.cc -tensorflow/core/example/example.pb.cc -tensorflow/core/grappler/costs/op_performance_data.pb.cc +tensorflow/core/protobuf/meta_graph.pb.cc +tensorflow/core/protobuf/named_tensor.pb.cc +tensorflow/core/protobuf/queue_runner.pb.cc +tensorflow/core/protobuf/rewriter_config.pb.cc +tensorflow/core/protobuf/saver.pb.cc +tensorflow/core/protobuf/tensorflow_server.pb.cc +tensorflow/core/util/event.pb.cc +tensorflow/core/util/memmapped_file_system.pb.cc +tensorflow/core/util/saved_tensor_slice.pb.cc +tensorflow/core/util/test_log.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index aa91b2f954..d982df9319 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -1,42 +1,43 @@ -tensorflow/core/util/test_log.pb.h -tensorflow/core/util/saved_tensor_slice.pb.h -tensorflow/core/util/memmapped_file_system.pb.h -tensorflow/core/util/event.pb.h -tensorflow/core/protobuf/tensorflow_server.pb.h -tensorflow/core/protobuf/saver.pb.h -tensorflow/core/protobuf/queue_runner.pb.h -tensorflow/core/protobuf/named_tensor.pb.h -tensorflow/core/protobuf/meta_graph.pb.h +tensorflow/core/example/example.pb.h +tensorflow/core/example/feature.pb.h +tensorflow/core/framework/allocation_description.pb.h +tensorflow/core/framework/api_def.pb.h +tensorflow/core/framework/attr_value.pb.h +tensorflow/core/framework/cost_graph.pb.h +tensorflow/core/framework/device_attributes.pb.h +tensorflow/core/framework/function.pb.h +tensorflow/core/framework/graph.pb.h +tensorflow/core/framework/graph_transfer_info.pb.h +tensorflow/core/framework/kernel_def.pb.h +tensorflow/core/framework/log_memory.pb.h +tensorflow/core/framework/node_def.pb.h +tensorflow/core/framework/op_def.pb.h +tensorflow/core/framework/remote_fused_graph_execute_info.pb.h +tensorflow/core/framework/resource_handle.pb.h +tensorflow/core/framework/step_stats.pb.h +tensorflow/core/framework/summary.pb.h +tensorflow/core/framework/tensor.pb.h +tensorflow/core/framework/tensor_description.pb.h +tensorflow/core/framework/tensor_shape.pb.h +tensorflow/core/framework/tensor_slice.pb.h +tensorflow/core/framework/types.pb.h +tensorflow/core/framework/variable.pb.h +tensorflow/core/framework/versions.pb.h +tensorflow/core/grappler/costs/op_performance_data.pb.h +tensorflow/core/lib/core/error_codes.pb.h tensorflow/core/protobuf/cluster.pb.h tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/debug.pb.h tensorflow/core/protobuf/device_properties.pb.h +tensorflow/core/protobuf/meta_graph.pb.h +tensorflow/core/protobuf/named_tensor.pb.h +tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/rewriter_config.pb.h +tensorflow/core/protobuf/saver.pb.h tensorflow/core/protobuf/tensor_bundle.pb.h -tensorflow/core/lib/core/error_codes.pb.h -tensorflow/core/framework/versions.pb.h -tensorflow/core/framework/variable.pb.h -tensorflow/core/framework/types.pb.h -tensorflow/core/framework/tensor_slice.pb.h -tensorflow/core/framework/tensor_shape.pb.h -tensorflow/core/framework/tensor_description.pb.h -tensorflow/core/framework/tensor.pb.h -tensorflow/core/framework/summary.pb.h -tensorflow/core/framework/step_stats.pb.h -tensorflow/core/framework/resource_handle.pb.h -tensorflow/core/framework/remote_fused_graph_execute_info.pb.h -tensorflow/core/framework/api_def.pb.h -tensorflow/core/framework/op_def.pb.h -tensorflow/core/framework/node_def.pb.h -tensorflow/core/framework/log_memory.pb.h -tensorflow/core/framework/kernel_def.pb.h -tensorflow/core/framework/graph_transfer_info.pb.h -tensorflow/core/framework/graph.pb.h -tensorflow/core/framework/function.pb.h -tensorflow/core/framework/device_attributes.pb.h -tensorflow/core/framework/cost_graph.pb.h -tensorflow/core/framework/attr_value.pb.h -tensorflow/core/framework/allocation_description.pb.h -tensorflow/core/example/feature.pb.h -tensorflow/core/example/example.pb.h -tensorflow/core/grappler/costs/op_performance_data.pb.h +tensorflow/core/protobuf/tensorflow_server.pb.h +tensorflow/core/util/event.pb.h +tensorflow/core/util/memmapped_file_system.pb.h +tensorflow/core/util/saved_tensor_slice.pb.h +tensorflow/core/util/test_log.pb.h + diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 66a3315700..676620e544 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -4,218 +4,19 @@ tensorflow/contrib/boosted_trees/ops/quantile_ops.cc tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc tensorflow/contrib/boosted_trees/ops/training_ops.cc -tensorflow/core/kernels/xent_op.cc -tensorflow/core/kernels/where_op.cc -tensorflow/core/kernels/variable_ops.cc -tensorflow/core/kernels/unpack_op.cc -tensorflow/core/kernels/unique_op.cc -tensorflow/core/kernels/transpose_op.cc -tensorflow/core/kernels/transpose_functor_cpu.cc -tensorflow/core/kernels/training_op_helpers.cc -tensorflow/core/kernels/training_ops.cc -tensorflow/core/kernels/topk_op.cc -tensorflow/core/kernels/tile_functor_cpu.cc -tensorflow/core/kernels/tile_ops.cc -tensorflow/core/kernels/tile_ops_cpu_impl_1.cc -tensorflow/core/kernels/tile_ops_cpu_impl_2.cc -tensorflow/core/kernels/tile_ops_cpu_impl_3.cc -tensorflow/core/kernels/tile_ops_cpu_impl_4.cc -tensorflow/core/kernels/tile_ops_cpu_impl_5.cc -tensorflow/core/kernels/tile_ops_cpu_impl_6.cc -tensorflow/core/kernels/tile_ops_cpu_impl_7.cc -tensorflow/core/kernels/tensor_array_ops.cc -tensorflow/core/kernels/tensor_array.cc -tensorflow/core/kernels/strided_slice_op_inst_7.cc -tensorflow/core/kernels/strided_slice_op_inst_6.cc -tensorflow/core/kernels/strided_slice_op_inst_5.cc -tensorflow/core/kernels/strided_slice_op_inst_4.cc -tensorflow/core/kernels/strided_slice_op_inst_3.cc -tensorflow/core/kernels/strided_slice_op_inst_2.cc -tensorflow/core/kernels/strided_slice_op_inst_1.cc -tensorflow/core/kernels/strided_slice_op_inst_0.cc -tensorflow/core/kernels/strided_slice_op.cc -tensorflow/core/kernels/stack_ops.cc -tensorflow/core/kernels/split_op.cc -tensorflow/core/kernels/split_v_op.cc -tensorflow/core/kernels/split_lib_cpu.cc -tensorflow/core/kernels/spectrogram_op.cc -tensorflow/core/kernels/spectrogram.cc -tensorflow/core/kernels/sparse_to_dense_op.cc -tensorflow/core/kernels/sparse_matmul_op.cc -tensorflow/core/kernels/sparse_fill_empty_rows_op.cc -tensorflow/core/kernels/sparse_reshape_op.c -tensorflow/core/kernels/segment_reduction_ops.cc -tensorflow/core/kernels/softsign_op.cc -tensorflow/core/kernels/softplus_op.cc -tensorflow/core/kernels/softmax_op.cc -tensorflow/core/kernels/slice_op_cpu_impl_1.cc -tensorflow/core/kernels/slice_op_cpu_impl_2.cc -tensorflow/core/kernels/slice_op_cpu_impl_3.cc -tensorflow/core/kernels/slice_op_cpu_impl_4.cc -tensorflow/core/kernels/slice_op_cpu_impl_5.cc -tensorflow/core/kernels/slice_op_cpu_impl_6.cc -tensorflow/core/kernels/slice_op_cpu_impl_7.cc -tensorflow/core/kernels/slice_op.cc -tensorflow/core/kernels/shape_ops.cc -tensorflow/core/kernels/session_ops.cc -tensorflow/core/kernels/sequence_ops.cc -tensorflow/core/kernels/sendrecv_ops.cc -tensorflow/core/kernels/scatter_op.cc -tensorflow/core/kernels/scatter_functor.cc -tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc -tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc -tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc -tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc -tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc -tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc -tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc -tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc -tensorflow/core/kernels/scatter_nd_op.cc -tensorflow/core/kernels/save_restore_tensor.cc -tensorflow/core/kernels/save_restore_v2_ops.cc -tensorflow/core/kernels/save_op.cc -tensorflow/core/kernels/string_join_op.cc -tensorflow/core/kernels/reverse_sequence_op.cc -tensorflow/core/kernels/reverse_op.cc -tensorflow/core/kernels/restore_op.cc -tensorflow/core/kernels/resize_nearest_neighbor_op.cc -tensorflow/core/kernels/resize_bilinear_op.cc -tensorflow/core/kernels/reshape_util.cc -tensorflow/core/kernels/reshape_op.cc -tensorflow/core/kernels/relu_op.cc -tensorflow/core/kernels/reduction_ops_sum.cc -tensorflow/core/kernels/reduction_ops_prod.cc -tensorflow/core/kernels/reduction_ops_min.cc -tensorflow/core/kernels/reduction_ops_mean.cc -tensorflow/core/kernels/reduction_ops_max.cc -tensorflow/core/kernels/reduction_ops_common.cc -tensorflow/core/kernels/reduction_ops_any.cc -tensorflow/core/kernels/reduction_ops_all.cc -tensorflow/core/kernels/roll_op.cc -tensorflow/core/kernels/queue_op.cc -tensorflow/core/kernels/queue_ops.cc -tensorflow/core/kernels/queue_base.cc -tensorflow/core/kernels/pooling_ops_common.cc -tensorflow/core/kernels/padding_fifo_queue_op.cc -tensorflow/core/kernels/padding_fifo_queue.cc -tensorflow/core/kernels/pad_op.cc -tensorflow/core/kernels/pack_op.cc -tensorflow/core/kernels/ops_util.cc -tensorflow/core/kernels/one_hot_op.cc -tensorflow/core/kernels/non_max_suppression_op.cc -tensorflow/core/kernels/no_op.cc -tensorflow/core/kernels/mirror_pad_op.cc -tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc -tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc -tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc -tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc -tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc -tensorflow/core/kernels/mfcc_op.cc -tensorflow/core/kernels/mfcc_mel_filterbank.cc -tensorflow/core/kernels/mfcc_dct.cc -tensorflow/core/kernels/mfcc.cc -tensorflow/core/kernels/maxpooling_op.cc -tensorflow/core/kernels/matmul_op.cc -tensorflow/core/kernels/lrn_op.cc -tensorflow/core/kernels/logging_ops.cc -tensorflow/core/kernels/initializable_lookup_table.c -tensorflow/core/kernels/lookup_table_init_op.cc -tensorflow/core/kernels/lookup_table_op.cc -tensorflow/core/kernels/lookup_util.cc -tensorflow/core/kernels/inplace_ops.cc -tensorflow/core/kernels/in_topk_op.cc -tensorflow/core/kernels/immutable_constant_op.cc -tensorflow/core/kernels/identity_op.cc -tensorflow/core/kernels/identity_n_op.cc -tensorflow/core/kernels/gather_op.cc -tensorflow/core/kernels/gather_functor.cc -tensorflow/core/kernels/gather_nd_op.cc -tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc -tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc -tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc -tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc -tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc -tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc -tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc -tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc -tensorflow/core/kernels/fused_batch_norm_op.cc -tensorflow/core/kernels/function_ops.cc -tensorflow/core/kernels/fill_functor.cc -tensorflow/core/kernels/fifo_queue.cc -tensorflow/core/kernels/fifo_queue_op.cc -tensorflow/core/kernels/fake_quant_ops.cc -tensorflow/core/kernels/example_parsing_ops.cc -tensorflow/core/kernels/encode_wav_op.cc -tensorflow/core/kernels/dynamic_stitch_op.cc -tensorflow/core/kernels/dynamic_partition_op.cc -tensorflow/core/kernels/decode_bmp_op.cc -tensorflow/core/kernels/depthtospace_op.cc -tensorflow/core/kernels/data_format_ops.cc -tensorflow/core/kernels/spacetodepth_op.cc -tensorflow/core/kernels/dense_update_functor.cc -tensorflow/core/kernels/dense_update_ops.cc -tensorflow/core/kernels/deep_conv2d.cc -tensorflow/core/kernels/decode_wav_op.cc -tensorflow/core/kernels/xsmm_conv2d.cc -tensorflow/core/kernels/cwise_ops_common.cc -tensorflow/core/kernels/cwise_op_tanh.cc -tensorflow/core/kernels/cwise_op_pow.cc -tensorflow/core/kernels/cwise_op_sub.cc -tensorflow/core/kernels/cwise_op_squared_difference.cc -tensorflow/core/kernels/cwise_op_square.cc -tensorflow/core/kernels/cwise_op_sqrt.cc -tensorflow/core/kernels/cwise_op_sigmoid.cc -tensorflow/core/kernels/cwise_op_sign.cc -tensorflow/core/kernels/cwise_op_select.cc -tensorflow/core/kernels/cwise_op_round.cc -tensorflow/core/kernels/cwise_op_rsqrt.cc -tensorflow/core/kernels/cwise_op_reciprocal.cc -tensorflow/core/kernels/cwise_op_neg.cc -tensorflow/core/kernels/cwise_op_mul_2.cc -tensorflow/core/kernels/cwise_op_mul_1.cc -tensorflow/core/kernels/cwise_op_minimum.cc -tensorflow/core/kernels/cwise_op_maximum.cc -tensorflow/core/kernels/cwise_op_logical_not.cc -tensorflow/core/kernels/cwise_op_logical_and.cc -tensorflow/core/kernels/cwise_op_logical_or.cc -tensorflow/core/kernels/cwise_op_log.cc -tensorflow/core/kernels/cwise_op_less.cc -tensorflow/core/kernels/cwise_op_less_equal.cc -tensorflow/core/kernels/cwise_op_isnan.cc -tensorflow/core/kernels/cwise_op_isfinite.cc -tensorflow/core/kernels/cwise_op_invert.cc -tensorflow/core/kernels/cwise_op_greater_equal.cc -tensorflow/core/kernels/cwise_op_greater.cc -tensorflow/core/kernels/cwise_op_floor_div.cc -tensorflow/core/kernels/cwise_op_floor_mod.cc -tensorflow/core/kernels/cwise_op_floor.cc -tensorflow/core/kernels/cwise_op_exp.cc -tensorflow/core/kernels/cwise_op_equal_to_2.cc -tensorflow/core/kernels/cwise_op_equal_to_1.cc -tensorflow/core/kernels/cwise_op_not_equal_to_2.cc -tensorflow/core/kernels/cwise_op_not_equal_to_1.cc -tensorflow/core/kernels/cwise_op_div.cc -tensorflow/core/kernels/cwise_op_bitwise_xor.cc -tensorflow/core/kernels/cwise_op_bitwise_or.cc -tensorflow/core/kernels/cwise_op_bitwise_and.cc -tensorflow/core/kernels/cwise_op_left_shift.cc -tensorflow/core/kernels/cwise_op_right_shift.cc -tensorflow/core/kernels/cwise_op_add_2.cc -tensorflow/core/kernels/cwise_op_add_1.cc -tensorflow/core/kernels/cwise_op_abs.cc -tensorflow/core/kernels/ctc_decoder_ops.cc -tensorflow/core/kernels/crop_and_resize_op.cc -tensorflow/core/kernels/conv_ops_using_gemm.cc -tensorflow/core/kernels/conv_ops_fused.cc -tensorflow/core/kernels/conv_ops.cc -tensorflow/core/kernels/conv_grad_filter_ops.cc -tensorflow/core/kernels/conv_grad_input_ops.cc -tensorflow/core/kernels/conv_grad_ops.cc -tensorflow/core/kernels/control_flow_ops.cc -tensorflow/core/kernels/constant_op.cc -tensorflow/core/kernels/concat_op.cc -tensorflow/core/kernels/concat_lib_cpu.cc -tensorflow/core/kernels/check_numerics_op.cc +tensorflow/core/kernels/aggregate_ops.cc +tensorflow/core/kernels/argmax_op.cc +tensorflow/core/kernels/avgpooling_op.cc +tensorflow/core/kernels/batch_matmul_op_real.cc +tensorflow/core/kernels/batch_norm_op.cc +tensorflow/core/kernels/batchtospace_op.cc +tensorflow/core/kernels/bcast_ops.cc +tensorflow/core/kernels/bias_op.cc +tensorflow/core/kernels/boosted_trees/prediction_ops.cc +tensorflow/core/kernels/boosted_trees/resource_ops.cc +tensorflow/core/kernels/boosted_trees/resources.cc +tensorflow/core/kernels/boosted_trees/stats_ops.cc +tensorflow/core/kernels/boosted_trees/training_ops.cc tensorflow/core/kernels/cast_op.cc tensorflow/core/kernels/cast_op_impl_bfloat.cc tensorflow/core/kernels/cast_op_impl_bool.cc @@ -232,20 +33,130 @@ tensorflow/core/kernels/cast_op_impl_uint16.cc tensorflow/core/kernels/cast_op_impl_uint32.cc tensorflow/core/kernels/cast_op_impl_uint64.cc tensorflow/core/kernels/cast_op_impl_uint8.cc -tensorflow/core/kernels/boosted_trees/prediction_ops.cc -tensorflow/core/kernels/boosted_trees/resource_ops.cc -tensorflow/core/kernels/boosted_trees/resources.cc -tensorflow/core/kernels/boosted_trees/stats_ops.cc -tensorflow/core/kernels/boosted_trees/training_ops.cc -tensorflow/core/kernels/bias_op.cc -tensorflow/core/kernels/bcast_ops.cc -tensorflow/core/kernels/batch_norm_op.cc -tensorflow/core/kernels/avgpooling_op.cc -tensorflow/core/kernels/argmax_op.cc -tensorflow/core/kernels/aggregate_ops.cc +tensorflow/core/kernels/check_numerics_op.cc +tensorflow/core/kernels/concat_lib_cpu.cc +tensorflow/core/kernels/concat_op.cc +tensorflow/core/kernels/constant_op.cc +tensorflow/core/kernels/control_flow_ops.cc +tensorflow/core/kernels/conv_grad_filter_ops.cc +tensorflow/core/kernels/conv_grad_input_ops.cc +tensorflow/core/kernels/conv_grad_ops.cc +tensorflow/core/kernels/conv_ops.cc +tensorflow/core/kernels/conv_ops_fused.cc +tensorflow/core/kernels/conv_ops_using_gemm.cc +tensorflow/core/kernels/crop_and_resize_op.cc +tensorflow/core/kernels/ctc_decoder_ops.cc +tensorflow/core/kernels/cwise_op_abs.cc +tensorflow/core/kernels/cwise_op_add_1.cc +tensorflow/core/kernels/cwise_op_add_2.cc +tensorflow/core/kernels/cwise_op_bitwise_and.cc +tensorflow/core/kernels/cwise_op_bitwise_or.cc +tensorflow/core/kernels/cwise_op_bitwise_xor.cc +tensorflow/core/kernels/cwise_op_div.cc +tensorflow/core/kernels/cwise_op_equal_to_1.cc +tensorflow/core/kernels/cwise_op_equal_to_2.cc +tensorflow/core/kernels/cwise_op_exp.cc +tensorflow/core/kernels/cwise_op_floor.cc +tensorflow/core/kernels/cwise_op_floor_div.cc +tensorflow/core/kernels/cwise_op_floor_mod.cc +tensorflow/core/kernels/cwise_op_greater.cc +tensorflow/core/kernels/cwise_op_greater_equal.cc +tensorflow/core/kernels/cwise_op_invert.cc +tensorflow/core/kernels/cwise_op_isfinite.cc +tensorflow/core/kernels/cwise_op_isnan.cc +tensorflow/core/kernels/cwise_op_left_shift.cc +tensorflow/core/kernels/cwise_op_less.cc +tensorflow/core/kernels/cwise_op_less_equal.cc +tensorflow/core/kernels/cwise_op_log.cc +tensorflow/core/kernels/cwise_op_logical_and.cc +tensorflow/core/kernels/cwise_op_logical_not.cc +tensorflow/core/kernels/cwise_op_logical_or.cc +tensorflow/core/kernels/cwise_op_maximum.cc +tensorflow/core/kernels/cwise_op_minimum.cc +tensorflow/core/kernels/cwise_op_mul_1.cc +tensorflow/core/kernels/cwise_op_mul_2.cc +tensorflow/core/kernels/cwise_op_neg.cc +tensorflow/core/kernels/cwise_op_not_equal_to_1.cc +tensorflow/core/kernels/cwise_op_not_equal_to_2.cc +tensorflow/core/kernels/cwise_op_pow.cc +tensorflow/core/kernels/cwise_op_reciprocal.cc +tensorflow/core/kernels/cwise_op_right_shift.cc +tensorflow/core/kernels/cwise_op_round.cc +tensorflow/core/kernels/cwise_op_rsqrt.cc +tensorflow/core/kernels/cwise_op_select.cc +tensorflow/core/kernels/cwise_op_sigmoid.cc +tensorflow/core/kernels/cwise_op_sign.cc +tensorflow/core/kernels/cwise_op_sqrt.cc +tensorflow/core/kernels/cwise_op_square.cc +tensorflow/core/kernels/cwise_op_squared_difference.cc +tensorflow/core/kernels/cwise_op_sub.cc +tensorflow/core/kernels/cwise_op_tanh.cc +tensorflow/core/kernels/cwise_ops_common.cc +tensorflow/core/kernels/data_format_ops.cc +tensorflow/core/kernels/decode_bmp_op.cc +tensorflow/core/kernels/decode_proto_op.cc +tensorflow/core/kernels/decode_wav_op.cc +tensorflow/core/kernels/deep_conv2d.cc +tensorflow/core/kernels/dense_update_functor.cc +tensorflow/core/kernels/dense_update_ops.cc +tensorflow/core/kernels/depthtospace_op.cc tensorflow/core/kernels/depthwise_conv_op.cc tensorflow/core/kernels/dequantize_op.cc +tensorflow/core/kernels/dynamic_partition_op.cc +tensorflow/core/kernels/dynamic_stitch_op.cc +tensorflow/core/kernels/encode_proto_op.cc +tensorflow/core/kernels/encode_wav_op.cc +tensorflow/core/kernels/example_parsing_ops.cc +tensorflow/core/kernels/fake_quant_ops.cc +tensorflow/core/kernels/fifo_queue.cc +tensorflow/core/kernels/fifo_queue_op.cc +tensorflow/core/kernels/fill_functor.cc +tensorflow/core/kernels/function_ops.cc +tensorflow/core/kernels/fused_batch_norm_op.cc +tensorflow/core/kernels/gather_functor.cc +tensorflow/core/kernels/gather_nd_op.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc +tensorflow/core/kernels/gather_op.cc +tensorflow/core/kernels/identity_n_op.cc +tensorflow/core/kernels/identity_op.cc +tensorflow/core/kernels/immutable_constant_op.cc +tensorflow/core/kernels/in_topk_op.cc +tensorflow/core/kernels/initializable_lookup_table.c +tensorflow/core/kernels/inplace_ops.cc +tensorflow/core/kernels/logging_ops.cc +tensorflow/core/kernels/lookup_table_init_op.cc +tensorflow/core/kernels/lookup_table_op.cc +tensorflow/core/kernels/lookup_util.cc +tensorflow/core/kernels/lrn_op.cc +tensorflow/core/kernels/matmul_op.cc +tensorflow/core/kernels/maxpooling_op.cc tensorflow/core/kernels/meta_support.cc +tensorflow/core/kernels/mfcc.cc +tensorflow/core/kernels/mfcc_dct.cc +tensorflow/core/kernels/mfcc_mel_filterbank.cc +tensorflow/core/kernels/mfcc_op.cc +tensorflow/core/kernels/mirror_pad_op.cc +tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc +tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc +tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc +tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc +tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc +tensorflow/core/kernels/no_op.cc +tensorflow/core/kernels/non_max_suppression_op.cc +tensorflow/core/kernels/one_hot_op.cc +tensorflow/core/kernels/ops_util.cc +tensorflow/core/kernels/pack_op.cc +tensorflow/core/kernels/pad_op.cc +tensorflow/core/kernels/padding_fifo_queue.cc +tensorflow/core/kernels/padding_fifo_queue_op.cc +tensorflow/core/kernels/pooling_ops_common.cc tensorflow/core/kernels/population_count_op.cc tensorflow/core/kernels/quantization_utils.cc tensorflow/core/kernels/quantize_down_and_shrink_range.cc @@ -262,46 +173,135 @@ tensorflow/core/kernels/quantized_mul_op.cc tensorflow/core/kernels/quantized_pooling_ops.cc tensorflow/core/kernels/quantized_reshape_op.cc tensorflow/core/kernels/quantized_resize_bilinear_op.cc -tensorflow/core/kernels/requantization_range_op.cc -tensorflow/core/kernels/requantize.cc +tensorflow/core/kernels/queue_base.cc +tensorflow/core/kernels/queue_op.cc +tensorflow/core/kernels/queue_ops.cc +tensorflow/core/kernels/random_op.cc +tensorflow/core/kernels/reduction_ops_all.cc +tensorflow/core/kernels/reduction_ops_any.cc +tensorflow/core/kernels/reduction_ops_common.cc +tensorflow/core/kernels/reduction_ops_max.cc +tensorflow/core/kernels/reduction_ops_mean.cc +tensorflow/core/kernels/reduction_ops_min.cc +tensorflow/core/kernels/reduction_ops_prod.cc +tensorflow/core/kernels/reduction_ops_sum.cc +tensorflow/core/kernels/relu_op.cc tensorflow/core/kernels/remote_fused_graph_execute_op.cc tensorflow/core/kernels/remote_fused_graph_execute_utils.cc -tensorflow/core/kernels/batch_matmul_op_real.cc -tensorflow/core/kernels/random_op.cc -tensorflow/core/ops/training_ops.cc -tensorflow/core/ops/string_ops.cc -tensorflow/core/ops/state_ops.cc -tensorflow/core/ops/sparse_ops.cc -tensorflow/core/ops/sendrecv_ops.cc -tensorflow/core/ops/script_ops.cc -tensorflow/core/ops/remote_fused_graph_ops.cc -tensorflow/core/ops/random_ops.cc -tensorflow/core/ops/random_grad.cc -tensorflow/core/ops/parsing_ops.cc -tensorflow/core/ops/no_op.cc -tensorflow/core/ops/nn_ops.cc -tensorflow/core/ops/nn_grad.cc -tensorflow/core/ops/manip_ops.cc -tensorflow/core/ops/math_ops.cc -tensorflow/core/ops/math_grad.cc -tensorflow/core/ops/logging_ops.cc -tensorflow/core/ops/linalg_ops.cc -tensorflow/core/ops/io_ops.cc -tensorflow/core/ops/image_ops.cc -tensorflow/core/ops/functional_ops.cc -tensorflow/core/ops/functional_grad.cc -tensorflow/core/ops/function_ops.cc -tensorflow/core/ops/data_flow_ops.cc -tensorflow/core/ops/ctc_ops.cc -tensorflow/core/ops/control_flow_ops.cc -tensorflow/core/ops/candidate_sampling_ops.cc -tensorflow/core/ops/boosted_trees_ops.cc -tensorflow/core/ops/array_ops.cc -tensorflow/core/ops/array_grad.cc +tensorflow/core/kernels/requantization_range_op.cc +tensorflow/core/kernels/requantize.cc +tensorflow/core/kernels/reshape_op.cc +tensorflow/core/kernels/reshape_util.cc +tensorflow/core/kernels/resize_bilinear_op.cc +tensorflow/core/kernels/resize_nearest_neighbor_op.cc +tensorflow/core/kernels/restore_op.cc +tensorflow/core/kernels/reverse_op.cc +tensorflow/core/kernels/reverse_sequence_op.cc +tensorflow/core/kernels/roll_op.cc +tensorflow/core/kernels/save_op.cc +tensorflow/core/kernels/save_restore_tensor.cc +tensorflow/core/kernels/save_restore_v2_ops.cc +tensorflow/core/kernels/scatter_functor.cc +tensorflow/core/kernels/scatter_nd_op.cc +tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc +tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc +tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc +tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc +tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc +tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc +tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc +tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc +tensorflow/core/kernels/scatter_op.cc +tensorflow/core/kernels/segment_reduction_ops.cc +tensorflow/core/kernels/segment_reduction_ops.cc +tensorflow/core/kernels/sendrecv_ops.cc +tensorflow/core/kernels/sequence_ops.cc +tensorflow/core/kernels/session_ops.cc +tensorflow/core/kernels/shape_ops.cc +tensorflow/core/kernels/slice_op.cc +tensorflow/core/kernels/slice_op_cpu_impl_1.cc +tensorflow/core/kernels/slice_op_cpu_impl_2.cc +tensorflow/core/kernels/slice_op_cpu_impl_3.cc +tensorflow/core/kernels/slice_op_cpu_impl_4.cc +tensorflow/core/kernels/slice_op_cpu_impl_5.cc +tensorflow/core/kernels/slice_op_cpu_impl_6.cc +tensorflow/core/kernels/slice_op_cpu_impl_7.cc +tensorflow/core/kernels/softmax_op.cc +tensorflow/core/kernels/softplus_op.cc +tensorflow/core/kernels/softsign_op.cc tensorflow/core/kernels/spacetobatch_functor.cc tensorflow/core/kernels/spacetobatch_op.cc -tensorflow/core/kernels/batchtospace_op.cc -tensorflow/core/kernels/segment_reduction_ops.cc +tensorflow/core/kernels/spacetodepth_op.cc +tensorflow/core/kernels/sparse_fill_empty_rows_op.cc +tensorflow/core/kernels/sparse_matmul_op.cc +tensorflow/core/kernels/sparse_reshape_op.c +tensorflow/core/kernels/sparse_to_dense_op.cc +tensorflow/core/kernels/spectrogram.cc +tensorflow/core/kernels/spectrogram_op.cc +tensorflow/core/kernels/split_lib_cpu.cc +tensorflow/core/kernels/split_op.cc +tensorflow/core/kernels/split_v_op.cc +tensorflow/core/kernels/stack_ops.cc +tensorflow/core/kernels/strided_slice_op.cc +tensorflow/core/kernels/strided_slice_op_inst_0.cc +tensorflow/core/kernels/strided_slice_op_inst_1.cc +tensorflow/core/kernels/strided_slice_op_inst_2.cc +tensorflow/core/kernels/strided_slice_op_inst_3.cc +tensorflow/core/kernels/strided_slice_op_inst_4.cc +tensorflow/core/kernels/strided_slice_op_inst_5.cc +tensorflow/core/kernels/strided_slice_op_inst_6.cc +tensorflow/core/kernels/strided_slice_op_inst_7.cc +tensorflow/core/kernels/string_join_op.cc +tensorflow/core/kernels/tensor_array.cc +tensorflow/core/kernels/tensor_array_ops.cc +tensorflow/core/kernels/tile_functor_cpu.cc +tensorflow/core/kernels/tile_ops.cc +tensorflow/core/kernels/tile_ops_cpu_impl_1.cc +tensorflow/core/kernels/tile_ops_cpu_impl_2.cc +tensorflow/core/kernels/tile_ops_cpu_impl_3.cc +tensorflow/core/kernels/tile_ops_cpu_impl_4.cc +tensorflow/core/kernels/tile_ops_cpu_impl_5.cc +tensorflow/core/kernels/tile_ops_cpu_impl_6.cc +tensorflow/core/kernels/tile_ops_cpu_impl_7.cc +tensorflow/core/kernels/topk_op.cc +tensorflow/core/kernels/training_op_helpers.cc +tensorflow/core/kernels/training_ops.cc +tensorflow/core/kernels/transpose_functor_cpu.cc +tensorflow/core/kernels/transpose_op.cc +tensorflow/core/kernels/unique_op.cc +tensorflow/core/kernels/unpack_op.cc +tensorflow/core/kernels/variable_ops.cc +tensorflow/core/kernels/where_op.cc +tensorflow/core/kernels/xent_op.cc +tensorflow/core/kernels/xsmm_conv2d.cc +tensorflow/core/ops/array_grad.cc +tensorflow/core/ops/array_ops.cc tensorflow/core/ops/audio_ops.cc -tensorflow/core/kernels/decode_proto_op.cc -tensorflow/core/kernels/encode_proto_op.cc +tensorflow/core/ops/boosted_trees_ops.cc +tensorflow/core/ops/candidate_sampling_ops.cc +tensorflow/core/ops/control_flow_ops.cc +tensorflow/core/ops/ctc_ops.cc +tensorflow/core/ops/data_flow_ops.cc +tensorflow/core/ops/function_ops.cc +tensorflow/core/ops/functional_grad.cc +tensorflow/core/ops/functional_ops.cc +tensorflow/core/ops/image_ops.cc +tensorflow/core/ops/io_ops.cc +tensorflow/core/ops/linalg_ops.cc +tensorflow/core/ops/logging_ops.cc +tensorflow/core/ops/manip_ops.cc +tensorflow/core/ops/math_grad.cc +tensorflow/core/ops/math_ops.cc +tensorflow/core/ops/nn_grad.cc +tensorflow/core/ops/nn_ops.cc +tensorflow/core/ops/no_op.cc +tensorflow/core/ops/parsing_ops.cc +tensorflow/core/ops/random_grad.cc +tensorflow/core/ops/random_ops.cc +tensorflow/core/ops/remote_fused_graph_ops.cc +tensorflow/core/ops/script_ops.cc +tensorflow/core/ops/sendrecv_ops.cc +tensorflow/core/ops/sparse_ops.cc +tensorflow/core/ops/state_ops.cc +tensorflow/core/ops/string_ops.cc +tensorflow/core/ops/training_ops.cc diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index b5431df2eb..f94d70db90 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -1,33 +1,33 @@ -tensorflow/core/util/saved_tensor_slice.pb_text.cc -tensorflow/core/util/memmapped_file_system.pb_text.cc -tensorflow/core/protobuf/saver.pb_text.cc +tensorflow/core/example/example.pb_text.cc +tensorflow/core/example/feature.pb_text.cc +tensorflow/core/framework/allocation_description.pb_text.cc +tensorflow/core/framework/api_def.pb_text.cc +tensorflow/core/framework/attr_value.pb_text.cc +tensorflow/core/framework/cost_graph.pb_text.cc +tensorflow/core/framework/device_attributes.pb_text.cc +tensorflow/core/framework/function.pb_text.cc +tensorflow/core/framework/graph.pb_text.cc +tensorflow/core/framework/graph_transfer_info.pb_text.cc +tensorflow/core/framework/kernel_def.pb_text.cc +tensorflow/core/framework/log_memory.pb_text.cc +tensorflow/core/framework/node_def.pb_text.cc +tensorflow/core/framework/op_def.pb_text.cc +tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc +tensorflow/core/framework/resource_handle.pb_text.cc +tensorflow/core/framework/step_stats.pb_text.cc +tensorflow/core/framework/summary.pb_text.cc +tensorflow/core/framework/tensor.pb_text.cc +tensorflow/core/framework/tensor_description.pb_text.cc +tensorflow/core/framework/tensor_shape.pb_text.cc +tensorflow/core/framework/tensor_slice.pb_text.cc +tensorflow/core/framework/types.pb_text.cc +tensorflow/core/framework/versions.pb_text.cc +tensorflow/core/lib/core/error_codes.pb_text.cc tensorflow/core/protobuf/cluster.pb_text.cc tensorflow/core/protobuf/config.pb_text.cc tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc +tensorflow/core/protobuf/saver.pb_text.cc tensorflow/core/protobuf/tensor_bundle.pb_text.cc -tensorflow/core/lib/core/error_codes.pb_text.cc -tensorflow/core/framework/versions.pb_text.cc -tensorflow/core/framework/types.pb_text.cc -tensorflow/core/framework/tensor_slice.pb_text.cc -tensorflow/core/framework/tensor_shape.pb_text.cc -tensorflow/core/framework/tensor_description.pb_text.cc -tensorflow/core/framework/tensor.pb_text.cc -tensorflow/core/framework/summary.pb_text.cc -tensorflow/core/framework/step_stats.pb_text.cc -tensorflow/core/framework/resource_handle.pb_text.cc -tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc -tensorflow/core/framework/api_def.pb_text.cc -tensorflow/core/framework/op_def.pb_text.cc -tensorflow/core/framework/node_def.pb_text.cc -tensorflow/core/framework/log_memory.pb_text.cc -tensorflow/core/framework/kernel_def.pb_text.cc -tensorflow/core/framework/graph_transfer_info.pb_text.cc -tensorflow/core/framework/graph.pb_text.cc -tensorflow/core/framework/function.pb_text.cc -tensorflow/core/framework/device_attributes.pb_text.cc -tensorflow/core/framework/cost_graph.pb_text.cc -tensorflow/core/framework/attr_value.pb_text.cc -tensorflow/core/framework/allocation_description.pb_text.cc -tensorflow/core/example/feature.pb_text.cc -tensorflow/core/example/example.pb_text.cc +tensorflow/core/util/memmapped_file_system.pb_text.cc +tensorflow/core/util/saved_tensor_slice.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 1f254692d7..8bec3e3e01 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -2,47 +2,47 @@ tensorflow/contrib/boosted_trees/proto/learner.proto tensorflow/contrib/boosted_trees/proto/quantiles.proto tensorflow/contrib/boosted_trees/proto/split_info.proto tensorflow/contrib/boosted_trees/proto/tree_config.proto -tensorflow/core/util/test_log.proto -tensorflow/core/util/saved_tensor_slice.proto -tensorflow/core/util/memmapped_file_system.proto -tensorflow/core/util/event.proto -tensorflow/core/protobuf/tensorflow_server.proto -tensorflow/core/protobuf/saver.proto -tensorflow/core/protobuf/queue_runner.proto -tensorflow/core/protobuf/named_tensor.proto -tensorflow/core/protobuf/meta_graph.proto +tensorflow/core/example/example.proto +tensorflow/core/example/feature.proto +tensorflow/core/framework/allocation_description.proto +tensorflow/core/framework/api_def.proto +tensorflow/core/framework/attr_value.proto +tensorflow/core/framework/cost_graph.proto +tensorflow/core/framework/device_attributes.proto +tensorflow/core/framework/function.proto +tensorflow/core/framework/graph.proto +tensorflow/core/framework/graph_transfer_info.proto +tensorflow/core/framework/kernel_def.proto +tensorflow/core/framework/log_memory.proto +tensorflow/core/framework/node_def.proto +tensorflow/core/framework/op_def.proto +tensorflow/core/framework/reader_base.proto +tensorflow/core/framework/remote_fused_graph_execute_info.proto +tensorflow/core/framework/resource_handle.proto +tensorflow/core/framework/step_stats.proto +tensorflow/core/framework/summary.proto +tensorflow/core/framework/tensor.proto +tensorflow/core/framework/tensor_description.proto +tensorflow/core/framework/tensor_shape.proto +tensorflow/core/framework/tensor_slice.proto +tensorflow/core/framework/types.proto +tensorflow/core/framework/variable.proto +tensorflow/core/framework/versions.proto +tensorflow/core/grappler/costs/op_performance_data.proto +tensorflow/core/kernels/boosted_trees/boosted_trees.proto +tensorflow/core/lib/core/error_codes.proto tensorflow/core/protobuf/cluster.proto tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/debug.proto tensorflow/core/protobuf/device_properties.proto +tensorflow/core/protobuf/meta_graph.proto +tensorflow/core/protobuf/named_tensor.proto +tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/rewriter_config.proto +tensorflow/core/protobuf/saver.proto tensorflow/core/protobuf/tensor_bundle.proto -tensorflow/core/lib/core/error_codes.proto -tensorflow/core/kernels/boosted_trees/boosted_trees.proto -tensorflow/core/framework/versions.proto -tensorflow/core/framework/variable.proto -tensorflow/core/framework/types.proto -tensorflow/core/framework/tensor_slice.proto -tensorflow/core/framework/tensor_shape.proto -tensorflow/core/framework/tensor_description.proto -tensorflow/core/framework/tensor.proto -tensorflow/core/framework/summary.proto -tensorflow/core/framework/step_stats.proto -tensorflow/core/framework/resource_handle.proto -tensorflow/core/framework/remote_fused_graph_execute_info.proto -tensorflow/core/framework/reader_base.proto -tensorflow/core/framework/api_def.proto -tensorflow/core/framework/op_def.proto -tensorflow/core/framework/node_def.proto -tensorflow/core/framework/log_memory.proto -tensorflow/core/framework/kernel_def.proto -tensorflow/core/framework/graph_transfer_info.proto -tensorflow/core/framework/graph.proto -tensorflow/core/framework/function.proto -tensorflow/core/framework/device_attributes.proto -tensorflow/core/framework/cost_graph.proto -tensorflow/core/framework/attr_value.proto -tensorflow/core/framework/allocation_description.proto -tensorflow/core/example/feature.proto -tensorflow/core/example/example.proto -tensorflow/core/grappler/costs/op_performance_data.proto +tensorflow/core/protobuf/tensorflow_server.proto +tensorflow/core/util/event.proto +tensorflow/core/util/memmapped_file_system.proto +tensorflow/core/util/saved_tensor_slice.proto +tensorflow/core/util/test_log.proto -- cgit v1.2.3 From 4136bd49d92c80de3c6ae03ffdb2524b36e96fa8 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Sat, 8 Sep 2018 09:22:58 -0700 Subject: [tf.data] Refactoring of optimization tests. PiperOrigin-RevId: 212119773 --- tensorflow/contrib/data/python/kernel_tests/BUILD | 15 --- .../python/kernel_tests/map_dataset_op_test.py | 2 +- .../data/python/kernel_tests/optimization/BUILD | 35 ++++- .../optimization/assert_next_dataset_op_test.py | 64 +++++++++ .../optimization/optimize_dataset_op_test.py | 108 ++++++++++++++++ .../kernel_tests/optimize_dataset_op_test.py | 143 --------------------- 6 files changed, 204 insertions(+), 163 deletions(-) create mode 100644 tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py create mode 100644 tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py delete mode 100644 tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index b9320e5fef..6f0111a2bd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -285,21 +285,6 @@ py_test( ], ) -py_test( - name = "optimize_dataset_op_test", - size = "small", - srcs = ["optimize_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - py_test( name = "parsing_ops_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index dc9d56dd53..55c9ac68dd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -209,7 +209,7 @@ class MapDatasetBenchmark(test.Benchmark): end = time.time() chained_deltas.append(end - start) - fused_dataset = dataset = dataset.apply( + fused_dataset = dataset.apply( batching.map_and_batch( math_ops.matmul, num_parallel_calls=num_calls, diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index b299e0736f..459bdf66f3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -6,6 +6,34 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") +py_test( + name = "assert_next_dataset_op_test", + size = "medium", + srcs = ["assert_next_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "latency_all_edges_test", + size = "small", + srcs = ["latency_all_edges_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "map_vectorization_test", size = "small", @@ -46,16 +74,15 @@ py_test( ) py_test( - name = "latency_all_edges_test", + name = "optimize_dataset_op_test", size = "small", - srcs = ["latency_all_edges_test.py"], + srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py new file mode 100644 index 0000000000..bd7b50b902 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class AssertNextDatasetTest(test.TestCase): + + def testAssertNext(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + + def testAssertNextInvalid(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted Whoops transformation at offset 0 but encountered " + "Map transformation instead."): + sess.run(get_next) + + def testAssertNextShort(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted next 2 transformations but encountered only 1."): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py new file mode 100644 index 0000000000..909da5aee0 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py @@ -0,0 +1,108 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test.TestCase): + + def testOptimizationDefault(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimizationEmpty(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimizationFusion(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimizationStatefulFunction(self): + dataset = dataset_ops.Dataset.range(10).map( + lambda _: random_ops.random_uniform([])).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(get_next) + + def testOptimizationLargeInputFromTensor(self): + input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) + dataset = dataset_ops.Dataset.from_tensors(input_t).apply( + optimization.optimize()) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) + sess.run(get_next) + + def testOptimizationLargeInputFromTensorSlices(self): + input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) + dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply( + optimization.optimize()) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py deleted file mode 100644 index 089717156c..0000000000 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib.data.python.ops import optimization -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.platform import test - - -class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): - - def testAssertSuffix(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Map"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertEqual(0, sess.run(get_next)) - - def testAssertSuffixInvalid(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Whoops"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Asserted Whoops transformation at offset 0 but encountered " - "Map transformation instead."): - sess.run(get_next) - - def testAssertSuffixShort(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Asserted next 2 transformations but encountered only 1."): - sess.run(get_next) - - def testOptimizationDefault(self): - dataset = dataset_ops.Dataset.range(10).apply( - optimization.assert_next( - ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( - optimization.optimize()) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOptimizationEmpty(self): - dataset = dataset_ops.Dataset.range(10).apply( - optimization.assert_next( - ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( - optimization.optimize([])) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOptimizationFusion(self): - dataset = dataset_ops.Dataset.range(10).apply( - optimization.assert_next( - ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply( - optimization.optimize(["map_and_batch_fusion"])) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOptimizationStatefulFunction(self): - dataset = dataset_ops.Dataset.range(10).map( - lambda _: random_ops.random_uniform([])).batch(10).apply( - optimization.optimize(["map_and_batch_fusion"])) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(get_next) - - def testOptimizationLargeInputFromTensor(self): - input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) - dataset = dataset_ops.Dataset.from_tensors(input_t).apply( - optimization.optimize()) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) - sess.run(get_next) - - def testOptimizationLargeInputFromTensorSlices(self): - input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) - dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply( - optimization.optimize()) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) - sess.run(get_next) - - -if __name__ == "__main__": - test.main() -- cgit v1.2.3 From a6bb25c05c15e39d04baf6dac30200db367e1ef2 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Sat, 8 Sep 2018 09:23:24 -0700 Subject: Make scheduling and rematerialization HLO passes. Now that HloSchedule is a field on the HLO module, scheduling can be done as an HLO pass. Similarly, rematerialization which requires a schedule can also be a pass which just gets the schedule from the module. Also as a clean up, hoist calls to CopyInsertion out of rematerialization. PiperOrigin-RevId: 212119795 --- tensorflow/compiler/xla/service/BUILD | 24 +- .../compiler/xla/service/buffer_assignment.cc | 1 - .../compiler/xla/service/buffer_assignment_test.cc | 2 +- tensorflow/compiler/xla/service/cpu/BUILD | 2 +- .../compiler/xla/service/cpu/cpu_compiler.cc | 2 +- tensorflow/compiler/xla/service/gpu/BUILD | 2 +- .../compiler/xla/service/gpu/gpu_hlo_schedule.cc | 2 +- .../compiler/xla/service/hlo_memory_scheduler.cc | 603 +++++++++++++++++++++ .../compiler/xla/service/hlo_memory_scheduler.h | 123 +++++ .../xla/service/hlo_memory_scheduler_test.cc | 432 +++++++++++++++ .../compiler/xla/service/hlo_ordering_test.cc | 1 - .../compiler/xla/service/hlo_rematerialization.cc | 88 +-- .../compiler/xla/service/hlo_rematerialization.h | 83 ++- .../xla/service/hlo_rematerialization_test.cc | 75 ++- .../compiler/xla/service/hlo_schedule_test.cc | 2 +- tensorflow/compiler/xla/service/hlo_scheduling.cc | 585 -------------------- tensorflow/compiler/xla/service/hlo_scheduling.h | 91 ---- .../compiler/xla/service/hlo_scheduling_test.cc | 420 -------------- tensorflow/compiler/xla/service/hlo_verifier.cc | 5 + 19 files changed, 1272 insertions(+), 1271 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_memory_scheduler.cc create mode 100644 tensorflow/compiler/xla/service/hlo_memory_scheduler.h create mode 100644 tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc delete mode 100644 tensorflow/compiler/xla/service/hlo_scheduling.cc delete mode 100644 tensorflow/compiler/xla/service/hlo_scheduling.h delete mode 100644 tensorflow/compiler/xla/service/hlo_scheduling_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e784663ff6..6ace6d3271 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1012,8 +1012,8 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", + ":hlo_memory_scheduler", ":hlo_proto", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1041,8 +1041,8 @@ tf_cc_test( ":cpu_plugin", ":flatten_call_graph", ":hlo", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1088,8 +1088,8 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dataflow_analysis", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1185,9 +1185,9 @@ tf_cc_test( ":heap_simulator", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1199,13 +1199,14 @@ tf_cc_test( ) cc_library( - name = "hlo_scheduling", - srcs = ["hlo_scheduling.cc"], - hdrs = ["hlo_scheduling.h"], + name = "hlo_memory_scheduler", + srcs = ["hlo_memory_scheduler.cc"], + hdrs = ["hlo_memory_scheduler.h"], deps = [ ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1219,15 +1220,15 @@ cc_library( ) tf_cc_test( - name = "hlo_scheduling_test", - srcs = ["hlo_scheduling_test.cc"], + name = "hlo_memory_scheduler_test", + srcs = ["hlo_memory_scheduler_test.cc"], deps = [ ":heap_simulator", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -2394,12 +2395,11 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", - ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 0f0af57626..65fa951afe 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 5a231c173d..c30abd1d3e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -30,11 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2368ac8c6a..039cbbff6c 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -122,7 +122,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", - "//tensorflow/compiler/xla/service:hlo_scheduling", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e7b6075994..18fc144efe 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -77,12 +77,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6791e15ee0..569381f5b0 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -813,9 +813,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index ea9376e101..02a0d028c1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -21,9 +21,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc new file mode 100644 index 0000000000..c7ec88d450 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -0,0 +1,603 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace { + +using ::tensorflow::strings::HumanReadableNumBytes; + +// Class implementing a list scheduler of HLO instructions which produces a +// sequence which minimizes memory usage by preferring to schedule the node that +// frees bigger buffer and defines smaller outputs. +// +// Note that list scheduler is a greedy algorithm which cannot guarantee a +// global optimal solution. As a counterexample, considering the following +// graph: +// +// +--> B ===> C -------+ +// A -> | | +// | v +// +--> D ---> F=======>G +// | ^ +// | | +// +--> E -----+ +// +// --> : Buffer with size 1 +// ==> : Buffer with size 2 +// +// The list scheduler will always try to defer scheduling B in a greedy way +// since its output buffer is bigger than input. The sequence it creates will +// be: +// A D E F B C G +// , which has a maximum memory usage of 6 (B is alive while F is executing). +// +// An optimal way to shedule the previous graph is: +// A B C D E F G +// , which has a maximum memory usage of 5 (when F is executing). +// +class ListScheduler { + public: + // Construct and return a memory-minimizing sequence of HLO instructions + // containing the given HLO computation. + static StatusOr Run( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + ListScheduler scheduler(computation, points_to_analysis, size_function, + memory_by_computation); + return scheduler.CreateSchedule(); + } + + // Returns whether the memory used by the given HLO should be ignored by the + // scheduling heuristic. + static bool IgnoreInstruction(const HloInstruction& instruction) { + return instruction.opcode() == HloOpcode::kParameter || + instruction.opcode() == HloOpcode::kConstant; + } + + private: + // The scheduling priority of an instruction is first the number of bytes + // freed by scheduling the instruction, and second (tie-breaker) by the number + // of users. This is represented as a std::pair containing these two values + // (first element is the bytes freed). std::pair provides the necessary + // comparison operators. + using Priority = std::pair; + + ListScheduler(const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) + : computation_(computation), + points_to_analysis_(points_to_analysis), + size_function_(size_function), + memory_by_computation_(memory_by_computation) { + // Create a map containing the LogicalBuffer uses for each HLO + // instruction. An HLO instruction "uses" a LogicalBuffer if the + // LogicalBuffer is in an operand of the instruction as indicated by + // points-to analysis. + for (auto* instruction : computation.instructions()) { + tensorflow::gtl::FlatSet instr_uses; + for (auto* operand : instruction->operands()) { + points_to_analysis.GetPointsToSet(operand).ForEachElement( + [&](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + instr_uses.insert(buffers.begin(), buffers.end()); + }); + } + buffer_uses_[instruction] = std::vector( + instr_uses.begin(), instr_uses.end()); + } + + // Create map containing the number of unscheduled uses (hlo instructions) + // of each logical buffer. + for (auto* instruction : computation.instructions()) { + for (auto* buffer : + points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { + unscheduled_use_count_[buffer] = 0; + } + } + for (auto* instruction : computation.instructions()) { + for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { + ++unscheduled_use_count_[buffer]; + } + } + + // Buffers live out of the computation have an implicit use at the end of + // the computation. + for (const LogicalBuffer* live_out_buffer : + points_to_analysis.GetPointsToSet(computation.root_instruction()) + .CreateFlattenedSet()) { + ++unscheduled_use_count_[live_out_buffer]; + } + } + + // Returns whether the memory used by the given buffer should be ignored by + // the scheduling heuristic. + static bool IgnoreBuffer(const LogicalBuffer& buffer) { + return IgnoreInstruction(*buffer.instruction()); + } + + // An entry in the worklist used by CreateSchedule. Corresponds to one + // HloInstruction, plus some cached metadata, saved for the purposes of making + // BytesFreedIfScheduled fast. + struct ReadyListEntry { + const HloInstruction* instruction; + + // The total size of all buffers defined by this instruction. + int64 bytes_defined; + + // For each buffer B used by this instruction, we keep a pair (B, U), where + // U is the number of uses of B that have not yet been scheduled. This pair + // is a pointer into the unscheduled_use_count_ map, so it gets updated for + // free when we update counts in the map. + std::vector*> + used_buffer_unscheduled_use_counts; + }; + + // Creates a ReadyListEntry for the given instruction. + ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { + ReadyListEntry entry; + entry.instruction = instruction; + + entry.bytes_defined = 0; + for (auto* buffer : + points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { + if (!IgnoreBuffer(*buffer)) { + entry.bytes_defined += size_function_(*buffer); + } + } + + for (auto* buffer : buffer_uses_.at(instruction)) { + if (IgnoreBuffer(*buffer)) { + continue; + } + auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer); + CHECK(unscheduled_use_count_it != unscheduled_use_count_.end()); + entry.used_buffer_unscheduled_use_counts.push_back( + &*unscheduled_use_count_it); + } + return entry; + } + + // Returns the number of bytes freed if the HLO instruction is scheduled. + // If the instruction calls subcomputations, we count the memory used by the + // subcomputations as memory "defined" by the instruction. This is not + // entirely accurate, because subcomputation memory will be freed after the + // instruction finishes. But it is more accurate than not taking + // subcomputations into account at all. In the future, we may improve + // accounting for subcomputation memory (b/65409243). + int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { + int64 freed_bytes = 0; + for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { + auto buffer = kv->first; + auto use_count = kv->second; + if (use_count == 1) { + freed_bytes += size_function_(*buffer); + } + } + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : entry.instruction->called_computations()) { + auto it = memory_by_computation_.find(c); + if (it != memory_by_computation_.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; + } + + // Constructs the scheduling priority of the given instruction. + Priority GetPriority(const ReadyListEntry& entry) { + return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; + } + + HloInstructionSequence CreateSchedule() { + HloInstructionSequence schedule; + + // Populate the ready list with instructions which have no operands or + // control predecessors. + tensorflow::gtl::FlatMap + unscheduled_pred_count; + for (auto* instruction : computation_.instructions()) { + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (const HloInstruction* user : instruction->users()) { + unscheduled_pred_count[user]++; + } + for (const HloInstruction* succ : instruction->control_successors()) { + unscheduled_pred_count[succ]++; + } + } + + // Use a multimap to sort ReadyListEntry according to their priority. + std::multimap ready_queue; + + // Map of ready instructions to their iterators in ready_queue. + tensorflow::gtl::FlatMap::iterator> + ready_instructions; + + auto add_to_ready_queue = [&](HloInstruction* inst) { + auto entry = MakeReadyListEntry(inst); + auto it = ready_queue.emplace(GetPriority(entry), std::move(entry)); + ready_instructions[inst] = it; + }; + + for (auto* instruction : computation_.instructions()) { + // Instruction with no operands or control predecessors will + // not be in the map. + if (unscheduled_pred_count.count(instruction) == 0) { + add_to_ready_queue(instruction); + } + } + + while (!ready_queue.empty()) { + // Remove the selected instruction from the ready list and add it to the + // schedule. + auto best_it = ready_queue.end(); + --best_it; + const HloInstruction* best = best_it->second.instruction; + VLOG(2) << "Schedule instruction: " << best->ToShortString() + << " Bytes freed: " << best_it->first.first; + ready_queue.erase(best_it); + ready_instructions.erase(best); + schedule.push_back(best); + scheduled_instructions_.insert(best); + + bool adjust_ready_queue = false; + // Update the unscheduled uses of the logical buffers. + for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { + int64& count = unscheduled_use_count_[buffer]; + CHECK_GT(count, 0); + --count; + if (count == 1) { + adjust_ready_queue = true; + } + } + + // Add new instructions to ready list. + auto update_pred_count = [&](HloInstruction* inst) { + int64 pred_count = --unscheduled_pred_count.at(inst); + CHECK_GE(pred_count, 0); + if (pred_count == 0) { + add_to_ready_queue(inst); + } + }; + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (HloInstruction* user : best->users()) { + update_pred_count(user); + } + for (HloInstruction* succ : best->control_successors()) { + update_pred_count(succ); + } + // The unscheduled use count for a buffer has changed to 1, so the + // priorities of some ready instructions may go up. We update them in the + // ready queue, so that they can appear earlier. + if (adjust_ready_queue) { + for (HloInstruction* operand : best->operands()) { + for (HloInstruction* operand_user : operand->users()) { + auto ready_instructions_it = ready_instructions.find(operand_user); + if (ready_instructions_it == ready_instructions.end()) { + continue; + } + auto ready_queue_it = ready_instructions_it->second; + auto& entry = ready_queue_it->second; + Priority new_priority = GetPriority(entry); + if (new_priority == ready_queue_it->first) { + continue; + } + // Create a new entry in ready_queue, then update + // ready_instructions[operand_user] to refer to the new entry. + ready_instructions_it->second = + ready_queue.emplace(new_priority, std::move(entry)); + // Remove the old entry in ready_queue. + ready_queue.erase(ready_queue_it); + } + } + } + } + CHECK_EQ(schedule.size(), computation_.instruction_count()); + CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); + + return schedule; + } + + const HloComputation& computation_; + const TuplePointsToAnalysis& points_to_analysis_; + const LogicalBuffer::SizeFunction& size_function_; + // Computations are analyzed in post-order. When scheduling an instruction + // that includes subcomputations, such as a while loop, we use this map to + // look up the memory needed by subcomputations. + const tensorflow::gtl::FlatMap& + memory_by_computation_; + + // A map containing the LogicalBuffers that each instruction uses. + tensorflow::gtl::FlatMap> + buffer_uses_; + + // A map containing the count of unscheduled HLOs which using a particular + // LogicalBuffer. We rely on iterator stability in this map, and that the map + // entries are std::pair's. + std::unordered_map unscheduled_use_count_; + + // Set of instructions which have been scheduled. + tensorflow::gtl::FlatSet scheduled_instructions_; +}; + +int64 SumLogicalBufferSizes( + const TuplePointsToAnalysis::BufferDefinitionVector& buffers, + const LogicalBuffer::SizeFunction& size_function) { + int64 size = 0; + for (const LogicalBuffer* buffer : buffers) { + size += size_function(*buffer); + } + return size; +} + +StatusOr ScheduleComputationHelper( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function, + memory_by_computation); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation); +} + +} // namespace + +StatusOr DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // These variables are a hack to prevent overflows. + int64 cumulative_total_size = 0; + int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); + tensorflow::gtl::FlatMap extra_users; + tensorflow::gtl::FlatMap total_sizes; + for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + if (ListScheduler::IgnoreInstruction(*hlo)) { + extra_users[hlo] = 0; + total_sizes[hlo] = 0; + continue; + } + // This ordering is based on DFS post-order, with a heuristic to decide + // which operand to visit first. The heuristic is based on 'extra_users', + // which is simply users-1 for each instruction. By subtracting 1, we're + // saying that instructions with no users or a single user don't count; + // instructions with lots of fan-out will be visited earlier. + extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; + int64 logical_buffer_size = SumLogicalBufferSizes( + points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); + total_sizes[hlo] = logical_buffer_size; + cumulative_total_size += logical_buffer_size; + tensorflow::gtl::FlatSet unique_operands( + hlo->operands().begin(), hlo->operands().end()); + for (const HloInstruction* operand : unique_operands) { + extra_users[hlo] += extra_users[operand]; + total_sizes[hlo] += total_sizes[operand]; + } + // total_sizes[hlo] transitively includes the sizes of all nodes that + // lead to it. But computation is a DAG, so we are double-counting nodes, + // which can lead to overflows for large programs. + // cumulative_total_size caps the size to prevent overflows. + // Same for total_hlos: it prevents overflows on very large and branchy + // models, where the number of paths is exponential to the number of nodes. + // NOTE(dimvar): this is quite ugly and should be changed. It's unclear + // why we care about transitive sizes; when scheduling a node, its input + // and output buffers should be all that matters, not its "history". + total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); + extra_users[hlo] = std::min(extra_users[hlo], total_hlos); + } + CHECK_EQ(extra_users.size(), computation.instruction_count()); + CHECK_EQ(total_sizes.size(), computation.instruction_count()); + + // Construct a total order based on DFS post-order, visiting operands in + // decreasing cumulative extra user order, and next by cumulative size, with a + // tiebreaker by name for determinism. + HloInstructionSequence sequence; + FunctionVisitor visitor([&sequence](HloInstruction* hlo) { + sequence.push_back(hlo); + return Status::OK(); + }); + TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + &visitor, [&extra_users, &total_sizes](const HloInstruction* a, + const HloInstruction* b) { + if (extra_users[a] != extra_users[b]) { + return extra_users[a] > extra_users[b]; + } + if (total_sizes[a] != total_sizes[b]) { + return total_sizes[a] > total_sizes[b]; + } + return a->name() < b->name(); + })); + CHECK_EQ(sequence.size(), computation.instruction_count()); + return sequence; +} // namespace xla + +StatusOr ListMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + return ListScheduler::Run(computation, points_to_analysis, size_function, + memory_by_computation); +} + +StatusOr PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + return HloInstructionSequence(computation.MakeInstructionPostOrder()); +} + +StatusOr DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We try a few schedulers and choose whichever returns a lower min-memory, + // not accounting for fragmentation. + // - List is a scheduler that uses greedy heuristics. + // - DFS visits HLOs in postorder, with a heuristic to decide the order of + // children. + // - Postorder does not use any heuristics. + // List wins for most of our benchmarks; postorder-based schedulers win for + // some RNNs. + TF_ASSIGN_OR_RETURN( + HloInstructionSequence list_sequence, + ListMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); + TF_ASSIGN_OR_RETURN(const int64 list_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, list_sequence, points_to_analysis, + size_function, &memory_by_computation)); + VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); + + TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, + DFSMemoryScheduler(computation, points_to_analysis, + size_function, memory_by_computation)); + TF_ASSIGN_OR_RETURN(const int64 dfs_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, dfs_sequence, points_to_analysis, + size_function, &memory_by_computation)); + VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); + + TF_ASSIGN_OR_RETURN( + HloInstructionSequence post_order_sequence, + PostOrderMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); + TF_ASSIGN_OR_RETURN(const int64 post_order_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, post_order_sequence, points_to_analysis, + size_function, &memory_by_computation)); + VLOG(2) << "Min-memory post order sequence: " + << HumanReadableNumBytes(post_order_memory); + + auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); + + if (min_memory == list_memory) { + VLOG(2) << "Chose min-memory list sequence: " + << HumanReadableNumBytes(list_memory); + return list_sequence; + } else if (min_memory == dfs_memory) { + VLOG(2) << "Chose min-memory dfs sequence: " + << HumanReadableNumBytes(dfs_memory); + return dfs_sequence; + } else { + VLOG(2) << "Chose min-memory post_order sequence: " + << HumanReadableNumBytes(post_order_memory); + return post_order_sequence; + } +} + +StatusOr ScheduleModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { + HloSchedule schedule(&module); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(&module)); + tensorflow::gtl::FlatMap memory_by_computation; + for (const auto* computation : module.MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, + ScheduleComputationHelper( + *computation, *points_to_analysis, size_function, + algorithm, memory_by_computation)); + memory_by_computation[computation] = + HeapSimulator::MinimumMemoryForComputation( + *computation, computation_sequence, *points_to_analysis, + size_function, &memory_by_computation) + .ValueOrDie(); + schedule.set_sequence(computation, std::move(computation_sequence)); + } + } + VLOG(1) << "Module schedule:\n" << schedule; + + TF_RETURN_IF_ERROR(schedule.Verify()); + + return std::move(schedule); +} + +StatusOr ScheduleComputation( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function) { + CHECK(!computation.IsFusionComputation()); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(computation.parent())); + tensorflow::gtl::FlatMap empty_map; + return ScheduleComputationHelper(computation, *points_to_analysis, + size_function, nullptr, empty_map); +} + +HloMemoryScheduler::HloMemoryScheduler( + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) + : size_function_(size_function), algorithm_(algorithm) {} + +StatusOr HloMemoryScheduler::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, size_function_, algorithm_)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; +} + +StatusOr HloDescheduler::Run(HloModule* module) { + bool changed = module->has_schedule(); + module->clear_schedule(); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h new file mode 100644 index 0000000000..5e02868eba --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -0,0 +1,123 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// A memory scheduler computes an execution sequence for the HLO instructions in +// 'computation' that minimizes peak memory, given a points-to analysis result +// that describes buffer aliasing, together with a target-specific size function +// that maps a tensor's logical size to its padded size. +typedef std::function( + const HloComputation&, const TuplePointsToAnalysis&, + const LogicalBuffer::SizeFunction&, + const tensorflow::gtl::FlatMap&)> + MemorySchedulerAlgorithm; + +// List scheduler +StatusOr ListMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// DFS-order scheduler +StatusOr DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// Naive Post Order scheduler +StatusOr PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// The default scheduling algorithm. Runs both the list scheduler +// and the DFS scheduler, and chooses whichever returns a lower min-memory, +// not accounting for fragmentation. +StatusOr DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// Returns an HloSchedule which seeks to minimize the memory required for +// the computation. size_function is the function returning the number of bytes +// required for a LogicalBuffer. +StatusOr ScheduleModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); + +// Computes the schedule for a single computation. +// Currently only used by the GPU backend. +StatusOr ScheduleComputation( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function); + +// A pass which schedules the HLO instructions in a module. The HloModule's +// schedule field is set to the resulting HloSchedule using +// HloModule::set_schedule. +class HloMemoryScheduler : public HloPassInterface { + public: + // size_function is the function returning the number of bytes required for a + // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not + // specified, then DefaultMemoryScheduler is used. + HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); + ~HloMemoryScheduler() override = default; + absl::string_view name() const override { return "hlo-memory-scheduler"; } + + StatusOr Run(HloModule* module) override; + + private: + LogicalBuffer::SizeFunction size_function_; + MemorySchedulerAlgorithm algorithm_; +}; + +// A trivial pass which clears the schedule currently set on the +// HloModule. After this pass runs HloModudle::has_schedule will return false. +class HloDescheduler : public HloPassInterface { + public: + HloDescheduler() = default; + ~HloDescheduler() override = default; + absl::string_view name() const override { return "hlo-descheduler"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc new file mode 100644 index 0000000000..1b9e9bfc77 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -0,0 +1,432 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloSchedulingTest : public HloTestBase {}; + +TEST_F(HloSchedulingTest, LastUseScheduledFirst) { + // Tests scheduling of the following HLO code: + // + // %ab = abs(%param) + // %exp = exp(%param) + // %add = add(%ab, %exp) + // %negate = negate(%exp) + // %sub = subtract(%add, %negate) + // + // %add should be scheduled before %negate because %add is the last (and only) + // use of %ab. Scheduling %add first then frees up %ab's buffer. + const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); + auto ab = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + HloMemoryScheduler scheduler([](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + + // Verify that all instructions are in the sequence. + const std::vector& sequence = + module->schedule().sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); + + // The first instruction should be the parameter and the last the root "sub". + EXPECT_EQ(param, sequence.front()); + EXPECT_EQ(sub, sequence.back()); + + SequentialHloOrdering ordering(module->schedule()); + EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); + + // Clear the schedule using the descheduling pass. + HloDescheduler descheduler; + EXPECT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed, + descheduler.Run(module.get())); + EXPECT_TRUE(descheduler_changed); + EXPECT_FALSE(module->has_schedule()); +} + +TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { + const char* module_str = R"( +HloModule test_aliasing_module + +ENTRY root { + param = s32[1000] parameter(0) + p0 = s32[1000] copy(param) + p1 = s32[1000] copy(param) + t = (s32[1000], s32[1000]) tuple(p0, p1) + a = s32[1000] get-tuple-element(t), index=0 + b = s32[1000] get-tuple-element(t), index=1 + c = s32[1000] add(a, b) + d = s32[1000] add(c, b) + e = s32[1000] add(c, c) + f = s32[1000] add(e, e) + ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); + + std::unordered_map instructions_by_name; + for (const HloInstruction* instruction : sequence) { + instructions_by_name[instruction->name()] = instruction; + } + + // The first instruction should be the parameter and the last the root. + EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); + EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); + + // Instructions "d" and "e" will both be schedulable at the same time, but + // instruction "d" allows us to free the buffer of "p1", so the list scheduler + // should prefer it. + SequentialHloOrdering ordering(schedule); + EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), + instructions_by_name.at("e"))); +} + +TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) + // ROOT %not-equal-to = pred[] not-equal-to( + // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) + // } + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // ROOT %subtract = f32[4]{0} subtract( + // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) + // } + // %ListAccountsForSubcomputations () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant( + // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) + // %transpose = f32[2,4]{1,0} transpose( + // f32[2,4]{1,0} %constant.3), dimensions={0,1} + // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), + // condition=%WhileCond, + // body=%WhileBody + // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} + // ROOT %add = f32[2,4]{1,0} add( + // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + schedule.sequence(entry_computation).size()); + SequentialHloOrdering ordering(schedule); + // This schedule is an example of List's greedy heuristics being suboptimal. + // The while_loop is more expensive than transpose, so it would have been + // better to schedule it first, instead of during the busy time. + EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); + EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + +TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { + auto builder = HloComputation::Builder(TestName()); + const auto TUPLE_SIZE = 1; + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); + + // Wrap lit in abs because constants are considered free by + // IgnoreInstruction, and it skews the accounting. + auto lit = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1, 1, 1}))); + auto abs_const = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); + + auto abs_abs1 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( + absl::Span({abs_abs1}))); + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto abs_abs2 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, + tuple_elm, abs_abs2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); + // tuple allocates the tuple buffer and doesn't free anything. + // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. + // abs_abs2 should be scheduled before tuple by List. + EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); +} + +TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); + HloComputation::Builder builder(TestName()); + + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1, 1}))); + auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4, 5}))); + auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({0, 2, 4, 6, 8}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); + + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); + + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto fusion = computation->CreateFusionInstruction( + {tuple, mul, add}, HloInstruction::FusionKind::kLoop); + + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), 2); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); + // fusion allocates memory for the tuple elements and doesn't free anything, + // so it's more expensive than exp. + EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); +} + +TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + auto entry_computation = module->entry_computation(); + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations. Cond is the largest one. + // The output buffer of the while is aliased. + EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 6b6005e7a5..00970bcda3 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 0a0a6a323e..bd6dd79b67 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -27,15 +27,14 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -1194,51 +1193,12 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run(HloModule* module, - HloSchedule* schedule, - int64 memory_limit_bytes, - RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The schedule is constructed entirely by this method. - TF_RET_CHECK(schedule->empty()); - +StatusOr HloRematerialization::Run(HloModule* module) { VLOG(1) << "HloRematerialization() with memory limit of " - << HumanReadableNumBytes(memory_limit_bytes); + << HumanReadableNumBytes(memory_limit_bytes_); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial schedule of HLO instructions. - TF_ASSIGN_OR_RETURN(*schedule, - ScheduleModule(*module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - SequentialHloOrdering ordering(*schedule); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - - // RemoveUnnecessaryCopies only considers interference when determining - // whether it is legal to remove a copy. However, copies in the graph may be - // necessary for other reason such as preventing a constant from being live - // out of the graph. So run AddSpecialCaseCopies to re-insert these copies. - // TODO(b/80249101): Break copy insertion into several passes and run each - // one once in the regular HLO pipeline. - TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module)); - - // The passes above can add and remove copies, update the schedule to - // account for these transformations. Newly added instructions will be - // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(schedule->Update()); - - TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(*schedule), module)); - } - + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); // Adjust memory limit to account for the output of the entry @@ -1254,7 +1214,7 @@ StatusOr HloRematerialization::Run(HloModule* module, }); const int64 adjusted_memory_limit_bytes = - memory_limit_bytes - module_output_size; + memory_limit_bytes_ - module_output_size; VLOG(1) << "Adjusted memory limit accounting for output (" << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); @@ -1263,13 +1223,14 @@ StatusOr HloRematerialization::Run(HloModule* module, // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, schedule](const CallGraphNode& node) -> Status { + [this, module](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory( - node.computation(), - schedule->sequence(node.computation()).instructions())); + ComputePeakMemory(node.computation(), + module->schedule() + .sequence(node.computation()) + .instructions())); } return Status::OK(); }, @@ -1287,9 +1248,10 @@ StatusOr HloRematerialization::Run(HloModule* module, // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), schedule, - adjusted_memory_limit_bytes)); + TF_ASSIGN_OR_RETURN( + bool changed, + RematerializeComputation(module->entry_computation(), &module->schedule(), + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -1298,7 +1260,7 @@ StatusOr HloRematerialization::Run(HloModule* module, // After DCE, the module sequence may include instructions which no longer // exist. - TF_RETURN_IF_ERROR(schedule->Update()); + TF_RETURN_IF_ERROR(module->schedule().Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1315,32 +1277,22 @@ StatusOr HloRematerialization::Run(HloModule* module, << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; - if (sizes != nullptr) { - sizes->before_bytes = before_peak_memory; - sizes->after_bytes = current_peak_memory; + if (sizes_ != nullptr) { + sizes_->before_bytes = before_peak_memory; + sizes_->after_bytes = current_peak_memory; } XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); - if (current_peak_memory > memory_limit_bytes) { + if (current_peak_memory > memory_limit_bytes_) { LOG(WARNING) << absl::StrFormat( "Can't reduce memory use below %s (%d bytes) by rematerialization; " "only reduced to %s (%d bytes)", - HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes, + HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; } -/* static */ StatusOr HloRematerialization::RematerializeAndSchedule( - const HloRematerialization::ShapeSizeFunction& size_function, - int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule, - RematerializationSizes* sizes, CopyInsertion* copy_insertion) { - HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes, - copy_insertion); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index fa0414b472..e2aaf18b3e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -17,17 +17,23 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { -class HloRematerialization { +// HLO pass which rematerializes instructions to reduce peak memory use, where +// memory use is defined as the total size of all live HLO instruction +// values. Parameters and constants are included in memory use estimates. +// +// CSE will undo the effects of this optimization and should not be run after +// this pass. In general, this pass should be run very late, immediately before +// code generation. +class HloRematerialization : public HloPassInterface { public: using ShapeSizeFunction = std::function; @@ -38,10 +44,7 @@ class HloRematerialization { int64 after_bytes; }; - // Rematerialize HLO instructions in the given module to reduce peak memory - // use below memory_limit_bytes where memory use is defined as the total size - // of all live HLO instruction values. Parameters and constants are included - // in memory use estimates. Method parameters: + // Constructor parameters: // // size_function: Function which returns the size in bytes of the top-level // buffer of the given shape. @@ -49,51 +52,27 @@ class HloRematerialization { // memory_limit_bytes: The threshold number of bytes to reduce memory use to // via rematerialization. // - // hlo_module: HLO module to rematerialize instructions in. - // - // schedule: Should point to an empty HloSchedule. Upon return - // contains the HLO instruction order which was used for - // rematerialization. This is the order in which HLO instructions should - // be emitted to minimize memory use. - // - // sizes: Optional outparam that indicates the peak memory usage of the HLO - // module before/after rematerialization. - // - // copy_insertion: If non-null, run copy elision after scheduling. This - // pass is used to eliminate copies that were inserted by copy insertion - // before HLO scheduling. - // - // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy - // insertion is integrated with HLO scheduling. - // - // Returns whether any instructions were rematerialized. If memory use is - // already below the given limit then no instructions are rematerialized and - // false is returned. - // - // CSE will undo the effects of this optimization and should not be run after - // this pass. In general, this pass should be run very late immediately before - // code generation. - static StatusOr RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - HloSchedule* schedule, RematerializationSizes* sizes, - CopyInsertion* copy_insertion = nullptr); - - protected: - HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, - const ShapeSizeFunction& size_function) - : scheduler_algorithm_(scheduler_algorithm), - size_function_(size_function) {} + // sizes: Pointer to data structure which records the peak memory usage of + // the HLO module before/after rematerialization. Value are set during + // Run(). Can be nullptr. + HloRematerialization(const ShapeSizeFunction& size_function, + int64 memory_limit_bytes, RematerializationSizes* sizes) + : size_function_(size_function), + memory_limit_bytes_(memory_limit_bytes), + sizes_(sizes) {} ~HloRematerialization() {} + absl::string_view name() const override { return "rematerialization"; } + // Runs rematerialization on the given module. Returns whether the module was - // changed. memory_limit is the target maximum peak memory usage by the - // module. schedule should be an empty HloSchedule. Upon return sequence - // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr Run(HloModule* module, HloSchedule* schedule, - int64 memory_limit, RematerializationSizes* sizes, - CopyInsertion* copy_insertion); + // changed. Requires that the module has a schedule set + // (HloModule::has_schedule() is true) before running. Returns whether any + // instructions were rematerialized. If memory use is already below the limit + // specified in the constructor then no instructions are rematerialized and + // false is returned. + StatusOr Run(HloModule* module) override; + protected: // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation @@ -121,6 +100,14 @@ class HloRematerialization { // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; + // The threshold number of bytes to reduce memory use to via + // rematerialization. + const int64 memory_limit_bytes_; + + // Pointer to data structure which records the peak memory usage of the HLO + // module before/after rematerialization + RematerializationSizes* sizes_; + // Call graph of the hlo_module. std::unique_ptr call_graph_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83cb113bfb..4b611fe450 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase { } StatusOr RunHloRematerialization(int64 memory_limit_bytes, - HloModule* module, - HloSchedule* schedule) { + HloModule* module) { TF_EXPECT_OK(verifier().Run(module).status()); - return HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - schedule, /*sizes=*/nullptr); + HloMemoryScheduler scheduler( + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, + DefaultMemoryScheduler); + TF_EXPECT_OK(scheduler.Run(module).status()); + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, + /*sizes=*/nullptr); + return remat.Run(module); } // Various shapes used in the canned computations. @@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - HloSchedule schedule(module.get()); // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module.get())); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(schedule.sequence(computation) + EXPECT_EQ(module->schedule() + .sequence(computation) .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(schedule.sequence(computation) + EXPECT_EQ(module->schedule() + .sequence(computation) .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, module.get())); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, module.get())); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, module.get())); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - HloSchedule schedule(module.get()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, module.get())); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &schedule)); + bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get())); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module.get())); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module.get())); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index eb52582bb5..1424569ac1 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc deleted file mode 100644 index 9bfb0af96c..0000000000 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ /dev/null @@ -1,585 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/heap_simulator.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace { - -using ::tensorflow::strings::HumanReadableNumBytes; - -// Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage by preferring to schedule the node that -// frees bigger buffer and defines smaller outputs. -// -// Note that list scheduler is a greedy algorithm which cannot guarantee a -// global optimal solution. As a counterexample, considering the following -// graph: -// -// +--> B ===> C -------+ -// A -> | | -// | v -// +--> D ---> F=======>G -// | ^ -// | | -// +--> E -----+ -// -// --> : Buffer with size 1 -// ==> : Buffer with size 2 -// -// The list scheduler will always try to defer scheduling B in a greedy way -// since its output buffer is bigger than input. The sequence it creates will -// be: -// A D E F B C G -// , which has a maximum memory usage of 6 (B is alive while F is executing). -// -// An optimal way to shedule the previous graph is: -// A B C D E F G -// , which has a maximum memory usage of 5 (when F is executing). -// -class ListScheduler { - public: - // Construct and return a memory-minimizing sequence of HLO instructions - // containing the given HLO computation. - static StatusOr Run( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation) { - ListScheduler scheduler(computation, points_to_analysis, size_function, - memory_by_computation); - return scheduler.CreateSchedule(); - } - - // Returns whether the memory used by the given HLO should be ignored by the - // scheduling heuristic. - static bool IgnoreInstruction(const HloInstruction& instruction) { - return instruction.opcode() == HloOpcode::kParameter || - instruction.opcode() == HloOpcode::kConstant; - } - - private: - // The scheduling priority of an instruction is first the number of bytes - // freed by scheduling the instruction, and second (tie-breaker) by the number - // of users. This is represented as a std::pair containing these two values - // (first element is the bytes freed). std::pair provides the necessary - // comparison operators. - using Priority = std::pair; - - ListScheduler(const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation) - : computation_(computation), - points_to_analysis_(points_to_analysis), - size_function_(size_function), - memory_by_computation_(memory_by_computation) { - // Create a map containing the LogicalBuffer uses for each HLO - // instruction. An HLO instruction "uses" a LogicalBuffer if the - // LogicalBuffer is in an operand of the instruction as indicated by - // points-to analysis. - for (auto* instruction : computation.instructions()) { - tensorflow::gtl::FlatSet instr_uses; - for (auto* operand : instruction->operands()) { - points_to_analysis.GetPointsToSet(operand).ForEachElement( - [&](const ShapeIndex& /*index*/, - const PointsToSet::BufferList& buffers) { - instr_uses.insert(buffers.begin(), buffers.end()); - }); - } - buffer_uses_[instruction] = std::vector( - instr_uses.begin(), instr_uses.end()); - } - - // Create map containing the number of unscheduled uses (hlo instructions) - // of each logical buffer. - for (auto* instruction : computation.instructions()) { - for (auto* buffer : - points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { - unscheduled_use_count_[buffer] = 0; - } - } - for (auto* instruction : computation.instructions()) { - for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { - ++unscheduled_use_count_[buffer]; - } - } - - // Buffers live out of the computation have an implicit use at the end of - // the computation. - for (const LogicalBuffer* live_out_buffer : - points_to_analysis.GetPointsToSet(computation.root_instruction()) - .CreateFlattenedSet()) { - ++unscheduled_use_count_[live_out_buffer]; - } - } - - // Returns whether the memory used by the given buffer should be ignored by - // the scheduling heuristic. - static bool IgnoreBuffer(const LogicalBuffer& buffer) { - return IgnoreInstruction(*buffer.instruction()); - } - - // An entry in the worklist used by CreateSchedule. Corresponds to one - // HloInstruction, plus some cached metadata, saved for the purposes of making - // BytesFreedIfScheduled fast. - struct ReadyListEntry { - const HloInstruction* instruction; - - // The total size of all buffers defined by this instruction. - int64 bytes_defined; - - // For each buffer B used by this instruction, we keep a pair (B, U), where - // U is the number of uses of B that have not yet been scheduled. This pair - // is a pointer into the unscheduled_use_count_ map, so it gets updated for - // free when we update counts in the map. - std::vector*> - used_buffer_unscheduled_use_counts; - }; - - // Creates a ReadyListEntry for the given instruction. - ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { - ReadyListEntry entry; - entry.instruction = instruction; - - entry.bytes_defined = 0; - for (auto* buffer : - points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { - if (!IgnoreBuffer(*buffer)) { - entry.bytes_defined += size_function_(*buffer); - } - } - - for (auto* buffer : buffer_uses_.at(instruction)) { - if (IgnoreBuffer(*buffer)) { - continue; - } - auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer); - CHECK(unscheduled_use_count_it != unscheduled_use_count_.end()); - entry.used_buffer_unscheduled_use_counts.push_back( - &*unscheduled_use_count_it); - } - return entry; - } - - // Returns the number of bytes freed if the HLO instruction is scheduled. - // If the instruction calls subcomputations, we count the memory used by the - // subcomputations as memory "defined" by the instruction. This is not - // entirely accurate, because subcomputation memory will be freed after the - // instruction finishes. But it is more accurate than not taking - // subcomputations into account at all. In the future, we may improve - // accounting for subcomputation memory (b/65409243). - int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { - int64 freed_bytes = 0; - for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { - auto buffer = kv->first; - auto use_count = kv->second; - if (use_count == 1) { - freed_bytes += size_function_(*buffer); - } - } - // We only count the memory usage of the largest subcomputation, instead of - // adding them all, because subcomputations won't execute in parallel. - int64 max_subcomputation_bytes = 0; - for (const auto* c : entry.instruction->called_computations()) { - auto it = memory_by_computation_.find(c); - if (it != memory_by_computation_.end()) { - int64 subcomputation_bytes = it->second; - if (subcomputation_bytes > max_subcomputation_bytes) { - max_subcomputation_bytes = subcomputation_bytes; - } - } - } - return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; - } - - // Constructs the scheduling priority of the given instruction. - Priority GetPriority(const ReadyListEntry& entry) { - return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; - } - - HloInstructionSequence CreateSchedule() { - HloInstructionSequence schedule; - - // Populate the ready list with instructions which have no operands or - // control predecessors. - tensorflow::gtl::FlatMap - unscheduled_pred_count; - for (auto* instruction : computation_.instructions()) { - // TODO(b/34466113): Replace this and above with successors() or - // predecessors() when these methods are added to HloInstruction. - for (const HloInstruction* user : instruction->users()) { - unscheduled_pred_count[user]++; - } - for (const HloInstruction* succ : instruction->control_successors()) { - unscheduled_pred_count[succ]++; - } - } - - // Use a multimap to sort ReadyListEntry according to their priority. - std::multimap ready_queue; - - // Map of ready instructions to their iterators in ready_queue. - tensorflow::gtl::FlatMap::iterator> - ready_instructions; - - auto add_to_ready_queue = [&](HloInstruction* inst) { - auto entry = MakeReadyListEntry(inst); - auto it = ready_queue.emplace(GetPriority(entry), std::move(entry)); - ready_instructions[inst] = it; - }; - - for (auto* instruction : computation_.instructions()) { - // Instruction with no operands or control predecessors will - // not be in the map. - if (unscheduled_pred_count.count(instruction) == 0) { - add_to_ready_queue(instruction); - } - } - - while (!ready_queue.empty()) { - // Remove the selected instruction from the ready list and add it to the - // schedule. - auto best_it = ready_queue.end(); - --best_it; - const HloInstruction* best = best_it->second.instruction; - VLOG(2) << "Schedule instruction: " << best->ToShortString() - << " Bytes freed: " << best_it->first.first; - ready_queue.erase(best_it); - ready_instructions.erase(best); - schedule.push_back(best); - scheduled_instructions_.insert(best); - - bool adjust_ready_queue = false; - // Update the unscheduled uses of the logical buffers. - for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { - int64& count = unscheduled_use_count_[buffer]; - CHECK_GT(count, 0); - --count; - if (count == 1) { - adjust_ready_queue = true; - } - } - - // Add new instructions to ready list. - auto update_pred_count = [&](HloInstruction* inst) { - int64 pred_count = --unscheduled_pred_count.at(inst); - CHECK_GE(pred_count, 0); - if (pred_count == 0) { - add_to_ready_queue(inst); - } - }; - // TODO(b/34466113): Replace this and above with successors() or - // predecessors() when these methods are added to HloInstruction. - for (HloInstruction* user : best->users()) { - update_pred_count(user); - } - for (HloInstruction* succ : best->control_successors()) { - update_pred_count(succ); - } - // The unscheduled use count for a buffer has changed to 1, so the - // priorities of some ready instructions may go up. We update them in the - // ready queue, so that they can appear earlier. - if (adjust_ready_queue) { - for (HloInstruction* operand : best->operands()) { - for (HloInstruction* operand_user : operand->users()) { - auto ready_instructions_it = ready_instructions.find(operand_user); - if (ready_instructions_it == ready_instructions.end()) { - continue; - } - auto ready_queue_it = ready_instructions_it->second; - auto& entry = ready_queue_it->second; - Priority new_priority = GetPriority(entry); - if (new_priority == ready_queue_it->first) { - continue; - } - // Create a new entry in ready_queue, then update - // ready_instructions[operand_user] to refer to the new entry. - ready_instructions_it->second = - ready_queue.emplace(new_priority, std::move(entry)); - // Remove the old entry in ready_queue. - ready_queue.erase(ready_queue_it); - } - } - } - } - CHECK_EQ(schedule.size(), computation_.instruction_count()); - CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); - - return schedule; - } - - const HloComputation& computation_; - const TuplePointsToAnalysis& points_to_analysis_; - const LogicalBuffer::SizeFunction& size_function_; - // Computations are analyzed in post-order. When scheduling an instruction - // that includes subcomputations, such as a while loop, we use this map to - // look up the memory needed by subcomputations. - const tensorflow::gtl::FlatMap& - memory_by_computation_; - - // A map containing the LogicalBuffers that each instruction uses. - tensorflow::gtl::FlatMap> - buffer_uses_; - - // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map, and that the map - // entries are std::pair's. - std::unordered_map unscheduled_use_count_; - - // Set of instructions which have been scheduled. - tensorflow::gtl::FlatSet scheduled_instructions_; -}; - -int64 SumLogicalBufferSizes( - const TuplePointsToAnalysis::BufferDefinitionVector& buffers, - const LogicalBuffer::SizeFunction& size_function) { - int64 size = 0; - for (const LogicalBuffer* buffer : buffers) { - size += size_function(*buffer); - } - return size; -} - -StatusOr ScheduleComputationHelper( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm, - const tensorflow::gtl::FlatMap& - memory_by_computation) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm) { - return algorithm(computation, points_to_analysis, size_function, - memory_by_computation); - } - return DefaultMemoryScheduler(computation, points_to_analysis, size_function, - memory_by_computation); -} - -} // namespace - -StatusOr DFSMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation) { - // These variables are a hack to prevent overflows. - int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); - tensorflow::gtl::FlatMap extra_users; - tensorflow::gtl::FlatMap total_sizes; - for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { - if (ListScheduler::IgnoreInstruction(*hlo)) { - extra_users[hlo] = 0; - total_sizes[hlo] = 0; - continue; - } - // This ordering is based on DFS post-order, with a heuristic to decide - // which operand to visit first. The heuristic is based on 'extra_users', - // which is simply users-1 for each instruction. By subtracting 1, we're - // saying that instructions with no users or a single user don't count; - // instructions with lots of fan-out will be visited earlier. - extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; - int64 logical_buffer_size = SumLogicalBufferSizes( - points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); - total_sizes[hlo] = logical_buffer_size; - cumulative_total_size += logical_buffer_size; - tensorflow::gtl::FlatSet unique_operands( - hlo->operands().begin(), hlo->operands().end()); - for (const HloInstruction* operand : unique_operands) { - extra_users[hlo] += extra_users[operand]; - total_sizes[hlo] += total_sizes[operand]; - } - // total_sizes[hlo] transitively includes the sizes of all nodes that - // lead to it. But computation is a DAG, so we are double-counting nodes, - // which can lead to overflows for large programs. - // cumulative_total_size caps the size to prevent overflows. - // Same for total_hlos: it prevents overflows on very large and branchy - // models, where the number of paths is exponential to the number of nodes. - // NOTE(dimvar): this is quite ugly and should be changed. It's unclear - // why we care about transitive sizes; when scheduling a node, its input - // and output buffers should be all that matters, not its "history". - total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); - extra_users[hlo] = std::min(extra_users[hlo], total_hlos); - } - CHECK_EQ(extra_users.size(), computation.instruction_count()); - CHECK_EQ(total_sizes.size(), computation.instruction_count()); - - // Construct a total order based on DFS post-order, visiting operands in - // decreasing cumulative extra user order, and next by cumulative size, with a - // tiebreaker by name for determinism. - HloInstructionSequence sequence; - FunctionVisitor visitor([&sequence](HloInstruction* hlo) { - sequence.push_back(hlo); - return Status::OK(); - }); - TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( - &visitor, [&extra_users, &total_sizes](const HloInstruction* a, - const HloInstruction* b) { - if (extra_users[a] != extra_users[b]) { - return extra_users[a] > extra_users[b]; - } - if (total_sizes[a] != total_sizes[b]) { - return total_sizes[a] > total_sizes[b]; - } - return a->name() < b->name(); - })); - CHECK_EQ(sequence.size(), computation.instruction_count()); - return sequence; -} // namespace xla - -StatusOr ListMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation) { - return ListScheduler::Run(computation, points_to_analysis, size_function, - memory_by_computation); -} - -StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation) { - return HloInstructionSequence(computation.MakeInstructionPostOrder()); -} - -StatusOr DefaultMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation) { - // We try a few schedulers and choose whichever returns a lower min-memory, - // not accounting for fragmentation. - // - List is a scheduler that uses greedy heuristics. - // - DFS visits HLOs in postorder, with a heuristic to decide the order of - // children. - // - Postorder does not use any heuristics. - // List wins for most of our benchmarks; postorder-based schedulers win for - // some RNNs. - TF_ASSIGN_OR_RETURN( - HloInstructionSequence list_sequence, - ListMemoryScheduler(computation, points_to_analysis, size_function, - memory_by_computation)); - TF_ASSIGN_OR_RETURN(const int64 list_memory, - HeapSimulator::MinimumMemoryForComputation( - computation, list_sequence, points_to_analysis, - size_function, &memory_by_computation)); - VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - - TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, - DFSMemoryScheduler(computation, points_to_analysis, - size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN(const int64 dfs_memory, - HeapSimulator::MinimumMemoryForComputation( - computation, dfs_sequence, points_to_analysis, - size_function, &memory_by_computation)); - VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); - - TF_ASSIGN_OR_RETURN( - HloInstructionSequence post_order_sequence, - PostOrderMemoryScheduler(computation, points_to_analysis, size_function, - memory_by_computation)); - TF_ASSIGN_OR_RETURN(const int64 post_order_memory, - HeapSimulator::MinimumMemoryForComputation( - computation, post_order_sequence, points_to_analysis, - size_function, &memory_by_computation)); - VLOG(2) << "Min-memory post order sequence: " - << HumanReadableNumBytes(post_order_memory); - - auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); - - if (min_memory == list_memory) { - VLOG(2) << "Chose min-memory list sequence: " - << HumanReadableNumBytes(list_memory); - return list_sequence; - } else if (min_memory == dfs_memory) { - VLOG(2) << "Chose min-memory dfs sequence: " - << HumanReadableNumBytes(dfs_memory); - return dfs_sequence; - } else { - VLOG(2) << "Chose min-memory post_order sequence: " - << HumanReadableNumBytes(post_order_memory); - return post_order_sequence; - } -} - -StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { - HloSchedule schedule(&module); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(&module)); - tensorflow::gtl::FlatMap memory_by_computation; - for (const auto* computation : module.MakeComputationPostOrder()) { - if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, - ScheduleComputationHelper( - *computation, *points_to_analysis, size_function, - algorithm, memory_by_computation)); - memory_by_computation[computation] = - HeapSimulator::MinimumMemoryForComputation( - *computation, computation_sequence, *points_to_analysis, - size_function, &memory_by_computation) - .ValueOrDie(); - schedule.set_sequence(computation, std::move(computation_sequence)); - } - } - VLOG(1) << "Module schedule:\n" << schedule; - - TF_RETURN_IF_ERROR(schedule.Verify()); - - return std::move(schedule); -} - -StatusOr ScheduleComputation( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function) { - CHECK(!computation.IsFusionComputation()); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation.parent())); - tensorflow::gtl::FlatMap empty_map; - return ScheduleComputationHelper(computation, *points_to_analysis, - size_function, nullptr, empty_map); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h deleted file mode 100644 index 54e32340ba..0000000000 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ - -#include - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { - -// A memory scheduler computes an execution sequence for the HLO instructions in -// 'computation' that minimizes peak memory, given a points-to analysis result -// that describes buffer aliasing, together with a target-specific size function -// that maps a tensor's logical size to its padded size. -typedef std::function( - const HloComputation&, const TuplePointsToAnalysis&, - const LogicalBuffer::SizeFunction&, - const tensorflow::gtl::FlatMap&)> - MemorySchedulerAlgorithm; - -// List scheduler -StatusOr ListMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - -// DFS-order scheduler -StatusOr DFSMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - -// Naive Post Order scheduler -StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - -// The default scheduling algorithm. Runs both the list scheduler -// and the DFS scheduler, and chooses whichever returns a lower min-memory, -// not accounting for fragmentation. -StatusOr DefaultMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - -// Returns an HloSchedule which seeks to minimize the memory required for -// the computation. size_function is the function returning the number of bytes -// required for a LogicalBuffer. -StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); - -// Computes the schedule for a single computation. -// Currently only used by the GPU backend. -StatusOr ScheduleComputation( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc deleted file mode 100644 index 6afe51997e..0000000000 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ /dev/null @@ -1,420 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" - -#include -#include - -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/xla/service/heap_simulator.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace xla { -namespace { - -class HloSchedulingTest : public HloTestBase {}; - -TEST_F(HloSchedulingTest, LastUseScheduledFirst) { - // Tests scheduling of the following HLO code: - // - // %ab = abs(%param) - // %exp = exp(%param) - // %add = add(%ab, %exp) - // %negate = negate(%exp) - // %sub = subtract(%add, %negate) - // - // %add should be scheduled before %negate because %add is the last (and only) - // use of %ab. Scheduling %add first then frees up %ab's buffer. - const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); - auto builder = HloComputation::Builder(TestName()); - auto param = - builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); - auto ab = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); - auto sub = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - // Verify that all instructions are in the sequence. - const std::vector& sequence = - schedule.sequence(module->entry_computation()).instructions(); - EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); - - // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.front()); - EXPECT_EQ(sub, sequence.back()); - - SequentialHloOrdering ordering(schedule); - EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); -} - -TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { - const char* module_str = R"( -HloModule test_aliasing_module - -ENTRY root { - param = s32[1000] parameter(0) - p0 = s32[1000] copy(param) - p1 = s32[1000] copy(param) - t = (s32[1000], s32[1000]) tuple(p0, p1) - a = s32[1000] get-tuple-element(t), index=0 - b = s32[1000] get-tuple-element(t), index=1 - c = s32[1000] add(a, b) - d = s32[1000] add(c, b) - e = s32[1000] add(c, c) - f = s32[1000] add(e, e) - ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f) -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - const std::vector& sequence = - schedule.sequence(module->entry_computation()).instructions(); - EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); - - std::unordered_map instructions_by_name; - for (const HloInstruction* instruction : sequence) { - instructions_by_name[instruction->name()] = instruction; - } - - // The first instruction should be the parameter and the last the root. - EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); - EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); - - // Instructions "d" and "e" will both be schedulable at the same time, but - // instruction "d" allows us to free the buffer of "p1", so the list scheduler - // should prefer it. - SequentialHloOrdering ordering(schedule); - EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), - instructions_by_name.at("e"))); -} - -TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { - // %WhileCond (cond_param: f32[4]) -> pred[] { - // %cond_param = f32[4]{0} parameter(0) - // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) - // ROOT %not-equal-to = pred[] not-equal-to( - // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) - // } - // %WhileBody (body_param: f32[4]) -> f32[4] { - // %body_param = f32[4]{0} parameter(0) - // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // ROOT %subtract = f32[4]{0} subtract( - // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) - // } - // %ListAccountsForSubcomputations () -> f32[2,4] { - // %constant.3 = f32[2,4]{1,0} constant( - // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) - // %transpose = f32[2,4]{1,0} transpose( - // f32[2,4]{1,0} %constant.3), dimensions={0,1} - // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), - // condition=%WhileCond, - // body=%WhileBody - // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} - // ROOT %add = f32[2,4]{1,0} add( - // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) - // } - - auto module = CreateNewModule(); - const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); - - // param != 0 - // Needs 17 bytes - auto cond_builder = HloComputation::Builder("WhileCond"); - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = - cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); - auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); - - // param - 1 - // Needs 16 bytes - auto body_builder = HloComputation::Builder("WhileBody"); - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = - body_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - body_builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kSubtract, body_param, one_vector)); - auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); - - // transpose(matrix) + bcast(while) - auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - // Creates 16 bytes, ignoring subcomputations - HloInstruction* while_loop = - builder.AddInstruction(HloInstruction::CreateWhile( - r1f32, cond_computation, body_computation, while_init)); - - // Creates 32 bytes and frees 16 - HloInstruction* bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); - - HloInstruction* matrix = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); - // Creates 32 bytes - HloInstruction* transpose = builder.AddInstruction( - HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); - - // Creates 32 bytes and frees 64 - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); - - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }; - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - schedule.sequence(entry_computation).size()); - SequentialHloOrdering ordering(schedule); - // This schedule is an example of List's greedy heuristics being suboptimal. - // The while_loop is more expensive than transpose, so it would have been - // better to schedule it first, instead of during the busy time. - EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); - EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); - - tensorflow::gtl::FlatMap memory_by_computation; - memory_by_computation[cond_computation] = 17; - memory_by_computation[body_computation] = 16; - std::unique_ptr points_to_analysis = - TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); - - // HeapSimulator doesn't account for subcomputations - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn) - .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The output buffer is aliased, - // so we don't double count. - EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn, &memory_by_computation) - .ValueOrDie()); -} - -TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { - auto builder = HloComputation::Builder(TestName()); - const auto TUPLE_SIZE = 1; - const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); - - // Wrap lit in abs because constants are considered free by - // IgnoreInstruction, and it skews the accounting. - auto lit = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 1, 1, 1, 1, 1}))); - auto abs_const = builder.AddInstruction( - HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); - - auto abs_abs1 = builder.AddInstruction( - HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); - auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( - absl::Span({abs_abs1}))); - auto tuple_elm = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); - - auto abs_abs2 = builder.AddInstruction( - HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); - - builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, - tuple_elm, abs_abs2)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); - - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - schedule.sequence(module->entry_computation()).size()); - SequentialHloOrdering ordering(schedule); - // tuple allocates the tuple buffer and doesn't free anything. - // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. - // abs_abs2 should be scheduled before tuple by List. - EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); -} - -TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { - const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); - HloComputation::Builder builder(TestName()); - - auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 1, 1, 1, 1}))); - auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 3, 4, 5}))); - auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({0, 2, 4, 6, 8}))); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); - auto mul = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); - auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); - - auto tuple_elm = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); - - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); - - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); - - auto module = CreateNewModule(); - auto* computation = module->AddEntryComputation(builder.Build()); - - auto fusion = computation->CreateFusionInstruction( - {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - - TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), 2); - }, - ListMemoryScheduler)); - - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - schedule.sequence(module->entry_computation()).size()); - SequentialHloOrdering ordering(schedule); - // fusion allocates memory for the tuple elements and doesn't free anything, - // so it's more expensive than exp. - EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); -} - -TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { - auto module = CreateNewModule(); - const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - - // param != 0 - // Needs 17 bytes - auto cond_builder = HloComputation::Builder("WhileCond"); - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = - cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); - auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); - - // param - 1 - // Needs 16 bytes - auto body_builder = HloComputation::Builder("WhileBody"); - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = - body_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - body_builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kSubtract, body_param, one_vector)); - auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - // Creates 16 bytes, ignoring subcomputations - builder.AddInstruction(HloInstruction::CreateWhile( - r1f32, cond_computation, body_computation, while_init)); - - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }; - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - auto entry_computation = module->entry_computation(); - EXPECT_EQ(module->entry_computation()->instruction_count(), - schedule.sequence(module->entry_computation()).size()); - - tensorflow::gtl::FlatMap memory_by_computation; - memory_by_computation[cond_computation] = 17; - memory_by_computation[body_computation] = 16; - std::unique_ptr points_to_analysis = - TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); - - // HeapSimulator doesn't account for subcomputations - EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn) - .ValueOrDie()); - // HeapSimulator accounts for subcomputations. Cond is the largest one. - // The output buffer of the while is aliased. - EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn, &memory_by_computation) - .ValueOrDie()); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 069586a738..50f39cbcb5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1123,6 +1123,11 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + // If the module has a schedule, it must be valid. + if (module->has_schedule()) { + TF_RETURN_IF_ERROR(module->schedule().Verify()); + } + return false; } -- cgit v1.2.3 From 31c1d228b15d6bcda2d6bd2172605d3a5f7d2be8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 8 Sep 2018 10:36:41 -0700 Subject: Avoid directly constructing vector iterators from pointers; that isn't part of their public API. PiperOrigin-RevId: 212123326 --- tensorflow/compiler/xla/shape_tree.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 52c895e8d4..df610102b4 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -224,14 +224,13 @@ class ShapeTree { // REQUIRES: index must exist in the ShapeTree. iterator find(ShapeIndexView index) { Node* element = Lookup(index); - return iterator(&nodes_, typename std::vector::iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.begin() + (element - &nodes_[0]); + return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); - return iterator(&nodes_, - typename std::vector::const_iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); + return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } // Returns the number of leaf nodes in the tree. -- cgit v1.2.3 From 1bf545492596f1d3dbaf1485de500116a2d2a25b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 8 Sep 2018 13:43:29 -0700 Subject: Deprecating the contrib.ffmpeg Python functions. PiperOrigin-RevId: 212132419 --- tensorflow/contrib/ffmpeg/ffmpeg_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index b1b5126d9e..45a67acb5b 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -24,11 +24,13 @@ from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader +from tensorflow.python.util.deprecation import deprecated _ffmpeg_so = loader.load_op_library( resource_loader.get_path_to_datafile('ffmpeg.so')) +@deprecated('2018-09-04', 'This will be deleted and should not be used.') def decode_audio(contents, file_format=None, samples_per_second=None, channel_count=None, stream=None): """Create an op that decodes the contents of an audio file. @@ -69,6 +71,7 @@ def decode_audio(contents, file_format=None, samples_per_second=None, ops.NotDifferentiable('DecodeAudio') +@deprecated('2018-09-04', 'This will be deleted and should not be used.') def encode_audio(audio, file_format=None, samples_per_second=None): """Creates an op that encodes an audio file using sampled audio from a tensor. @@ -95,6 +98,7 @@ def encode_audio(audio, file_format=None, samples_per_second=None): ops.NotDifferentiable('EncodeAudio') +@deprecated('2018-09-04', 'This will be deleted and should not be used.') def decode_video(contents): """Create an op that decodes the contents of a video file. -- cgit v1.2.3 From c50f1da063a7b6365542d923c4014e84515fe955 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 8 Sep 2018 23:43:35 +0000 Subject: Fix broken link in rnn_colorbot The README.md inside rnn_colorbot is broken, this fix fixes the link. Signed-off-by: Yong Tang --- tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md index fabd7b3e20..750bbc66f3 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md @@ -23,4 +23,4 @@ Attribution-ShareAlike License and is available at https://en.wikipedia.org/wiki/List_of_colors:_N-Z This example was adapted from - https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot + https://github.com/random-forests/tensorflow-workshop/tree/master/archive/extras/colorbot -- cgit v1.2.3 From a3776a234f555213aafcf41f49a42a8a9448c4ac Mon Sep 17 00:00:00 2001 From: Tong Shen Date: Sun, 9 Sep 2018 01:37:02 -0700 Subject: Move control flow functionalization as a graph optimization pass, instead of a step in XlaCompiler. PiperOrigin-RevId: 212164482 --- tensorflow/compiler/jit/BUILD | 1 + .../jit/jit_compilation_pass_registration.cc | 12 ++ tensorflow/compiler/tf2xla/BUILD | 18 ++- tensorflow/compiler/tf2xla/functionalize_cond.cc | 10 +- .../compiler/tf2xla/functionalize_control_flow.cc | 133 +++++++++++++++++++++ .../compiler/tf2xla/functionalize_control_flow.h | 13 ++ ...functionalize_control_flow_pass_registration.cc | 25 ++++ tensorflow/compiler/tf2xla/functionalize_while.cc | 25 +++- tensorflow/compiler/tf2xla/graph_compiler.cc | 1 - tensorflow/compiler/tf2xla/tf2xla.cc | 8 ++ tensorflow/compiler/tf2xla/tf2xla_util.cc | 102 ++++++++++++++++ tensorflow/compiler/tf2xla/tf2xla_util.h | 62 ++++++++++ tensorflow/compiler/tf2xla/xla_compiler.cc | 13 +- tensorflow/compiler/tf2xla/xla_compiler_test.cc | 17 --- tensorflow/core/framework/function.cc | 11 ++ tensorflow/core/framework/function.h | 4 + 16 files changed, 423 insertions(+), 32 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a989f15a1c..7d5db713f6 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -265,6 +265,7 @@ cc_library( srcs = ["jit_compilation_pass_registration.cc"], deps = [ ":compilation_passes", + "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index c37b6112cc..5dcf754969 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -21,6 +21,18 @@ limitations under the License. namespace tensorflow { +// PRE_PLACEMENT passes: + +// from +// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc +// FunctionalizeControlFlowPass: 27 +// +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (XlaIf/XlaWhile). Following passes must +// handle those FunctionDef correctly. + +// POST_REWRITE_FOR_EXEC passes: REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3821dced63..b28ffaf8a4 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -76,6 +76,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", @@ -188,7 +189,6 @@ cc_library( deps = [ ":common", ":dump_graph", - ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", ":side_effect_util", @@ -285,6 +285,7 @@ cc_library( deps = [ ":sharding_util", ":tf2xla_proto", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -480,6 +481,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -507,11 +509,23 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], ) +cc_library( + name = "functionalize_control_flow_pass_registration", + srcs = [ + "functionalize_control_flow_pass_registration.cc", + ], + deps = [ + ":functionalize_control_flow", + ], + alwayslink = 1, +) + cc_library( name = "functionalize_while", srcs = [ @@ -521,6 +535,7 @@ cc_library( "functionalize_while.h", ], deps = [ + ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", @@ -531,6 +546,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 0911550f1f..55439e77a6 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; @@ -642,7 +643,7 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If"); + NodeDefBuilder builder(name(), "If", library); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -1252,6 +1253,13 @@ Status FunctionalizeCond::FunctionalizeInternal() { std::vector switch_ids; std::vector merge_order; DFS(*graph_, nullptr, [&](Node* n) { + // Nodes marked with _xla_outside_compilation are skipped, because they need + // to be executed on host with regular TF executor, which does not support + // XlaIf/XlaWhile. + if (HasNodeAttr(n->def(), kXlaOutsideCompilationAttrName)) { + return; + } + if (IsSwitch(n)) { switch_ids.push_back(n->id()); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 5932be4e52..622767f68d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,11 +31,16 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -68,4 +73,132 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + std::map* canonicalized_name_to_new_name) { + // Convert the function to Graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + const FunctionDef& fdef = body->fdef; + + // If any node has associated functions, functionalize them first. + for (auto* n : body->graph->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, flr); + for (auto& associated_function : associated_functions) { + string name = associated_function.func_name(); + string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); + // If we already functionalized this function, skip it. + auto iter = canonicalized_name_to_new_name->find(canonicalized_name); + if (iter != canonicalized_name_to_new_name->end()) { + continue; + } + + string new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + // Notice that if "n" is a function call, RewriteAssociatedFunction() will + // delete it and create a new node instead, making "n" an invalid pointer. + // That's fine because in that case, associated_functions will only have + // one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + body->graph, n, fld, associated_function, new_name)); + } + } + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *body->graph, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), + *body->graph, fld); + } + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); + + // Copy signature and ret from original FunctionDef. + *functionalized_fdef.mutable_signature() = fdef.signature(); + *functionalized_fdef.mutable_ret() = fdef.ret(); + functionalized_fdef.mutable_signature()->set_name(new_func_name); + + // Add rewritten FunctionDef into library. + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + + return ret_status; +} + +Status FunctionalizeControlFlowPass::Run( + const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, + options.flib_def); + } + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, options.session_options->env, + TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + // Find XLA compile ops and its corresponding FunctionDef. + static std::map* kNodeTypeToFunctionAttrMapping = + new std::map{ + {"TPUCompile", "function"}, + {"XlaLaunch", "function"}, + }; + std::map canonicalized_name_to_new_name; + for (Node* n : graph->nodes()) { + auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); + if (it == kNodeTypeToFunctionAttrMapping->end()) { + continue; + } + const string func_attr = it->second; + if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != + kNodeTypeToFunctionAttrMapping->end()) { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name)); + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } + } + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, + options.flib_def); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 55600f2a8b..f1cbcdf617 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -32,6 +33,18 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (XlaIf/XlaWhile). +// +// Notice that control flow structure marked with _xla_outside_compilation are +// skipped, because they need to be executed on host with regular TF executor, +// which does not support XlaIf/XlaWhile. +class FunctionalizeControlFlowPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc new file mode 100644 index 0000000000..a10a9d0499 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +namespace tensorflow { + +// This pass is required for some AOT backends and all JIT backends, so this +// file exists as a separate lib and will be linked to both AOT and JIT. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, + FunctionalizeControlFlowPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 7f45e3bffa..f905c6a0fc 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace { @@ -473,12 +475,21 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } } - // Builds the condition and body functions. + // Builds the condition and body functions. Notice that we call + // FunctionalizeCond() on cond_graph and body_graph because we might have + // unfunctionalized "if" in cond_graph and body_graph. Functionalize them + // before they are encapsulated in FunctionDef. + // TODO(b/114485797): current logic does not functionalize while loop in + // another loop cond. std::unique_ptr cond_graph; TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + FixupSourceAndSinkEdges(cond_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); DataTypeVector arg_types; std::unique_ptr body_graph; TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + FixupSourceAndSinkEdges(body_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) @@ -510,7 +521,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Builds a While operator. NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile", library); builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); @@ -641,8 +652,14 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, continue; } - TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); + // Nodes marked with _xla_outside_compilation are skipped, because they need + // to be executed on host with regular TF executor, which does not support + // XlaIf/XlaWhile. + string name; + if (!HasNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName)) { + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); + } // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index bc2e640559..fa25a230b0 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 7dbe3a0b58..b22d53805d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), second_copy_def, g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); + + // Functionalize control flow. + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def)); + // After control flow functionalization, we might have more FunctionDef's + // (then/else branch, loop body). Add them to the graph. + TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); + *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 211caf8736..d6f42bac86 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -25,9 +25,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -75,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace +const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; + Status ValidateConfig(const tf2xla::Config& config) { std::set names; for (const tf2xla::Feed& feed : config.feed()) { @@ -323,4 +328,101 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } +// TODO(b/77601805): add tests for associated function related stuff. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr) { + if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { + return true; + } + + if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + return false; + } + + for (const auto& iter : node_def.attr()) { + if (iter.second.has_func()) { + return true; + } + } + + return false; +} + +std::vector GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr) { + std::vector results; + const string& op = node.type_string(); + if (flr->GetFunctionLibraryDefinition()->Contains(op)) { + // This is a function call node. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo(op, attrs)); + } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + } else { + // Collect all function attrs for the node. + for (auto& iter : node.attrs()) { + if (iter.second.has_func()) { + VLOG(2) << "Found function attr for node " << node.name() << ": " + << iter.first << " = " << iter.second.func().name(); + results.emplace_back(AssociatedFunctionInfo( + iter.second.func().name(), iter.second.func().attr(), iter.first)); + } + } + } + return results; +} + +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name) { + switch (associated_function.type()) { + case AssociatedFunctionInfo::kFunctionCallNode: { + // Change this node to call the new function. + NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + for (auto attr : node->attrs()) { + builder.Attr(attr.first, attr.second); + } + for (int i = 0; i < node->num_inputs(); i++) { + Node* input_node; + TF_RETURN_IF_ERROR(node->input_node(i, &input_node)); + builder.Input(input_node->name(), i, node->input_type(i)); + } + builder.Device(node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name()); + NodeDef node_def; + TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); + Status s; + Node* new_node = graph->AddNode(node_def, &s); + TF_RETURN_IF_ERROR(s); + for (auto edge : node->in_edges()) { + graph->AddEdge(edge->src(), edge->src_output(), new_node, + edge->dst_input()); + } + for (auto edge : node->out_edges()) { + graph->AddEdge(new_node, edge->src_output(), edge->dst(), + edge->dst_input()); + } + graph->RemoveNode(node); + break; + } + case AssociatedFunctionInfo::kFunctionAttr: { + // Change function attr to rewritten functions. + NameAttrList func; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), associated_function.attr_name(), &func)); + node->ClearAttr(associated_function.attr_name()); + func.set_name(rewritten_function_name); + node->AddAttr(associated_function.attr_name(), func); + break; + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index dcddef8418..41e70e0658 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -60,6 +61,67 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, // Returns the next random seed to use for seeding xla rng. uint32 GetXLARandomSeed(); +// Indicates how a FunctionDef is associated with a graph node (e.g. the node is +// a function call, or the node has function attrs). +class AssociatedFunctionInfo { + public: + enum AssociatedFunctionType { + kFunctionCallNode = 0, + kFunctionAttr = 1, + }; + + // The node is a function call. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) + : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} + + // The function is an attr of the node. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, + const string& attr_name) + : type_(kFunctionAttr), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + + AssociatedFunctionType type() const { return type_; } + + const string& func_name() const { return func_name_; } + + const string& attr_name() const { return attr_name_; } + + const AttrValueMap& attrs() const { return attrs_; } + + private: + // Available for all instances. + AssociatedFunctionType type_; + string func_name_; + AttrValueMap attrs_; + + // Only available if the function is defined in an attr. + string attr_name_; +}; + +// Returns if the NodeDef has associated function. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr); + +// Gets functions associated with the node. Current cases: +// 1. For function call node, its function name; +// 2. For nodes like XlaWhile/XlaIf, all their function attributes. +std::vector GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr); + +// Changes associated functions for the node. Current cases: +// 1. For function call node, creates a new node with the new function name and +// remove the old node; +// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name); + +// Attribute to mark nodes to be executed on host. +extern const char kXlaOutsideCompilationAttrName[]; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index dcb455779d..105f3b61d5 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" @@ -150,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), "Local lookup failed with: ", status.error_message()); + VLOG(4) << "Function " << function.name() << " in flib_runtime_"; + } else { + VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } return Status::OK(); } @@ -743,18 +745,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - absl::StrCat("xla_compile_graph_", name), *graph); + absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); } // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); - // Converts Tensorflow's graph control-flow constructs into functional - // control-flow that can be compiled into XLA code. - TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), - graph.get(), local_flib_def_.get())); - // Detect invalid nodes. // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 40ce9fb41c..42de6bacd6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -1255,25 +1255,8 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result); - ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) - << status.error_message(); - } - - // Fix control edges for NoOp. - { - std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); - CopyGraph(*graph, graph_copy.get()); - EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); - XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result)); - EXPECT_EQ(0, result.resource_updates.size()); } } diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 26f32677af..d979353d2f 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1154,6 +1154,17 @@ Status FunctionLibraryDefinition::LookUp( return default_registry_->LookUp(op, op_reg_data); } +string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { + tf_shared_lock l(mu_); + int index = 0; + string name = strings::StrCat(prefix, index); + while (function_defs_.find(name) != function_defs_.end()) { + ++index; + name = strings::StrCat(prefix, index); + } + return name; +} + const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( const NodeDef& ndef) const { if (ndef.op() != kGradientOp) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 03296a7761..e01eb7503d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -358,6 +358,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface { const OpRegistrationData** op_reg_data) const override LOCKS_EXCLUDED(mu_); + // Generates new function name with the specified prefix that is unique + // across this library. + string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_); + // Ops created for function arguments bear the name given by `kArgOp`; those // created for return values bear the name given by `kRetOp`. static constexpr const char* const kArgOp = "_Arg"; -- cgit v1.2.3 From b4d89565fcd73b4f2c4d6aa1ff159006795674b5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 9 Sep 2018 02:01:16 -0700 Subject: compat: Update forward compatibility horizon to 2018-09-09 PiperOrigin-RevId: 212165415 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index ca72cbac1a..5c50be2367 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 8) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 9) @tf_export("compat.forward_compatible") -- cgit v1.2.3 From d31f360e1574553ed23b8d483512a2065ac426eb Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Sun, 9 Sep 2018 07:18:09 -0700 Subject: Automated rollback of commit 39b2fb7cfef489424fead18ec5174d8e8b2a9a1a PiperOrigin-RevId: 212177437 --- tensorflow/python/data/util/nest.py | 33 +++++++++++++++++++++++++++++---- tensorflow/python/util/util.i | 27 --------------------------- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index 3a5d1f0adf..9d621fcd30 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -96,12 +96,37 @@ def _yield_value(iterable): yield value -# See the swig file (../../util/util.i) for documentation. -is_sequence = _pywrap_tensorflow.IsSequenceForData +def is_sequence(seq): + """Returns a true if `seq` is a Sequence or dict (except strings/lists). + NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`, + which *does* treat a Python list as a sequence. For ergonomic + reasons, `tf.data` users would prefer to treat lists as + implicit `tf.Tensor` objects, and dicts as (nested) sequences. -# See the swig file (../../util/util.i) for documentation. -flatten = _pywrap_tensorflow.FlattenForData + Args: + seq: an input sequence. + + Returns: + True if the sequence is a not a string or list and is a + collections.Sequence. + """ + return _pywrap_tensorflow.IsSequenceForData(seq) + + +def flatten(nest): + """Returns a flat sequence from a given nested structure. + + If `nest` is not a sequence, this returns a single-element list: `[nest]`. + + Args: + nest: an arbitrarily nested structure or a scalar object. + Note, numpy arrays are considered scalars. + + Returns: + A Python list, the flattened version of the input. + """ + return _pywrap_tensorflow.FlattenForData(nest) def assert_same_structure(nest1, nest2, check_types=True): diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index 104a615636..6d336ac39d 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -104,36 +104,9 @@ Raises: %unignore tensorflow::swig::Flatten; %noexception tensorflow::swig::Flatten; -%feature("docstring") tensorflow::swig::IsSequenceForData -"""Returns a true if `seq` is a Sequence or dict (except strings/lists). - -NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`, -which *does* treat a Python list as a sequence. For ergonomic -reasons, `tf.data` users would prefer to treat lists as -implicit `tf.Tensor` objects, and dicts as (nested) sequences. - -Args: - seq: an input sequence. - -Returns: - True if the sequence is a not a string or list and is a - collections.Sequence. -""" %unignore tensorflow::swig::IsSequenceForData; %noexception tensorflow::swig::IsSequenceForData; -%feature("docstring") tensorflow::swig::FlattenForData -"""Returns a flat sequence from a given nested structure. - -If `nest` is not a sequence, this returns a single-element list: `[nest]`. - -Args: - nest: an arbitrarily nested structure or a scalar object. - Note, numpy arrays are considered scalars. - -Returns: - A Python list, the flattened version of the input. -""" %unignore tensorflow::swig::FlattenForData; %noexception tensorflow::swig::FlattenForData; -- cgit v1.2.3 From b40ace8f28315431e3435647ce39cc7b24c20bfd Mon Sep 17 00:00:00 2001 From: Tong Shen Date: Sun, 9 Sep 2018 09:50:03 -0700 Subject: Automated rollback of commit a3776a234f555213aafcf41f49a42a8a9448c4ac PiperOrigin-RevId: 212182923 --- tensorflow/compiler/jit/BUILD | 1 - .../jit/jit_compilation_pass_registration.cc | 12 -- tensorflow/compiler/tf2xla/BUILD | 18 +-- tensorflow/compiler/tf2xla/functionalize_cond.cc | 10 +- .../compiler/tf2xla/functionalize_control_flow.cc | 133 --------------------- .../compiler/tf2xla/functionalize_control_flow.h | 13 -- ...functionalize_control_flow_pass_registration.cc | 25 ---- tensorflow/compiler/tf2xla/functionalize_while.cc | 25 +--- tensorflow/compiler/tf2xla/graph_compiler.cc | 1 + tensorflow/compiler/tf2xla/tf2xla.cc | 8 -- tensorflow/compiler/tf2xla/tf2xla_util.cc | 102 ---------------- tensorflow/compiler/tf2xla/tf2xla_util.h | 62 ---------- tensorflow/compiler/tf2xla/xla_compiler.cc | 13 +- tensorflow/compiler/tf2xla/xla_compiler_test.cc | 17 +++ tensorflow/core/framework/function.cc | 11 -- tensorflow/core/framework/function.h | 4 - 16 files changed, 32 insertions(+), 423 deletions(-) delete mode 100644 tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7d5db713f6..a989f15a1c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -265,7 +265,6 @@ cc_library( srcs = ["jit_compilation_pass_registration.cc"], deps = [ ":compilation_passes", - "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 5dcf754969..c37b6112cc 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -21,18 +21,6 @@ limitations under the License. namespace tensorflow { -// PRE_PLACEMENT passes: - -// from -// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc -// FunctionalizeControlFlowPass: 27 -// -// This pass looks at the graph and all associated FunctionDefs, and turns -// traditional control flow structure (Switch/Merge/etc.) into functional -// control flow structure (XlaIf/XlaWhile). Following passes must -// handle those FunctionDef correctly. - -// POST_REWRITE_FOR_EXEC passes: REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index b28ffaf8a4..3821dced63 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -76,7 +76,6 @@ cc_library( deps = [ ":common", ":dump_graph", - ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", @@ -189,6 +188,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", ":side_effect_util", @@ -285,7 +285,6 @@ cc_library( deps = [ ":sharding_util", ":tf2xla_proto", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -481,7 +480,6 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -509,23 +507,11 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], ) -cc_library( - name = "functionalize_control_flow_pass_registration", - srcs = [ - "functionalize_control_flow_pass_registration.cc", - ], - deps = [ - ":functionalize_control_flow", - ], - alwayslink = 1, -) - cc_library( name = "functionalize_while", srcs = [ @@ -535,7 +521,6 @@ cc_library( "functionalize_while.h", ], deps = [ - ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", @@ -546,7 +531,6 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 55439e77a6..0911550f1f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; @@ -643,7 +642,7 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If", library); + NodeDefBuilder builder(name(), "If"); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -1253,13 +1252,6 @@ Status FunctionalizeCond::FunctionalizeInternal() { std::vector switch_ids; std::vector merge_order; DFS(*graph_, nullptr, [&](Node* n) { - // Nodes marked with _xla_outside_compilation are skipped, because they need - // to be executed on host with regular TF executor, which does not support - // XlaIf/XlaWhile. - if (HasNodeAttr(n->def(), kXlaOutsideCompilationAttrName)) { - return; - } - if (IsSwitch(n)) { switch_ids.push_back(n->id()); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 622767f68d..5932be4e52 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,16 +31,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -73,132 +68,4 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } -Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, - FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map* canonicalized_name_to_new_name) { - // Convert the function to Graph. - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = Status::OK(); - auto cleanup_handle = gtl::MakeCleanup([&]() { - auto s = flr->ReleaseHandle(handle); - if (!s.ok()) { - ret_status.Update(s); - } - }); - const FunctionBody* body = flr->GetFunctionBody(handle); - const FunctionDef& fdef = body->fdef; - - // If any node has associated functions, functionalize them first. - for (auto* n : body->graph->nodes()) { - auto associated_functions = GetAssociatedFunctions(*n, flr); - for (auto& associated_function : associated_functions) { - string name = associated_function.func_name(); - string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); - // If we already functionalized this function, skip it. - auto iter = canonicalized_name_to_new_name->find(canonicalized_name); - if (iter != canonicalized_name_to_new_name->end()) { - continue; - } - - string new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; - // Notice that if "n" is a function call, RewriteAssociatedFunction() will - // delete it and create a new node instead, making "n" an invalid pointer. - // That's fine because in that case, associated_functions will only have - // one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - body->graph, n, fld, associated_function, new_name)); - } - } - - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *body->graph, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *body->graph, fld); - } - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); - - // Copy signature and ret from original FunctionDef. - *functionalized_fdef.mutable_signature() = fdef.signature(); - *functionalized_fdef.mutable_ret() = fdef.ret(); - functionalized_fdef.mutable_signature()->set_name(new_func_name); - - // Add rewritten FunctionDef into library. - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); - } - - return ret_status; -} - -Status FunctionalizeControlFlowPass::Run( - const GraphOptimizationPassOptions& options) { - Graph* graph = options.graph->get(); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, - options.flib_def); - } - std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime( - /*device_mgr=*/nullptr, options.session_options->env, - TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); - FunctionLibraryRuntime* flr = - pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - - // Find XLA compile ops and its corresponding FunctionDef. - static std::map* kNodeTypeToFunctionAttrMapping = - new std::map{ - {"TPUCompile", "function"}, - {"XlaLaunch", "function"}, - }; - std::map canonicalized_name_to_new_name; - for (Node* n : graph->nodes()) { - auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); - if (it == kNodeTypeToFunctionAttrMapping->end()) { - continue; - } - const string func_attr = it->second; - if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != - kNodeTypeToFunctionAttrMapping->end()) { - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); - VLOG(2) << "Graph has node " << n->type_string() - << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( - absl::StrCat(func.name(), "_f15n_")); - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name)); - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); - } - } - - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, - options.flib_def); - } - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index f1cbcdf617..55600f2a8b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -33,18 +32,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); -// This pass looks at the graph and all associated FunctionDefs, and turns -// traditional control flow structure (Switch/Merge/etc.) into functional -// control flow structure (XlaIf/XlaWhile). -// -// Notice that control flow structure marked with _xla_outside_compilation are -// skipped, because they need to be executed on host with regular TF executor, -// which does not support XlaIf/XlaWhile. -class FunctionalizeControlFlowPass : public GraphOptimizationPass { - public: - Status Run(const GraphOptimizationPassOptions& options) override; -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc deleted file mode 100644 index a10a9d0499..0000000000 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" - -namespace tensorflow { - -// This pass is required for some AOT backends and all JIT backends, so this -// file exists as a separate lib and will be linked to both AOT and JIT. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, - FunctionalizeControlFlowPass); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index f905c6a0fc..7f45e3bffa 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -25,7 +25,6 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -35,7 +34,6 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace { @@ -475,21 +473,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } } - // Builds the condition and body functions. Notice that we call - // FunctionalizeCond() on cond_graph and body_graph because we might have - // unfunctionalized "if" in cond_graph and body_graph. Functionalize them - // before they are encapsulated in FunctionDef. - // TODO(b/114485797): current logic does not functionalize while loop in - // another loop cond. + // Builds the condition and body functions. std::unique_ptr cond_graph; TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); - FixupSourceAndSinkEdges(cond_graph.get()); - TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); DataTypeVector arg_types; std::unique_ptr body_graph; TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); - FixupSourceAndSinkEdges(body_graph.get()); - TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) @@ -521,7 +510,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Builds a While operator. NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile", library); + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); @@ -652,14 +641,8 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, continue; } - // Nodes marked with _xla_outside_compilation are skipped, because they need - // to be executed on host with regular TF executor, which does not support - // XlaIf/XlaWhile. - string name; - if (!HasNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName)) { - TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); - } + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index fa25a230b0..bc2e640559 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index b22d53805d..7dbe3a0b58 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -25,7 +25,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -341,13 +340,6 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), second_copy_def, g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); - - // Functionalize control flow. - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def)); - // After control flow functionalization, we might have more FunctionDef's - // (then/else branch, loop body). Add them to the graph. - TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); - *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index d6f42bac86..211caf8736 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -25,12 +25,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" -#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -78,8 +75,6 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace -const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; - Status ValidateConfig(const tf2xla::Config& config) { std::set names; for (const tf2xla::Feed& feed : config.feed()) { @@ -328,101 +323,4 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } -// TODO(b/77601805): add tests for associated function related stuff. -bool HasAssociatedFunction(const NodeDef& node_def, - FunctionLibraryRuntime* flr) { - if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { - return true; - } - - if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. - return false; - } - - for (const auto& iter : node_def.attr()) { - if (iter.second.has_func()) { - return true; - } - } - - return false; -} - -std::vector GetAssociatedFunctions( - const Node& node, FunctionLibraryRuntime* flr) { - std::vector results; - const string& op = node.type_string(); - if (flr->GetFunctionLibraryDefinition()->Contains(op)) { - // This is a function call node. - AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); - results.emplace_back(AssociatedFunctionInfo(op, attrs)); - } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. - } else { - // Collect all function attrs for the node. - for (auto& iter : node.attrs()) { - if (iter.second.has_func()) { - VLOG(2) << "Found function attr for node " << node.name() << ": " - << iter.first << " = " << iter.second.func().name(); - results.emplace_back(AssociatedFunctionInfo( - iter.second.func().name(), iter.second.func().attr(), iter.first)); - } - } - } - return results; -} - -Status RewriteAssociatedFunction( - Graph* graph, Node* node, FunctionLibraryDefinition* fld, - const AssociatedFunctionInfo& associated_function, - const string& rewritten_function_name) { - switch (associated_function.type()) { - case AssociatedFunctionInfo::kFunctionCallNode: { - // Change this node to call the new function. - NodeDefBuilder builder(node->name(), rewritten_function_name, fld); - for (auto attr : node->attrs()) { - builder.Attr(attr.first, attr.second); - } - for (int i = 0; i < node->num_inputs(); i++) { - Node* input_node; - TF_RETURN_IF_ERROR(node->input_node(i, &input_node)); - builder.Input(input_node->name(), i, node->input_type(i)); - } - builder.Device(node->assigned_device_name().empty() - ? node->requested_device() - : node->assigned_device_name()); - NodeDef node_def; - TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); - Status s; - Node* new_node = graph->AddNode(node_def, &s); - TF_RETURN_IF_ERROR(s); - for (auto edge : node->in_edges()) { - graph->AddEdge(edge->src(), edge->src_output(), new_node, - edge->dst_input()); - } - for (auto edge : node->out_edges()) { - graph->AddEdge(new_node, edge->src_output(), edge->dst(), - edge->dst_input()); - } - graph->RemoveNode(node); - break; - } - case AssociatedFunctionInfo::kFunctionAttr: { - // Change function attr to rewritten functions. - NameAttrList func; - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), associated_function.attr_name(), &func)); - node->ClearAttr(associated_function.attr_name()); - func.set_name(rewritten_function_name); - node->AddAttr(associated_function.attr_name(), func); - break; - } - } - - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 41e70e0658..dcddef8418 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -20,7 +20,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -61,67 +60,6 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, // Returns the next random seed to use for seeding xla rng. uint32 GetXLARandomSeed(); -// Indicates how a FunctionDef is associated with a graph node (e.g. the node is -// a function call, or the node has function attrs). -class AssociatedFunctionInfo { - public: - enum AssociatedFunctionType { - kFunctionCallNode = 0, - kFunctionAttr = 1, - }; - - // The node is a function call. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) - : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} - - // The function is an attr of the node. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, - const string& attr_name) - : type_(kFunctionAttr), - func_name_(func_name), - attrs_(attrs), - attr_name_(attr_name) {} - - AssociatedFunctionType type() const { return type_; } - - const string& func_name() const { return func_name_; } - - const string& attr_name() const { return attr_name_; } - - const AttrValueMap& attrs() const { return attrs_; } - - private: - // Available for all instances. - AssociatedFunctionType type_; - string func_name_; - AttrValueMap attrs_; - - // Only available if the function is defined in an attr. - string attr_name_; -}; - -// Returns if the NodeDef has associated function. -bool HasAssociatedFunction(const NodeDef& node_def, - FunctionLibraryRuntime* flr); - -// Gets functions associated with the node. Current cases: -// 1. For function call node, its function name; -// 2. For nodes like XlaWhile/XlaIf, all their function attributes. -std::vector GetAssociatedFunctions( - const Node& node, FunctionLibraryRuntime* flr); - -// Changes associated functions for the node. Current cases: -// 1. For function call node, creates a new node with the new function name and -// remove the old node; -// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. -Status RewriteAssociatedFunction( - Graph* graph, Node* node, FunctionLibraryDefinition* fld, - const AssociatedFunctionInfo& associated_function, - const string& rewritten_function_name); - -// Attribute to mark nodes to be executed on host. -extern const char kXlaOutsideCompilationAttrName[]; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 105f3b61d5..dcb455779d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" @@ -149,9 +150,6 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), "Local lookup failed with: ", status.error_message()); - VLOG(4) << "Function " << function.name() << " in flib_runtime_"; - } else { - VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } return Status::OK(); } @@ -745,13 +743,18 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - absl::StrCat("xla_compile_graph_", name), *graph, - flib_runtime_->GetFunctionLibraryDefinition()); + absl::StrCat("xla_compile_graph_", name), *graph); } // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); + // Converts Tensorflow's graph control-flow constructs into functional + // control-flow that can be compiled into XLA code. + TF_RETURN_IF_ERROR( + FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), + graph.get(), local_flib_def_.get())); + // Detect invalid nodes. // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 42de6bacd6..40ce9fb41c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -1255,8 +1255,25 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + absl::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) + << status.error_message(); + } + + // Fix control edges for NoOp. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); + XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result)); + EXPECT_EQ(0, result.resource_updates.size()); } } diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index d979353d2f..26f32677af 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1154,17 +1154,6 @@ Status FunctionLibraryDefinition::LookUp( return default_registry_->LookUp(op, op_reg_data); } -string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { - tf_shared_lock l(mu_); - int index = 0; - string name = strings::StrCat(prefix, index); - while (function_defs_.find(name) != function_defs_.end()) { - ++index; - name = strings::StrCat(prefix, index); - } - return name; -} - const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( const NodeDef& ndef) const { if (ndef.op() != kGradientOp) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index e01eb7503d..03296a7761 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -358,10 +358,6 @@ class FunctionLibraryDefinition : public OpRegistryInterface { const OpRegistrationData** op_reg_data) const override LOCKS_EXCLUDED(mu_); - // Generates new function name with the specified prefix that is unique - // across this library. - string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_); - // Ops created for function arguments bear the name given by `kArgOp`; those // created for return values bear the name given by `kRetOp`. static constexpr const char* const kArgOp = "_Arg"; -- cgit v1.2.3 From 0b90eec6e16238198ffd0ff0011e0f6f33f4038d Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sun, 9 Sep 2018 10:47:00 -0700 Subject: [XLA] Improve error message in HLO evaluator for illegal broadcast. PiperOrigin-RevId: 212185352 --- tensorflow/compiler/xla/service/hlo_evaluator.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index d0d955fea8..a2f683b690 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -940,8 +940,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand.shape().dimensions(i)); + auto operand_dim_size = operand.shape().dimensions(i); + auto broadcast_dim_size = + broadcast->shape().dimensions(broadcast->dimensions(i)); + TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat( + "Operand dimension %d is broadcast to output dimension %d, but the " + "sizes of these two dims do not match (%d vs %d): %s", + i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size, + broadcast->ToString()); } TF_ASSIGN_OR_RETURN( -- cgit v1.2.3 From 515a7f3ccb96b8f1224c4b93e942b81942c4e3d2 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sun, 9 Sep 2018 11:12:57 -0700 Subject: Fix typo in error message in xla_op_kernel. PiperOrigin-RevId: 212186490 --- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 636cb71e21..c7baee27f9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -220,7 +220,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, - "as a compile-time constant.\nError: ", + " as a compile-time constant.\nError: ", computed.status().error_message()); } *constant_literal = std::move(*computed.ValueOrDie()); -- cgit v1.2.3 From 542fb58cf5f66899479602c70659d59897249101 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 9 Sep 2018 18:42:36 +0000 Subject: Fix np.float -> np.floating change While running core_rnn_cell_test: ``` bazel test -s --verbose_failures --config=opt //tensorflow/contrib/rnn:core_rnn_cell_test ``` Noticed the following warning: ``` FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. ``` This fix fixes the above warning. Signed-off-by: Yong Tang --- tensorflow/python/framework/test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4bece9e25e..cd23b3923e 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1655,7 +1655,7 @@ class TensorFlowTestCase(googletest.TestCase): if any of the elements do not fall in the specified range. """ target = self._GetNdArray(target) - if not (np.issubdtype(target.dtype, np.float) or + if not (np.issubdtype(target.dtype, np.floating) or np.issubdtype(target.dtype, np.integer)): raise AssertionError( "The value of %s does not have an ordered numeric type, instead it " -- cgit v1.2.3 From 231f34e3d8634ae02dae00af89d0ceafb3ada588 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Sun, 9 Sep 2018 19:49:17 -0700 Subject: Add support for evaluate and predict in keras with TPUStrategy. Also add unittests and updated examples. PiperOrigin-RevId: 212207760 --- tensorflow/contrib/distribute/python/BUILD | 21 +- .../contrib/distribute/python/combinations.py | 4 + .../distribute/python/examples/keras_mnist.py | 1 - tensorflow/contrib/distribute/python/keras_test.py | 142 +++++---- .../keras/engine/distributed_training_utils.py | 8 + .../python/keras/engine/training_distributed.py | 342 ++++++++++++++++++--- 6 files changed, 409 insertions(+), 109 deletions(-) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index c524d8b394..87f76eaa94 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -708,19 +708,32 @@ cuda_py_test( ], ) -cuda_py_test( - name = "keras_test", +py_library( + name = "keras_test_lib", + testonly = 1, srcs = ["keras_test.py"], - additional_deps = [ - "//third_party/py/numpy", + deps = [ + ":combinations", "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "keras_test", + srcs = ["keras_test.py"], + additional_deps = [ + ":keras_test_lib", ], tags = [ "multi_and_single_gpu", + "no_pip", "no_windows_gpu", "notsan", ], diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 2301ba9233..1133be6d0b 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -328,6 +328,10 @@ tpu_strategy = NamedDistribution( "TPU", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=5), required_tpu=True) +tpu_strategy_one_step = NamedDistribution( + "TPU", lambda: tpu_lib.TPUStrategy( + TPUClusterResolver(""), steps_per_run=1), + required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 0495134636..a84ef04196 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -63,7 +63,6 @@ def get_input_datasets(): # eval dataset eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = eval_ds.repeat() - eval_ds = eval_ds.shuffle(100) eval_ds = eval_ds.batch(64, drop_remainder=True) return train_ds, eval_ds, input_shape diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 3cee3e37a7..d46f0eb276 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -18,9 +18,12 @@ from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops @@ -91,6 +94,25 @@ def get_ds_test_input_fn(): return dataset +def batch_wrapper(dataset, batch_size, distribution): + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.tpu_strategy_one_step], + mode=['graph']) + + class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): @@ -175,7 +197,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.DeleteRecursively(self._config.model_dir) -class TestWithDistributionStrategy(test.TestCase): +class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): def test_validating_dataset_input_tensors_with_shape_mismatch(self): with self.cached_session(): @@ -215,7 +237,8 @@ class TestWithDistributionStrategy(test.TestCase): distributed_training_utils.validate_distributed_dataset_inputs( strategy, x, y) - def test_calling_model_on_same_dataset(self): + @combinations.generate(all_combinations()) + def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) @@ -224,15 +247,13 @@ class TestWithDistributionStrategy(test.TestCase): optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = batch_wrapper(dataset, 10, distribution) # Call fit with validation data model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, @@ -241,6 +262,9 @@ class TestWithDistributionStrategy(test.TestCase): validation_data=dataset, validation_steps=2) model.predict(dataset, steps=2) + # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work + # as clone_model's input_tensors argument only seems to accept list and not + # tuples or dict. def test_fit_with_tuple_and_dict_dataset_inputs(self): with self.cached_session(): a = keras.layers.Input(shape=(3,), name='input_a') @@ -282,7 +306,8 @@ class TestWithDistributionStrategy(test.TestCase): model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - def test_fit_eval_and_predict_methods_on_dataset(self): + @combinations.generate(all_combinations()) + def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) @@ -291,16 +316,13 @@ class TestWithDistributionStrategy(test.TestCase): optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = batch_wrapper(dataset, 10, distribution) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1) @@ -496,6 +518,8 @@ class TestWithDistributionStrategy(test.TestCase): class LossMaskingWithDistributionStrategyTest(test.TestCase): + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. def test_masking(self): with self.cached_session(): np.random.seed(1337) @@ -519,24 +543,25 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase): self.assertEqual(hist.history['loss'][0], 0) -class NormalizationLayerWithDistributionStrategyTest(test.TestCase): +class NormalizationLayerWithDistributionStrategyTest( + test.TestCase, parameterized.TestCase): - def test_batchnorm_correctness(self): + @combinations.generate(all_combinations()) + def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) model.add(norm) - strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0', - '/device:GPU:0']) model.compile(loss='mse', optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=strategy) + distribute=distribution) # centered on 5.0, variance 10.0 x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + x = x.astype('float32') dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) dataset = dataset.repeat(100) - dataset = dataset.batch(32) + dataset = batch_wrapper(dataset, 32, distribution) model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) out = model.predict(dataset, steps=2) @@ -546,9 +571,11 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase): np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) -class CorrectnessWithDistributionStrategyTest(test.TestCase): +class CorrectnessWithDistributionStrategyTest(test.TestCase, + parameterized.TestCase): - def test_correctness(self): + @combinations.generate(all_combinations()) + def test_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') num_samples = 10000 @@ -557,43 +584,43 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase): x_train = x_train.astype('float32') y_train = y_train.astype('float32') - model = keras.Sequential() - model.add(keras.layers.Dense(1, input_shape=(1,))) - - # With DistributionStrategy - dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - dataset_with = dataset_with.batch(32) - strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', - '/device:GPU:0']) - - model.compile(loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), - distribute=strategy) - model.fit(x=dataset_with, epochs=1, steps_per_epoch=310) - wts_with_ds = model.get_weights() - - x_predict = [[1], [2], [3], [4]] - predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict, - x_predict)) - predict_dataset_with = predict_dataset_with.batch(2) - predict_with_ds = model.predict(predict_dataset_with, steps=1) - predict_with_ds = np.reshape(predict_with_ds, (4, 1)) - - # Without DistributionStrategy - dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train, + def fit_and_predict(with_distribution=None): + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_shape=(1,))) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + distribute=with_distribution) + + batch_size = 64 + if with_distribution: + batch_size //= with_distribution.num_towers + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - dataset_without = dataset_without.batch(64) - - model.compile(loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5)) - model.fit(x=dataset_without, epochs=1, steps_per_epoch=310) - wts_without_ds = model.get_weights() - - x_predict = [[1], [2], [3], [4]] - predict_dataset_without = dataset_ops.Dataset.from_tensor_slices(( - x_predict, x_predict)) - predict_dataset_without = predict_dataset_without.batch(4) - predict_without_ds = model.predict(predict_dataset_without, steps=1) + train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + # Running only 100 steps instead of the full dataset to keep test + # duration small. + model.fit(x=train_dataset, epochs=1, steps_per_epoch=100) + + weights = model.get_weights() + + x_predict = [[1.], [2.], [3.], [4.]] + predict_batch_size = 4 + if with_distribution: + predict_batch_size //= with_distribution.num_towers + predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict, + x_predict)) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, distribution) + predict_result = model.predict(predict_dataset, steps=1) + predict_result = np.reshape(predict_result, (4, 1)) + + return weights, predict_result + + wts_with_ds, predict_with_ds = fit_and_predict( + with_distribution=distribution) + wts_without_ds, predict_without_ds = fit_and_predict( + with_distribution=None) # Verify that the weights are the same within some limits of tolerance. np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3) @@ -602,5 +629,8 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase): np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3) +# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index c1c4970025..fa7228ed7b 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -287,3 +287,11 @@ def configure_and_create_session(distribution_strategy): session = session_module.Session(config=session_config) K.set_session(session) + + +def get_batch_dimension(iterator): + shapes = nest.flatten(iterator.output_shapes) + # Take the batch size from the first element, as it should be the same for + # all. + dims = shapes[0].dims + return dims[0] if dims else None diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 939732cd67..b35903d3fe 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -27,10 +27,14 @@ from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import distribute as distribute_lib +# TODO(priyag, sourabhbajaj): Refactor this file to address code duplication. + + def fit_loop( model, iterator, @@ -41,13 +45,13 @@ def fit_loop( initial_epoch=0, steps_per_epoch=None, validation_steps=None): - """fit function when using DistributionStrategy for training. + """Fit loop for training with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. epochs: Number of times to iterate over the data - verbose: Verbosity mode, 0, 1 or 2 + verbose: Integer, Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_iterator: Iterator for validation data. initial_epoch: Epoch at which to start training @@ -73,8 +77,8 @@ def fit_loop( model, iterator, epochs, verbose, callbacks, initial_epoch, steps_per_epoch) - clone_model_on_towers( - model, current_strategy, make_callback_model=True) + if not model._grouped_model: + clone_model_on_towers(model, current_strategy, make_callback_model=True) def _per_device_train_function(model): model._make_train_function() @@ -206,13 +210,13 @@ def _experimental_fit_loop( callbacks=None, initial_epoch=0, steps_per_epoch=None): - """fit function when using TPU DistributionStrategy for training. + """Fit loop for training with TPU DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator that returns inputs and targets epochs: Number of times to iterate over the data - verbose: Verbosity mode, 0, 1 or 2 + verbose: Integer, Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training initial_epoch: Epoch at which to start training (useful for resuming a previous training run) @@ -244,7 +248,9 @@ def _experimental_fit_loop( def step_fn(ctx, inputs, targets): """Clones the model and calls make_train_function.""" - # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes. + # TODO(priyag, sourabhbajaj): The model gets cloned every time + # fit/test/predict is called. We should look into caching this keyed on + # input shapes. clone_model_on_towers( model, current_strategy, @@ -258,19 +264,22 @@ def _experimental_fit_loop( (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, - grouped_updates, grouped_session_args, with_loss_tensor=True) + grouped_updates, grouped_session_args) combined_fn = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) - # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be - # something else for different outputs. out_labels = model.metrics_names or [] for label, output in zip(out_labels, combined_fn.outputs): - ctx.set_last_step_output(label, output, - aggregation=distribute_lib.get_loss_reduction()) + if label == 'loss': + aggregation = distribute_lib.get_loss_reduction() + else: + # We aggregate all other metrics using mean for now. This is temporary + # workaround until new metrics are in place. + aggregation = variable_scope.VariableAggregation.MEAN + ctx.set_last_step_output(label, output, aggregation) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should @@ -324,10 +333,9 @@ def _experimental_fit_loop( callbacks.on_epoch_begin(epoch) epoch_logs = {} for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run): - # TODO(sourabhbajaj): Add the size parameter in batch_logs once callbacks - # are fixed as we need to replace size with a combination of steps_per_run + # TODO(sourabhbajaj): Replace size with a combination of steps_per_run # and batch_size - batch_logs = {'batch': step_index} + batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: _, outputs = K.get_session().run([train_op, output_tensors]) @@ -360,12 +368,12 @@ def _experimental_fit_loop( def test_loop(model, iterator, verbose=0, steps=None): - """evaluate method to validate a model that uses DistributionStrategy. + """Test loop for evaluating with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. - verbose: verbosity mode. + verbose: Integer, Verbosity mode 0 or 1. steps: Total number of steps (batches of samples) before declaring predictions finished. Ignored with the default value of `None`. @@ -374,11 +382,16 @@ def test_loop(model, iterator, verbose=0, steps=None): Scalar loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics). The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. + the display labels for the outputs. """ current_strategy = model._distribution_strategy - clone_model_on_towers(model, current_strategy) + # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. + if current_strategy.__class__.__name__ == 'TPUStrategy': + return _experimental_test_loop(model, iterator, verbose, steps) + + if not model._grouped_model: + clone_model_on_towers(model, current_strategy) def _per_device_test_function(model): model._make_test_function() @@ -429,25 +442,136 @@ def test_loop(model, iterator, verbose=0, steps=None): distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) - if steps is not None: - for step in range(steps): - batch_outs = distributed_test_function(ins) - batch_outs = _aggregate_metrics_across_towers( - current_strategy.num_towers, model.metrics_names, batch_outs) - if isinstance(batch_outs, list): - if step == 0: - for _ in enumerate(batch_outs): - outs.append(0.) - for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out + assert steps is not None + for step in range(steps): + batch_outs = distributed_test_function(ins) + batch_outs = _aggregate_metrics_across_towers( + current_strategy.num_towers, model.metrics_names, batch_outs) + if isinstance(batch_outs, list): + if step == 0: + outs = [0.] * len(batch_outs) + for i, batch_out in enumerate(batch_outs): + outs[i] += batch_out + else: + if step == 0: + outs.append(0.) + outs[0] += batch_outs + if verbose >= 1: + progbar.update(step + 1) + for i in range(len(outs)): + outs[i] /= steps + + if len(outs) == 1: + return outs[0] + return outs + + +def _experimental_test_loop(model, iterator, verbose=0, steps=None): + """Test loop for evaluating with TPU DistributionStrategy. + + Arguments: + model: Keras Model instance. + iterator: Iterator for input data. + verbose: Integer, Verbosity mode 0 or 1. + steps: Total number of steps (batches of samples) + before declaring predictions finished. + Ignored with the default value of `None`. + + Returns: + Scalar loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). The attribute `model.metrics_names` will give you + the display labels for the outputs. + """ + current_strategy = model._distribution_strategy + K.get_session().run(current_strategy.initialize()) + + def _per_device_test_function(model): + model._make_test_function() + return (model.test_function.inputs, + model.test_function.outputs, + model.test_function.updates_op, + model.test_function.session_kwargs) + + # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. + K.set_learning_phase(0) + + def step_fn(ctx, inputs, targets): + """Clones the model and calls make_test_function.""" + # TODO(priyag, sourabhbajaj): The model gets cloned every time + # fit/test/predict is called. We should look into caching this keyed on + # input shapes. + clone_model_on_towers( + model, + current_strategy, + make_callback_model=False, + inputs=inputs, + targets=targets) + + (grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) = current_strategy.call_for_each_tower( + _per_device_test_function, model._grouped_model) + + (all_inputs, all_outputs, all_updates, + all_session_args) = distributed_training_utils.unwrap_values( + current_strategy, grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) + + combined_fn = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_test_function', + **all_session_args) + + for label, output in zip(model.metrics_names, combined_fn.outputs): + if label == 'loss': + aggregation = distribute_lib.get_loss_reduction() else: - if step == 0: - outs.append(0.) - outs[0] += batch_outs - if verbose == 1: - progbar.update(step + 1) - for i in range(len(outs)): - outs[i] /= steps + # We aggregate all other metrics using mean for now. This is temporary + # workaround until new metrics are in place. + aggregation = variable_scope.VariableAggregation.MEAN + ctx.set_last_step_output(label, output, aggregation) + + return combined_fn.updates_op + + # Add initial dummy values for loss and other metric tensors. + initial_loop_values = {} + initial_loop_values['loss'] = constant_op.constant(1e7) + for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors): + initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) + + with current_strategy.scope(): + # TODO(priyag): Use steps_per_run when we use new metrics as they will + # allow handling metric computation at each step using variables. + ctx = current_strategy.run_steps_on_dataset( + step_fn, iterator, iterations=1, + initial_loop_values=initial_loop_values) + + test_op = ctx.run_op + output_tensors = ctx.last_step_outputs + + if verbose == 1: + progbar = Progbar(target=steps) + + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() + with current_strategy.scope(): + distributed_model = current_strategy.unwrap(model._grouped_model)[0] + distributed_training_utils.set_weights( + current_strategy, distributed_model, orig_model_weights) + + assert steps is not None + outs = [0.] * len(model.metrics_names) + for step in range(steps): + _, batch_outs = K.get_session().run([test_op, output_tensors]) + for i, label in enumerate(model.metrics_names): + outs[i] += batch_outs[label] + if verbose >= 1: + progbar.update(step + 1) + for i in range(len(outs)): + outs[i] /= (steps) + + K.get_session().run(current_strategy.finalize()) if len(outs) == 1: return outs[0] @@ -455,12 +579,12 @@ def test_loop(model, iterator, verbose=0, steps=None): def predict_loop(model, iterator, verbose=0, steps=None): - """Abstract method to loop over some data in batches. + """Predict loop for predicting with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. - verbose: verbosity mode. + verbose: Integer, Verbosity mode 0 or 1. steps: Total number of steps (batches of samples) before declaring `_predict_loop` finished. Ignored with the default value of `None`. @@ -472,7 +596,12 @@ def predict_loop(model, iterator, verbose=0, steps=None): """ current_strategy = model._distribution_strategy - clone_model_on_towers(model, current_strategy) + # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. + if current_strategy.__class__.__name__ == 'TPUStrategy': + return _experimental_predict_loop(model, iterator, verbose, steps) + + if not model._grouped_model: + clone_model_on_towers(model, current_strategy) def _per_device_predict_function(model): model._make_predict_function() @@ -528,9 +657,11 @@ def predict_loop(model, iterator, verbose=0, steps=None): if step == 0: for _ in batch_outs: unconcatenated_outs.append([]) + # TODO(anjalisridhar): Should combine the outputs from multiple towers + # correctly here. for i, batch_out in enumerate(batch_outs): unconcatenated_outs[i].append(batch_out) - if verbose == 1: + if verbose >= 1: progbar.update(step + 1) if len(unconcatenated_outs) == 1: return np.concatenate(unconcatenated_outs[0], axis=0) @@ -540,6 +671,122 @@ def predict_loop(model, iterator, verbose=0, steps=None): ] +def _experimental_predict_loop(model, iterator, verbose=0, steps=None): + """Predict loop for predicting with TPU DistributionStrategy. + + Arguments: + model: Keras Model instance. + iterator: Iterator for input data. + verbose: Integer, Verbosity mode 0 or 1. + steps: Total number of steps (batches of samples) + before declaring `_predict_loop` finished. + Ignored with the default value of `None`. + + Returns: + Array of predictions (if the model has a single output) + or list of arrays of predictions + (if the model has multiple outputs). + """ + current_strategy = model._distribution_strategy + K.get_session().run(current_strategy.initialize()) + + # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. + K.set_learning_phase(0) + + def _per_device_predict_function(model): + model._make_predict_function() + return (model.predict_function.inputs, + model.predict_function.outputs, + model.predict_function.updates_op, + model.predict_function.session_kwargs) + + def step_fn(ctx, inputs, targets): + """Clones the model and calls make_predict_function.""" + + # TODO(anjalisridhar): Support predict input correctly as it will not + # contain targets, only inputs. + del targets + + # TODO(priyag, sourabhbajaj): The model gets cloned every time + # fit/test/predict is called. We should look into caching this keyed on + # input shapes. + clone_model_on_towers( + model, + current_strategy, + make_callback_model=False, + inputs=inputs) + + (grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) = current_strategy.call_for_each_tower( + _per_device_predict_function, model._grouped_model) + + (all_inputs, all_outputs, all_updates, + all_session_args) = distributed_training_utils.unwrap_values( + current_strategy, grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) + + combined_fn = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_predict_function', + **all_session_args) + + for label, output in zip(model.output_names, combined_fn.outputs): + ctx.set_last_step_output(label, output) + + return combined_fn.updates_op + + # Add initial dummy values for outputs. + initial_loop_values = {} + batch_dimension = distributed_training_utils.get_batch_dimension(iterator) + for name, tensor in zip(model.output_names, model.outputs): + # TODO(priyag): This is a workaround as we do not know the batch dimension + # of the model's output at this point. + tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:] + initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) + + with current_strategy.scope(): + # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed. + ctx = current_strategy.run_steps_on_dataset( + step_fn, iterator, iterations=1, + initial_loop_values=initial_loop_values) + + predict_op = ctx.run_op + output_tensors = ctx.last_step_outputs + + if verbose == 1: + progbar = Progbar(target=steps) + + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() + with current_strategy.scope(): + distributed_model = current_strategy.unwrap(model._grouped_model)[0] + distributed_training_utils.set_weights( + current_strategy, distributed_model, orig_model_weights) + + assert steps is not None + # Since we do not know how many samples we will see, we cannot pre-allocate + # the returned Numpy arrays. Instead, we store one array per batch seen + # and concatenate them upon returning. + unconcatenated_outs = [[] for _ in model.outputs] + for step in range(steps): + _, batch_outs = K.get_session().run([predict_op, output_tensors]) + # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy. + for i, label in enumerate(model.output_names): + unconcatenated_outs[i].extend(batch_outs[label]) + if verbose >= 1: + progbar.update(step + 1) + + K.get_session().run(current_strategy.finalize()) + + if len(unconcatenated_outs) == 1: + return np.concatenate(unconcatenated_outs[0], axis=0) + return [ + np.concatenate(unconcatenated_outs[i], axis=0) + for i in range(len(unconcatenated_outs)) + ] + + def _clone_and_build_model(model, inputs=None, targets=None): """Clone and build the given keras_model.""" # We need to set the import here since we run into a circular dependency @@ -572,13 +819,12 @@ def _clone_and_build_model(model, inputs=None, targets=None): def clone_model_on_towers( model, strategy, make_callback_model=False, inputs=None, targets=None): - """Create a cloned model on each tower, unless already created.""" - if not model._grouped_model: - with strategy.scope(): - model._grouped_model = strategy.call_for_each_tower( - _clone_and_build_model, model, inputs, targets) - if make_callback_model: - model._make_callback_model() + """Create a cloned model on each tower.""" + with strategy.scope(): + model._grouped_model = strategy.call_for_each_tower( + _clone_and_build_model, model, inputs, targets) + if make_callback_model: + model._make_callback_model() def _aggregate_metrics_across_towers(num_devices, out_labels, outs): -- cgit v1.2.3 From 17a34ab8f214cd1f07d63ea238eda4ba3bf052c5 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sun, 9 Sep 2018 20:42:48 -0700 Subject: Add support for numpy arrays with DistributionStrategy in Keras. PiperOrigin-RevId: 212210810 --- tensorflow/contrib/distribute/python/keras_test.py | 34 ++++++++ .../keras/engine/distributed_training_utils.py | 69 ++++++++++++++- tensorflow/python/keras/engine/training.py | 99 ++++++++++++++++++---- .../python/keras/engine/training_distributed.py | 14 ++- 4 files changed, 189 insertions(+), 27 deletions(-) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index d46f0eb276..9e1762d92c 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -237,6 +237,40 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): distributed_training_utils.validate_distributed_dataset_inputs( strategy, x, y) + def test_calling_model_with_numpy_arrays(self): + with self.cached_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + @combinations.generate(all_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index fa7228ed7b..b28df75493 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.client import session as session_module +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks @@ -212,7 +213,10 @@ def validate_distributed_dataset_inputs(distribution_strategy, x, y): # validate the input and targets. x_values_list = validate_per_device_inputs(distribution_strategy, x) - y_values_list = validate_per_device_inputs(distribution_strategy, y) + if y is not None: + y_values_list = validate_per_device_inputs(distribution_strategy, y) + else: + y_values_list = None # Return the unwrapped values to avoid calling `unwrap` a second time. return x_values_list, y_values_list @@ -289,6 +293,69 @@ def configure_and_create_session(distribution_strategy): K.set_session(session) +def validate_inputs(x, y): + """Validate inputs when using DistributionStrategy. + + Args: + x: Model Inputs. + y: Model Targets. + + Raises: + ValueError: if input is not a Dataset or a numpy array. + """ + if isinstance(x, list) or isinstance(y, list): + raise ValueError('DistributionStrategy does not support lists of numpy' + 'arrays. You must pass a Dataset object or a numpy array ' + 'as input.') + + if isinstance(x, dict) or isinstance(y, dict): + raise ValueError('DistributionStrategy does not support inputs of type ' + 'dict. You must pass a Dataset object or a numpy array as ' + 'input.') + + if isinstance(x, iterator_ops.Iterator) or \ + isinstance(y, iterator_ops.Iterator): + raise ValueError('DistributionStrategy does not support inputs of type ' + 'Iterator. You must pass a Dataset object or a numpy ' + 'array as input.') + + +def get_input_batch_params(first_x_value, batch_size, current_strategy): + """Calculate the number of batches and steps/steps_per_epoch. + + Args: + first_x_value: This is the first input numpy array that is passed in as the + model input. + batch_size: The specified batch_size or the default batch_size of 32. + current_strategy: The current DistributionStrategy used to compile the + model. + + Returns: + The steps or steps_per_epoch argument depending on if a user is + calling `fit`, `evaluate` or `predict`. + + Raises: + ValueError: If the number of batches or steps evaluates to 0. + + """ + num_batches = first_x_value.shape[0] // batch_size + if not num_batches: + raise ValueError('Please specify a batch_size that is smaller than' + 'the number of input samples %d.' % first_x_value.shape[0]) + # TODO(anjalisridhar): TPU currently supports using the num_towers property. + # We might want to look into implementing worker_devices. In multi worker + # strategy, perhaps num_towers works better? + steps = num_batches // current_strategy.num_towers + if not steps: + # TODO(anjalisridhar): Number of towers in the error message may not convey + # what we want to the user. Is there another terminology that we can use + # that is consistent across different strategies. + raise ValueError('The number of batches %d is smaller than the number ' + 'of towers %d used for DistributionStrategy. ' % + num_batches, current_strategy.num_towers) + return steps + + def get_batch_dimension(iterator): shapes = nest.flatten(iterator.output_shapes) # Take the batch size from the first element, as it should be the same for diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index d224dfffdd..49b25e307e 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -20,9 +20,11 @@ from __future__ import print_function import weakref import numpy as np +import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops.dataset_ops import Dataset from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -754,9 +756,8 @@ class Model(Network): the model. Args: - x: Input data. A `tf.data` dataset. - y: Since `x` is a dataset, `y` should not be specified - (since targets will be obtained from the iterator). + x: Input data. A numpy array or `tf.data` dataset. + y: Target data. A numpy array or None if x is a `tf.data` dataset. sample_weight: An optional sample-weight array passed by the user to weight the importance of each sample in `x`. class_weight: An optional class-weight array by the user to @@ -786,12 +787,51 @@ class Model(Network): raise NotImplementedError('`class_weight` is currently not supported ' 'when using DistributionStrategy.') + # Validates `steps` argument right at the beginning since we use it to + # construct the dataset object. + # TODO(anjalisridhar): This may not be a valid error since we now accept + # numpy array inputs. We still want to assert that we have a populated steps + # parameter. + if check_steps: + if steps is None: + raise ValueError('When using DistributionStrategy, ' + 'you should specify the `{steps_name}` argument.' + .format(steps_name=steps_name)) + + first_x_value = nest.flatten(x)[0] + if isinstance(first_x_value, np.ndarray): + x_shape = first_x_value.shape + x_dtype = first_x_value.dtype + if batch_size is None: + batch_size = x_shape[0] // steps + if y is not None: + first_y_value = nest.flatten(y)[0] + x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y), + output_types=(x_dtype, first_y_value.dtype), + output_shapes=(x_shape[1:], + first_y_value.shape[1:])) + # TODO(anjalisridhar): What should the buffer size be? + x = x.shuffle(10000) + x = x.repeat() + x = x.batch(batch_size) + y = None + else: + # This case is for the predict call where the dataset only contains + # inputs and no targets i.e it does not return a tuple. + # TODO(anjalisridhar): Raise an error if we are not able to process + # all the predict samples. This can happen if the number of batches is + # not evenly divisible by the number of worker devices. + x = Dataset.from_generator(lambda x=x: x, + output_types=x_dtype, + output_shapes=x_shape[1:]) + x = x.repeat() + x = x.batch(batch_size) + # TODO(anjalisridhar): Can we use the iterator and getnext op cache? # We require users to pass Datasets since we distribute the dataset across # multiple devices. - if not isinstance(x, dataset_ops.Dataset): - raise ValueError('When using DistributionStrategy, model inputs should be' - ' Dataset instances; found instead %s.' % type(x)) + assert isinstance(x, dataset_ops.Dataset) + # TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a # function which returns a Dataset. Currently distribute_dataset() only # accepts a function that returns a Dataset. Once we add support for being @@ -799,12 +839,6 @@ class Model(Network): result = self._distribution_strategy.distribute_dataset(lambda: x) iterator = result.make_initializable_iterator() K.get_session().run(iterator.initializer) - # Validates `steps` argument based on x's type. - if check_steps: - if steps is None: - raise ValueError('When using a Dataset instance as input to a model, ' - 'you should specify the `{steps_name}` argument.' - .format(steps_name=steps_name)) training_utils.validate_iterator_input(x, y, sample_weight, validation_split) @@ -1428,6 +1462,13 @@ class Model(Network): if self._distribution_strategy: distributed_training_utils.validate_callbacks(callbacks) + distributed_training_utils.validate_inputs(x, y) + + first_x_value = nest.flatten(x)[0] + if not steps_per_epoch and isinstance(first_x_value, np.ndarray): + steps_per_epoch = distributed_training_utils.get_input_batch_params( + first_x_value, batch_size, self._distribution_strategy) + x, y, sample_weights = self._standardize_user_data( x, y, @@ -1462,6 +1503,13 @@ class Model(Network): 'However we received `validation_data=%s`' % validation_data) # Validate and standardize validation data. + if self._distribution_strategy: + distributed_training_utils.validate_inputs(val_x, val_y) + first_valx_value = nest.flatten(val_x)[0] + if not validation_steps and isinstance(first_valx_value, np.ndarray): + validation_steps = distributed_training_utils.get_input_batch_params( + first_valx_value, batch_size, self._distribution_strategy) + val_x, val_y, val_sample_weights = self._standardize_user_data( val_x, val_y, @@ -1599,6 +1647,13 @@ class Model(Network): batch_size = 32 # Validate and standardize user data. + if self._distribution_strategy: + distributed_training_utils.validate_inputs(x, y) + first_x_value = nest.flatten(x)[0] + if isinstance(first_x_value, np.ndarray) and not steps: + steps = distributed_training_utils.get_input_batch_params( + first_x_value, batch_size, self._distribution_strategy) + x, y, sample_weights = self._standardize_user_data( x, y, @@ -1669,14 +1724,22 @@ class Model(Network): if batch_size is None and steps is None: batch_size = 32 - # Turn off prefetching since this is currently not deterministic. Once - # b/112498930 is fixed we can turn it back on. - # `_prefetch_on_device` is currently a property of only `MirroredStrategy`. - if (self._distribution_strategy and - hasattr(self._distribution_strategy, '_prefetch_on_device')): - self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access + if self._distribution_strategy: + # Turn off prefetching since this is currently not deterministic. Once + # b/112498930 is fixed we can turn it back on. + # `_prefetch_on_device` is currently a property of only + # `MirroredStrategy`. + if hasattr(self._distribution_strategy, '_prefetch_on_device'): + self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access + distributed_training_utils.validate_inputs(x, None) + first_x_value = nest.flatten(x)[0] + if isinstance(first_x_value, np.ndarray) and not steps: + steps = distributed_training_utils.get_input_batch_params( + first_x_value, batch_size, self._distribution_strategy) # Validate and standardize user data. + # TODO(anjalisridhar): We don't pass batch_size here for some reason. This + # means that we end up calculating it twice which we should avoid. x, _, _ = self._standardize_user_data( x, check_steps=True, steps_name='steps', steps=steps) diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index b35903d3fe..53291c3956 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -861,14 +861,12 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs): def _get_input_from_iterator(iterator, model): """Get elements from the iterator and verify the input shape and type.""" next_element = iterator.get_next() - # TODO(anjalisridhar): Support predict input correctly as it will not contain - # targets, only inputs. - if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: - raise ValueError('Please provide model inputs as a list or tuple of 2 ' - 'elements: input and target pair. ' - 'Received %s' % next_element) - - x, y = next_element + + if isinstance(next_element, tuple): + x, y = next_element + else: + x = next_element + y = None # Validate that all the elements in x and y are of the same type and shape. # We can then pass the first element of x and y to `_standardize_weights` # below and be confident of the output. -- cgit v1.2.3 From 5d62202cb1491cf97f0cd34a9c7b0d691984ff5b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 9 Sep 2018 21:00:38 -0700 Subject: Fix code section in documentation of tf.enable_eager_execution(). PiperOrigin-RevId: 212211691 --- tensorflow/python/framework/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 9401309c19..75678cbc01 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -5364,6 +5364,7 @@ def enable_eager_execution(config=None, computational graph). For example: + ```python tf.enable_eager_execution() -- cgit v1.2.3 From cb92ac2041f196487415ced1e0081866ef8a0f15 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 10 Sep 2018 01:00:34 -0700 Subject: Move HloConstantFolding to the end of the conv_canonicalization pass pipeline. This will also fold the added pad instructions into constants if possible. PiperOrigin-RevId: 212227161 --- tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index f6325b3368..dfdcf1875d 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -208,10 +208,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass(); - // CudnnConvolutionRewriter may add instructions of the form - // reverse(constant), which it expects will be simplified by constant - // folding. - pipeline.AddPass(); pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass(); @@ -219,6 +215,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // pairs that TupleSimplifier fixes. pipeline.AddPass(); } + // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add + // instructions which can be simplified by constant folding. + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } -- cgit v1.2.3 From 7624156f03549e1822969d9eb2395b9357f74aa7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 02:01:20 -0700 Subject: compat: Update forward compatibility horizon to 2018-09-10 PiperOrigin-RevId: 212233410 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 5c50be2367..af58a6f841 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 9) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 10) @tf_export("compat.forward_compatible") -- cgit v1.2.3 From cfddd182f71147eaf5ee8dc50113de3c0e622655 Mon Sep 17 00:00:00 2001 From: pengwa Date: Mon, 10 Sep 2018 18:51:42 +0800 Subject: fix comments for _dynamic_rnn_loop and LSTMCell::call --- tensorflow/python/ops/rnn.py | 2 +- tensorflow/python/ops/rnn_cell_impl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 4f3d8c2318..259aca5a81 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -709,7 +709,7 @@ def _dynamic_rnn_loop(cell, Raises: ValueError: If the input depth cannot be inferred via shape inference from the inputs. - ValueError: If time is not the same for all the elements in the + ValueError: If time_step is not the same for all the elements in the input. ValueError: If batch_size is not the same for all the elements in the input. diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index c11c9ccaae..3e19183ff5 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -954,7 +954,7 @@ class LSTMCell(LayerRNNCell): """Run one step of LSTM. Args: - inputs: input Tensor, 2D, `[batch, num_units]. + inputs: input Tensor, must be 2-D, `[batch, input_size]`. state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and -- cgit v1.2.3 From 4b0d12bb8c62a44e895ebd515c0145d1c18e9191 Mon Sep 17 00:00:00 2001 From: pengwa Date: Mon, 10 Sep 2018 18:54:52 +0800 Subject: minor format --- tensorflow/python/ops/rnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 259aca5a81..dcc17db632 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -711,8 +711,8 @@ def _dynamic_rnn_loop(cell, from the inputs. ValueError: If time_step is not the same for all the elements in the input. - ValueError: If batch_size is not the same for all the elements - in the input. + ValueError: If batch_size is not the same for all the elements in the + input. """ state = initial_state assert isinstance(parallel_iterations, int), "parallel_iterations must be int" -- cgit v1.2.3 From 192e842e78475310ae0a36287570a1edcb2fbdaf Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 10 Sep 2018 06:21:43 -0700 Subject: Enable grouped convolutions for CudnnConvBackwardInput. So far, for grouped convolutions we always use forward convolution, which means we can't "fuse" the reverse into the cuDNN call. With this CL, we can also allow to use grouped convolutions if we match the backward convolution case. To make this work, we need to insert another reshape op. Also, refactor the code so that it returns the new "rhs" operand. PiperOrigin-RevId: 212256924 --- .../xla/service/gpu/cudnn_convolution_rewriter.cc | 80 ++++++++++++++++------ tensorflow/compiler/xla/tests/convolution_test.cc | 37 ++++++++++ 2 files changed, 95 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 9bf721ecd2..4a6a84d87d 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include #include #include @@ -59,8 +60,6 @@ std::tuple MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward filter. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -218,13 +217,16 @@ std::tuple MatchBackwardFilter( // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple MatchBackwardInput( - HloInstruction* conv) { +std::tuple +MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward input. + // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but at least for now with version 7.1.4 + // it is slower. This needs to be re-evaluated for future cuDNN versions. + // Note that we already have the necessary code down below, the only thing to + // enable it is to remove the following early return. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -401,10 +403,18 @@ std::tuple MatchBackwardInput( } } - // OK, it's a match! Canonicalize the conv's filter so that it's a reverse. - // This simplifies things for our caller, and algebraic-simplifier will later - // remove any unnecessary reverses. - if (reverse_filter->opcode() != HloOpcode::kReverse) { + // OK, it's a match! Switch the input feature dimension with the output + // feature dimension. This is the way cuDNN expects it to be. + dnums.set_kernel_input_feature_dimension( + conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + + // If we matched against a constant, we need to add a reverse op that can be + // subsumed by the cuDNN call. algebraic-simplifier will later remove any + // unnecessary reverses. + if (reverse_filter->opcode() != HloOpcode::kReverse && + reverse_filter->IsConstant()) { // Create a double-reverse, which is a nop. HloComputation* c = conv->parent(); reverse_filter = c->AddInstruction( @@ -416,11 +426,41 @@ std::tuple MatchBackwardInput( TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); } - dnums.set_kernel_input_feature_dimension( - conv->convolution_dimension_numbers().kernel_output_feature_dimension()); - dnums.set_kernel_output_feature_dimension( - conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); + // Calculate the 'rhs' that goes into the backward input convolution. + HloInstruction* rhs = reverse_filter; + // One reverse is subsumed by the cuDNN call. + if (rhs->opcode() == HloOpcode::kReverse) { + rhs = rhs->mutable_operand(0); + } + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, new_window, dnums, rhs); + } + + // Handle grouped convolutions. Because we swapped the input feature dimension + // with the output feature dimension, we need to also reshape the kernel so + // that the 'feature_group_count' parameter still makes sense. The + // 'feature_group_count' parameter essentially specifies how often the + // 'kernel_input_feature_dimension' is repeated. So when we swap these + // dimensions, we need to divide the new 'kernel_input_feature_dimension' by + // 'feature_group_count' and multiply the new + // 'kernel_output_feature_dimension' by 'feature_group_count'. + Shape new_shape = rhs->shape(); + int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); + int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + + // In the backward convolution case, the spatial dimensions become the + // feature dimensions, and we are guaranteed that the spatial dimensions are + // adjacent. + CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); + int64 input_features = new_shape.dimensions(input_feature_dimension); + int64 output_features = new_shape.dimensions(output_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_features / conv->feature_group_count()); + new_shape.set_dimensions(output_feature_dimension, + output_features * conv->feature_group_count()); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); + return std::make_tuple(true, new_window, dnums, rhs); } // Tries to rewrite a single convolution into a call to cudnn. @@ -431,6 +471,7 @@ StatusOr RunOnInstruction(HloInstruction* conv) { bool match; Window window; ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { @@ -439,13 +480,8 @@ StatusOr RunOnInstruction(HloInstruction* conv) { window, dnums, conv->feature_group_count()); } - std::tie(match, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - // Backward input conv subsumes the conv plus the reverse in operand 1. - HloInstruction* reverse = conv->mutable_operand(1); - CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); - HloInstruction* rhs = reverse->mutable_operand(0); - return CreateCudnnConvBackwardInput(conv->shape(), conv->mutable_operand(0), rhs, window, dnums, conv->feature_group_count()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index d2c6478b02..e0a1538850 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -896,6 +896,43 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { std::move(*LiteralUtil::CreateFromArray(filter_data))}); } +XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100}); + Array4D input_data(1, 64, 100, 100); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321); + Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64}); + Array4D filter_data(7, 7, 1, 64); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = ConstantR4FromArray4D(&builder, filter_data); + + // Specify bf01_01io->bf01 as dimension numbers. + ConvolutionDimensionNumbers dnums; + // Input + dnums.set_input_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + // Kernel + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + // Output + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + dnums.add_output_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(3); + ConvGeneral(input, filter, /*window_strides=*/{1, 1}, + /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums, + /*feature_group_count=*/64); + + ComputeAndCompare(&builder, + {std::move(*LiteralUtil::CreateFromArray(input_data))}, + error_spec_); +} + class ConvolutionHloTest : public HloTestBase {}; XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) { -- cgit v1.2.3 From 7ede7c78a1e1fccd6f2c083dad4e2629dfd43714 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 10 Sep 2018 07:28:20 -0700 Subject: [tf.data] Expose `tf.contrib.data.Optional` and `tf.contrib.data.get_next_as_optional()`. PiperOrigin-RevId: 212263849 --- tensorflow/contrib/data/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 5e6c1520a2..baec238c62 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -26,6 +26,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@CheckpointInputPipelineHook @@CsvDataset @@LMDBDataset +@@Optional @@RandomDataset @@Reducer @@SqlDataset @@ -38,7 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@copy_to_device @@dense_to_sparse_batch @@enumerate_dataset - +@@get_next_as_optional @@get_single_element @@group_by_reducer @@group_by_window @@ -46,7 +47,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator - @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave @@ -107,6 +107,8 @@ from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch from tensorflow.contrib.data.python.ops.unique import unique from tensorflow.contrib.data.python.ops.writers import TFRecordWriter +from tensorflow.python.data.ops.iterator_ops import get_next_as_optional +from tensorflow.python.data.ops.optional_ops import Optional # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented -- cgit v1.2.3 From bdbf4a4ab5e612487f0ee3699391956c6c472d88 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 08:20:55 -0700 Subject: Changing run_mode to run_as in documentation. PiperOrigin-RevId: 212270429 --- tensorflow/contrib/autograph/docs/pyfunc_dtypes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md index bcbb920cc5..c2427f5f4f 100644 --- a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md +++ b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md @@ -4,7 +4,7 @@ The `py_func` op requires specifying a [data type](https://www.tensorflow.org/guide/tensors#data_types). When wrapping a function with `py_func`, for instance using -`@autograph.do_not_convert(run_mode=autograph.RunMode.PY_FUNC)`, you have two +`@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)`, you have two options to specify the returned data type: * explicitly, with a specified `tf.DType` value -- cgit v1.2.3 From 73fd552491252494f71ec1fbf39daa5b41a48749 Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Mon, 10 Sep 2018 08:59:32 -0700 Subject: Don't print control dependencies when dumping HLO profile PiperOrigin-RevId: 212275570 --- tensorflow/compiler/xla/service/hlo_instruction.cc | 2 +- tensorflow/compiler/xla/service/hlo_instruction.h | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 25ae344ea5..f06c98f2e7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2096,7 +2096,7 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_predecessors_.empty()) { + if (options.print_control_dependencies() && !control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", StrJoin(control_predecessors_, ", ", [&](string* out, HloInstruction* pre) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5581c17c2d..bf25157395 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -82,6 +82,7 @@ class HloPrintOptions { print_operand_shape_(true), print_program_shape_(true), print_percent_(true), + print_control_dependencies_(true), canonicalize_instruction_names_(false), indent_amount_(0), is_in_nested_computation_(false) {} @@ -94,7 +95,8 @@ class HloPrintOptions { .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) - .set_print_percent(false); + .set_print_percent(false) + .set_print_control_dependencies(false); } // Options to produce the canonical string representing an isomorphic @@ -108,6 +110,7 @@ class HloPrintOptions { .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) + .set_print_control_dependencies(false) .set_canonicalize_instruction_names(true); } @@ -153,6 +156,12 @@ class HloPrintOptions { return *this; } + // If true, control dependencies will be printed. + HloPrintOptions& set_print_control_dependencies(bool value) { + print_control_dependencies_ = value; + return *this; + } + // If true, only a part of operands will be printed out, and their names will // be omitted (note that in this case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { @@ -190,6 +199,9 @@ class HloPrintOptions { bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool print_control_dependencies() const { + return print_control_dependencies_; + } bool canonicalize_instruction_names() const { return canonicalize_instruction_names_; } @@ -205,6 +217,7 @@ class HloPrintOptions { bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool print_control_dependencies_; bool canonicalize_instruction_names_; int indent_amount_; bool is_in_nested_computation_; -- cgit v1.2.3 From 7d3884bb87dc02c4548f55749f3d6db1b8364ddc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 09:47:23 -0700 Subject: Fix bug in copy optimization in Tensor slicing. PiperOrigin-RevId: 212283065 --- tensorflow/python/kernel_tests/slice_op_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py index 4a1fc1d9a9..40d384c623 100644 --- a/tensorflow/python/kernel_tests/slice_op_test.py +++ b/tensorflow/python/kernel_tests/slice_op_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -260,6 +261,21 @@ class SliceTest(test.TestCase): grad_actual = gradients_impl.gradients(out, inp)[0].eval() self.assertAllClose([0., 1., 1.], grad_actual) + def _testGradientVariableSize2D(self): + # Regression test for bug in slice. A low-level bug in Eigen was causing + # incorrect results for negative indices in multi-dimensional tensors. + # See b/114318298. + with self.test_session(use_gpu=True) as sess: + x = constant_op.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 7]]) + loss1 = math_ops.reduce_sum(x[:-1, :-1] * 1.0) + loss2 = math_ops.reduce_sum(x[:-1][:, :-1]) + + g1 = gradients_impl.gradients(loss1, x)[0] + g2 = gradients_impl.gradients(loss2, x)[0] + + g1_val, g2_val = sess.run([g1, g2]) + self.assertAllEqual(g1_val, g2_val) + def testGradientsAll(self): # Slice the middle square out of a 4x4 input self._testGradientSlice([4, 4], [1, 1], [2, 2]) @@ -276,6 +292,9 @@ class SliceTest(test.TestCase): # Use -1 as a slice dimension. self._testGradientVariableSize() + # Use -1 as a slice dimension on a 2D tensor. + self._testGradientVariableSize2D() + def testNotIterable(self): # NOTE(mrry): If we register __getitem__ as an overloaded # operator, Python will valiantly attempt to iterate over the -- cgit v1.2.3 From 5f004516a3c104ed7632ff4a31b65c49f620d199 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 10:23:14 -0700 Subject: Automated rollback of commit d6f107761459dfdf8773a148e11193a3512a51a6 PiperOrigin-RevId: 212289067 --- .../compiler/aot/embedded_protocol_buffers.h | 1 - tensorflow/compiler/aot/tfcompile_main.cc | 6 +- .../compiler/jit/mark_for_compilation_pass_test.cc | 2 +- tensorflow/compiler/jit/xla_cluster_util.h | 1 - tensorflow/compiler/jit/xla_device_context.cc | 6 +- tensorflow/compiler/jit/xla_device_context.h | 8 +- tensorflow/compiler/tf2xla/BUILD | 1 - .../compiler/tf2xla/resource_operation_table.cc | 18 ++-- tensorflow/compiler/tf2xla/tf2xla_util.h | 1 - tensorflow/compiler/tf2xla/xla_op_kernel.cc | 11 +- tensorflow/compiler/tf2xla/xla_op_registry.h | 1 - tensorflow/compiler/xla/packed_literal_reader.cc | 5 +- .../contrib/makefile/proto_text_cc_files.txt | 1 - tensorflow/core/lib/core/stringpiece.cc | 54 ---------- tensorflow/core/lib/core/stringpiece.h | 117 +-------------------- tensorflow/core/lib/strings/strcat.h | 3 + 16 files changed, 30 insertions(+), 206 deletions(-) delete mode 100644 tensorflow/core/lib/core/stringpiece.cc diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index bd270045e3..cf5c04ac4b 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -20,7 +20,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 1c9d30d7b0..b95b063348 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -93,8 +92,9 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, - StringPiece(obj.data(), obj.size()))); + TF_RETURN_IF_ERROR( + WriteStringToFile(env, flags.out_function_object, + absl::string_view(obj.data(), obj.size()))); CodegenOpts codegen_opts; codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 807ab51fd3..9473ac0a4c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,7 +633,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](StringPiece name, Graph* graph) { + auto BuildNoopNode = [](absl::string_view name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 94c96ac7c5..ba218f3315 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -18,7 +18,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ -#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6d4160a968..af83c792e5 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -339,11 +339,11 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { - manager_.CopyDeviceTensorToCPU(device_tensor, absl::string_view(tensor_name), - device, cpu_tensor, done); + manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, + done); } void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1effd6628f..df82421294 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { @@ -111,12 +110,9 @@ class XlaDeviceContext : public DeviceContext { void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const override; - // TODO(rlahaye): Replace StringPiece with absl::string_view when the - // StringPiece->absl::string_view change is rolled forward. void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, // non-ABSL OK - Device* device, Tensor* cpu_tensor, - StatusCallback done) override; + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3821dced63..ab289a2b6c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -215,7 +215,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 92577b5bc8..20f2ce2919 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { @@ -31,10 +30,11 @@ namespace tensorflow { } } -static gtl::FlatMap* CreateResourceOpInfoMap() { - auto* result = new gtl::FlatMap; +static gtl::FlatMap* +CreateResourceOpInfoMap() { + auto* result = new gtl::FlatMap; - auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { auto insert_result = result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); @@ -103,17 +103,17 @@ static gtl::FlatMap* CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap& +static const gtl::FlatMap& GetStaticResourceOpInfoMap() { - static gtl::FlatMap* op_info_map = + static gtl::FlatMap* op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { - const gtl::FlatMap& op_infos = + const gtl::FlatMap& op_infos = GetStaticResourceOpInfoMap(); - auto it = op_infos.find(StringPiece(op.data(), op.length())); + auto it = op_infos.find(op); return it == op_infos.end() ? nullptr : &it->second; } @@ -121,7 +121,7 @@ namespace resource_op_table_internal { std::vector GetKnownResourceOps() { std::vector result; for (const auto& p : GetStaticResourceOpInfoMap()) { - result.push_back(absl::string_view(p.first)); + result.push_back(p.first); } absl::c_sort(result); return result; diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index dcddef8418..a29e764466 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -18,7 +18,6 @@ limitations under the License. #include -#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c7baee27f9..d1534e9a15 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -102,8 +102,7 @@ Status XlaOpKernelContext::ConstantInput(int index, static xla::StatusOr InputIndex(XlaOpKernelContext* context, absl::string_view name) { int start, stop; - TF_RETURN_IF_ERROR(context->op_kernel().InputRange( - StringPiece(name.data(), name.length()), &start, &stop)); + TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, @@ -366,8 +365,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { OpInputList inputs; - TF_RETURN_IF_ERROR( - context_->input_list(StringPiece(name.data(), name.size()), &inputs)); + TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { @@ -380,8 +378,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, Status XlaOpKernelContext::ConstantInputList( absl::string_view name, std::vector* outputs) { int start, stop; - TF_RETURN_IF_ERROR(op_kernel().InputRange( - StringPiece(name.data(), name.size()), &start, &stop)); + TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); outputs->resize(stop - start); for (int i = start; i < stop; ++i) { TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); @@ -615,7 +612,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; - CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok()); + CHECK(context_->input(name, &tensor).ok()); return *tensor; } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 5d53169f68..74a4885f1f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -22,7 +22,6 @@ limitations under the License. #include #include -#include "absl/strings/string_view.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index bddb664149..f9473d372b 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -65,7 +64,7 @@ StatusOr> PackedLiteralReader::Read( absl::Span field = result->data(); char* data = absl::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); - tensorflow::StringPiece sp; + absl::string_view sp; auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -86,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - tensorflow::StringPiece sp; + absl::string_view sp; auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index b5c781ad76..9ea94c7433 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -2,7 +2,6 @@ tensorflow/core/framework/resource_handle.cc tensorflow/core/lib/core/arena.cc tensorflow/core/lib/core/coding.cc tensorflow/core/lib/core/status.cc -tensorflow/core/lib/core/stringpiece.cc tensorflow/core/lib/core/threadpool.cc tensorflow/core/lib/hash/crc32c.cc tensorflow/core/lib/hash/crc32c_accelerate.cc diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc deleted file mode 100644 index 4c488066e4..0000000000 --- a/tensorflow/core/lib/core/stringpiece.cc +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/core/stringpiece.h" - -#include -#include - -namespace tensorflow { - -std::ostream& operator<<(std::ostream& o, StringPiece piece) { - o.write(piece.data(), piece.size()); - return o; -} - -size_t StringPiece::find(char c, size_t pos) const { - if (pos >= size_) { - return npos; - } - const char* result = - reinterpret_cast(memchr(data_ + pos, c, size_ - pos)); - return result != nullptr ? result - data_ : npos; -} - -// Search range is [0..pos] inclusive. If pos == npos, search everything. -size_t StringPiece::rfind(char c, size_t pos) const { - if (size_ == 0) return npos; - for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) { - if (*p == c) { - return p - data_; - } - } - return npos; -} - -StringPiece StringPiece::substr(size_t pos, size_t n) const { - if (pos > size_) pos = size_; - if (n > size_ - pos) n = size_ - pos; - return StringPiece(data_ + pos, n); -} - -} // namespace tensorflow diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h index 02dded42c1..e7b17c9b36 100644 --- a/tensorflow/core/lib/core/stringpiece.h +++ b/tensorflow/core/lib/core/stringpiece.h @@ -31,124 +31,13 @@ limitations under the License. #include #include #include -#include +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -class StringPiece { - public: - typedef size_t size_type; - - // Create an empty slice. - StringPiece() : data_(nullptr), size_(0) {} - - // Create a slice that refers to d[0,n-1]. - StringPiece(const char* d, size_t n) : data_(d), size_(n) {} - - // Create a slice that refers to the contents of "s" - StringPiece(const string& s) : data_(s.data()), size_(s.size()) {} - - // Create a slice that refers to s[0,strlen(s)-1] - StringPiece(const char* s) : data_(s), size_(strlen(s)) {} - - // Return a pointer to the beginning of the referenced data - const char* data() const { return data_; } - - // Return the length (in bytes) of the referenced data - size_t size() const { return size_; } - - // Return true iff the length of the referenced data is zero - bool empty() const { return size_ == 0; } - - typedef const char* const_iterator; - typedef const char* iterator; - iterator begin() const { return data_; } - iterator end() const { return data_ + size_; } - - static const size_t npos = size_type(-1); - - // Return the ith byte in the referenced data. - // REQUIRES: n < size() - char operator[](size_t n) const { - assert(n < size()); - return data_[n]; - } - - // Drop the first "n" bytes from this slice. - void remove_prefix(size_t n) { - assert(n <= size()); - data_ += n; - size_ -= n; - } - - void remove_suffix(size_t n) { - assert(size_ >= n); - size_ -= n; - } - - size_t find(char c, size_t pos = 0) const; - size_t rfind(char c, size_t pos = npos) const; - - StringPiece substr(size_t pos, size_t n = npos) const; - - // Three-way comparison. Returns value: - // < 0 iff "*this" < "b", - // == 0 iff "*this" == "b", - // > 0 iff "*this" > "b" - int compare(StringPiece b) const; - - // Converts to various kinds of strings, including `std::basic_string`. - template - explicit operator S() const { - static_assert( - std::is_same::value, - "Type mismatch: S must be a string with character type char."); - static_assert( - std::is_same, typename S::traits_type>::value, - "Type mismatch: S must be a string with traits type " - "std::char_traits."); - if (!data()) return {}; - return S(data(), size()); - } - - private: - const char* data_; - size_t size_; - - // Intentionally copyable -}; - -inline bool operator==(StringPiece x, StringPiece y) { - return ((x.size() == y.size()) && - (memcmp(x.data(), y.data(), x.size()) == 0)); -} - -inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } - -inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; } -inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; } -inline bool operator<=(StringPiece x, StringPiece y) { - return x.compare(y) <= 0; -} -inline bool operator>=(StringPiece x, StringPiece y) { - return x.compare(y) >= 0; -} - -inline int StringPiece::compare(StringPiece b) const { - const size_t min_len = (size_ < b.size_) ? size_ : b.size_; - int r = memcmp(data_, b.data_, min_len); - if (r == 0) { - if (size_ < b.size_) - r = -1; - else if (size_ > b.size_) - r = +1; - } - return r; -} - -// allow StringPiece to be logged -extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece); +// Deprecated: please use absl::string_view directly. +using StringPiece = absl::string_view; } // namespace tensorflow diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h index 351b6f5de3..a620f59447 100644 --- a/tensorflow/core/lib/strings/strcat.h +++ b/tensorflow/core/lib/strings/strcat.h @@ -124,6 +124,9 @@ class AlphaNum { AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit) AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit) : piece_(str) {} + template + AlphaNum(const std::basic_string, A> &str) + : piece_(str) {} // NOLINT(runtime/explicit) StringPiece::size_type size() const { return piece_.size(); } const char *data() const { return piece_.data(); } -- cgit v1.2.3 From 07c0f308ecce579ec69ad53541332ccf506ca280 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 10 Sep 2018 10:23:34 -0700 Subject: Make checkpointable list and dict wrappers copyable and deepcopyable Also tests copying Checkpointable objects, which seems to just work. PiperOrigin-RevId: 212289140 --- .../training/checkpointable/data_structures.py | 43 ++++++++++ .../checkpointable/data_structures_test.py | 99 ++++++++++++++++++++++ 2 files changed, 142 insertions(+) diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index f06cbbfa15..c29e5db075 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import copy import six @@ -251,6 +252,12 @@ class List(CheckpointableDataStructure, collections.Sequence): self._storage[index] = self._track_value( element, name=self._name_element(index)) + def __copy__(self): + return type(self)(copy.copy(self._storage)) + + def __deepcopy__(self, memo): + return type(self)(copy.deepcopy(self._storage, memo)) + def _make_storage(self, *args, **kwargs): """Determines the backing storage (overridden in subclasses).""" return list(*args, **kwargs) @@ -325,6 +332,20 @@ class _ListWrapper(List, collections.MutableSequence, super(_ListWrapper, self).__init__(wrapped_list) self._last_wrapped_list_snapshot = list(self._storage) + # pylint: disable=protected-access + def __copy__(self): + copied = super(_ListWrapper, self).__copy__() + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + return copied + + def __deepcopy__(self, memo): + copied = super(_ListWrapper, self).__deepcopy__(memo) + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + return copied + # pylint: enable=protected-access + def _make_storage(self, wrapped_list): """Use the user's original list for storage.""" return wrapped_list @@ -449,6 +470,12 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): value, name=self._name_element(key)) for key, value in self._storage.items()}) + def __copy__(self): + return type(self)(copy.copy(self._storage)) + + def __deepcopy__(self, memo): + return type(self)(copy.deepcopy(self._storage, memo)) + def _make_storage(self, *args, **kwargs): return dict(*args, **kwargs) @@ -525,6 +552,22 @@ class _DictWrapper(Mapping, collections.MutableMapping): super(_DictWrapper, self).__init__(wrapped_dict) self._update_snapshot() + # pylint: disable=protected-access + def __copy__(self): + copied = super(_DictWrapper, self).__copy__() + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + copied._non_string_key = self._non_string_key + return copied + + def __deepcopy__(self, memo): + copied = super(_DictWrapper, self).__deepcopy__(memo) + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + copied._non_string_key = self._non_string_key + return copied + # pylint: enable=protected-access + def _make_storage(self, wrapped_dict): """Re-use the wrapped dict for storage (to force them to be in sync).""" return wrapped_dict diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index 4638917b4c..5597c7c772 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import os import numpy @@ -424,6 +425,104 @@ class MappingTests(test.TestCase): new_dict.update(model.d) self.assertEqual({1: 3}, new_dict) + def testListShallowCopy(self): + root = tracking.Checkpointable() + orig_list = [[1.]] + root.a = orig_list + copied = copy.copy(root.a) + self.assertAllEqual([[1.]], copied) + self.assertIsNot(root.a, copied) + self.assertIs(root.a[0], copied[0]) + + # Dirtiness should be inherited + util.list_objects(root.a) + orig_list.append(1.) + with self.assertRaises(ValueError): + util.list_objects(root.a) + with self.assertRaises(ValueError): + util.list_objects(copy.copy(root.a)) + + def testListDeepCopy(self): + root = tracking.Checkpointable() + orig_list = [[1.]] + root.a = orig_list + copied = copy.deepcopy(root.a) + self.assertAllEqual([[1.]], copied) + self.assertIsNot(root.a, copied) + self.assertIsNot(root.a[0], copied[0]) + + # Dirtiness should be inherited + util.list_objects(root.a) + orig_list.append(1.) + with self.assertRaises(ValueError): + util.list_objects(root.a) + with self.assertRaises(ValueError): + util.list_objects(copy.deepcopy(root.a)) + + def testDictShallowCopy(self): + root = tracking.Checkpointable() + orig_dict = {"a": [1.]} + root.a = orig_dict + copied = copy.copy(root.a) + self.assertAllEqual([1.], copied["a"]) + self.assertIsNot(root.a, copied) + self.assertIs(root.a["a"], copied["a"]) + + # Dirtiness should be inherited + util.list_objects(root.a) + orig_dict["b"] = [] + with self.assertRaises(ValueError): + util.list_objects(root.a) + with self.assertRaises(ValueError): + util.list_objects(copy.copy(root.a)) + + def testDictDeepCopy(self): + root = tracking.Checkpointable() + orig_dict = {"a": [1.]} + root.a = orig_dict + copied = copy.deepcopy(root.a) + self.assertAllEqual([1.], copied["a"]) + self.assertIsNot(root.a, copied) + self.assertIsNot(root.a["a"], copied["a"]) + + # Dirtiness should be inherited + util.list_objects(root.a) + orig_dict["b"] = [] + with self.assertRaises(ValueError): + util.list_objects(root.a) + with self.assertRaises(ValueError): + util.list_objects(copy.deepcopy(root.a)) + + def testShallowCopyCheckpointable(self): + original = tracking.Checkpointable() + original_sub = tracking.Checkpointable() + original.a = [[1.]] + original.b = {"a": original_sub} + shallow_copied = copy.copy(original) + self.assertIs(original_sub, shallow_copied.b["a"]) + self.assertIsNot(original, shallow_copied) + self.assertEqual([[1.]], shallow_copied.a) + shallow_deps = util.list_objects(shallow_copied) + self.assertIn(shallow_copied.a, shallow_deps) + self.assertIn(shallow_copied.b, shallow_deps) + self.assertIn(shallow_copied.b["a"], shallow_deps) + + def testDeepCopyCheckpointable(self): + original = tracking.Checkpointable() + original_sub = tracking.Checkpointable() + original.a = [[1.]] + original.b = {"a": original_sub} + deep_copied = copy.deepcopy(original) + self.assertIsNot(original, deep_copied) + self.assertIsNot(original_sub, deep_copied.b["a"]) + self.assertEqual([[1.]], deep_copied.a) + self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable) + deps = util.list_objects(deep_copied) + self.assertIn(deep_copied.a, deps) + self.assertIn(deep_copied.b, deps) + self.assertIn(deep_copied.b["a"], deps) + self.assertNotIn(original_sub, deps) + def testConstructableFromSequence(self): result = data_structures._DictWrapper([(1, 2), (3, 4)]) self.assertIsInstance(result, dict) -- cgit v1.2.3 From 3e137b24b06a81772402b86392dbd158653d487b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 10:43:05 -0700 Subject: Remove note in TF for Android build instructions about Bazel not supporting NDK 15/16. PiperOrigin-RevId: 212292791 --- tensorflow/contrib/lite/java/demo/README.md | 6 +----- tensorflow/examples/android/README.md | 8 -------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index e3cea19e16..6a3f0651d0 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -20,9 +20,6 @@ code to merge. - Make sure to install the latest version of Bazel. Some distributions ship with Bazel 0.5.4, which is too old. - Bazel requires Android Build Tools `26.0.1` or higher. - - **Bazel is incompatible with NDK revisions 15 and above,** with revision - 16 being a compile-breaking change. [Download an older version manually - instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites) - You also need to install the Android Support Repository, available through Android Studio under `Android SDK Manager -> SDK Tools -> Android Support Repository`. @@ -37,8 +34,7 @@ code to merge. - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that you have installed. - By default, Android Studio will install the SDK to `~/Android/Sdk` and - the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual - download until Bazel supports NDK 16. See bullet points under (1)). + the NDK to `~/Android/Sdk/ndk-bundle`. 2. Build the app with Bazel. The demo needs C++11: diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index dac9b7ab82..82bc3ffda9 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -121,10 +121,6 @@ the Android NDK and SDK must be installed on your system. 2. The Android NDK is required to build the native (C/C++) TensorFlow code. The current recommended version is 14b, which may be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). - - * NDK 16, the revision released in November 2017, is **incompatible** with - Bazel. See [here](https://github.com/tensorflow/tensorflow/issues/14918). - 3. The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android @@ -132,10 +128,6 @@ the Android NDK and SDK must be installed on your system. 23 is required to build the TF Android demo (though it will run on API >= 21 devices). - - The Android Studio SDK Manager's NDK installer will install the latest - revision of the NDK, which is **incompatible** with Bazel. You'll need - to download an older version manually, as (2) suggests. - ##### Edit WORKSPACE NOTE: As long as you have the SDK and NDK installed, the `./configure` script -- cgit v1.2.3 From 54273565a7b877ef448c29650409a60021cf6c5e Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 10 Sep 2018 10:47:25 -0700 Subject: Log all tensor allocations in eager mode when VLOG_IS_ON. PiperOrigin-RevId: 212293675 --- tensorflow/core/common_runtime/eager/context.cc | 1 + tensorflow/core/common_runtime/eager/context.h | 4 ++++ tensorflow/core/common_runtime/eager/execute.cc | 2 +- tensorflow/core/common_runtime/eager/kernel_and_device.cc | 1 + tensorflow/core/common_runtime/eager/kernel_and_device.h | 8 ++++++-- tensorflow/core/common_runtime/eager/kernel_and_device_test.cc | 4 ++-- 6 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 879a794368..37fc031985 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -56,6 +56,7 @@ EagerContext::EagerContext(const SessionOptions& opts, log_device_placement_(opts.config.log_device_placement()), num_active_steps_(0), async_default_(async), + log_memory_(LogMemory::IsEnabled()), env_(opts.env), use_send_tensor_rpc_(false) { if (device_mgr_owned) { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index eb6eb0d55a..5ed6057ec6 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #endif +#include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -141,6 +142,7 @@ class EagerContext { void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); bool LogDevicePlacement() { return log_device_placement_; } + bool LogMemory() { return log_memory_; } Rendezvous* GetRendezvous() { return rendezvous_; } @@ -261,6 +263,8 @@ class EagerContext { std::unordered_map thread_local_async_ GUARDED_BY(async_map_mu_); + const bool log_memory_; + Env* const env_; #ifndef __ANDROID__ diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 5b3a64ba98..1da1326a9a 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -296,7 +296,7 @@ Status EagerLocalExecute(EagerOperation* op, LOG(INFO) << "Executing op " << ndef.op() << " in device " << device->name(); } - kernel = new KernelAndDevice(ctx->GetRendezvous()); + kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory()); auto* flr = ctx->func_lib(device); if (flr == nullptr) { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 3d61ff4dc2..59f94506b7 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -95,6 +95,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container, params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; params.cancellation_manager = &cm_; + params.log_memory = log_memory_; if (stats != nullptr) { params.track_allocations = true; } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 0ef419cbaa..ed76c4f601 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -56,8 +56,11 @@ class KernelAndDevice { static Status InitOp(Device* device, const NodeDef& ndef, KernelAndDevice* out); - KernelAndDevice(tensorflow::Rendezvous* rendez) - : device_(nullptr), flib_(nullptr), rendez_(rendez) {} + KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory) + : device_(nullptr), + flib_(nullptr), + rendez_(rendez), + log_memory_(log_memory) {} // TODO(ashankar): Handle list-valued inputs. Status Run(std::vector* inputs, std::vector* outputs, @@ -87,6 +90,7 @@ class KernelAndDevice { DataTypeVector output_dtypes_; std::function)>* runner_; std::function)> default_runner_; + const bool log_memory_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index 6abe98f53c..da280b2317 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -104,7 +104,7 @@ void BM_KernelAndDeviceInit(int iters) { .NumInputs(2) .BuildNodeDef()); TestEnv env; - KernelAndDevice k(nullptr); + KernelAndDevice k(nullptr, false); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), @@ -127,7 +127,7 @@ void BM_KernelAndDeviceRun(int iters) { .NumInputs(inputs.size()) .BuildNodeDef()); TestEnv env; - KernelAndDevice kernel(nullptr); + KernelAndDevice kernel(nullptr, false); TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), nullptr, &kernel)); tensorflow::testing::StartTiming(); -- cgit v1.2.3 From a0bec62c0219e143a8b0d8e3dd3fb5b577db388e Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Mon, 10 Sep 2018 10:47:54 -0700 Subject: Add helper functions that allow users to write TFRecords in memory. PiperOrigin-RevId: 212293765 --- tensorflow/core/lib/io/record_reader.cc | 3 --- tensorflow/core/lib/io/record_reader.h | 8 ++++++++ tensorflow/core/lib/io/record_writer.cc | 15 ++++----------- tensorflow/core/lib/io/record_writer.h | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index c24628be57..f93ebea771 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -109,9 +109,6 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) { } Status RecordReader::ReadRecord(uint64* offset, string* record) { - static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); - static const size_t kFooterSize = sizeof(uint32); - // Position the input stream. int64 curr_pos = input_stream_->Tell(); int64 desired_pos = static_cast(*offset); diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index c05f9e1b36..11af1366b0 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -58,6 +58,14 @@ class RecordReaderOptions { // Note: this class is not thread safe; external synchronization required. class RecordReader { public: + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + // Create a reader that will return log records from "*file". // "*file" must remain live while this Reader is in use. explicit RecordReader( diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 6e71d23e71..2c6db2487e 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -88,10 +88,6 @@ RecordWriter::~RecordWriter() { } } -static uint32 MaskedCrc(const char* data, size_t n) { - return crc32c::Mask(crc32c::Value(data, n)); -} - Status RecordWriter::WriteRecord(StringPiece data) { if (dest_ == nullptr) { return Status(::tensorflow::error::FAILED_PRECONDITION, @@ -102,13 +98,10 @@ Status RecordWriter::WriteRecord(StringPiece data) { // uint32 masked crc of length // byte data[length] // uint32 masked crc of data - char header[sizeof(uint64) + sizeof(uint32)]; - core::EncodeFixed64(header + 0, data.size()); - core::EncodeFixed32(header + sizeof(uint64), - MaskedCrc(header, sizeof(uint64))); - char footer[sizeof(uint32)]; - core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size())); - + char header[kHeaderSize]; + char footer[kFooterSize]; + PopulateHeader(header, data.data(), data.size()); + PopulateFooter(footer, data.data(), data.size()); TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); TF_RETURN_IF_ERROR(dest_->Append(data)); return dest_->Append(StringPiece(footer, sizeof(footer))); diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 6a2bf66d12..1212e1fafb 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ #define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ +#include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/hash/crc32c.h" #if !defined(IS_SLIM_BUILD) #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" @@ -47,6 +49,14 @@ class RecordWriterOptions { class RecordWriter { public: + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + // Create a writer that will append data to "*dest". // "*dest" must be initially empty. // "*dest" must remain live while this Writer is in use. @@ -72,13 +82,35 @@ class RecordWriter { // are invalid. Status Close(); + // Utility method to populate TFRecord headers. Populates record-header in + // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1]. + inline static void PopulateHeader(char* header, const char* data, size_t n); + + // Utility method to populate TFRecord footers. Populates record-footer in + // "footer[0,kFooterSize-1]". The record-footer is based on data[0, n-1]. + inline static void PopulateFooter(char* footer, const char* data, size_t n); + private: WritableFile* dest_; RecordWriterOptions options_; + inline static uint32 MaskedCrc(const char* data, size_t n) { + return crc32c::Mask(crc32c::Value(data, n)); + } + TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); }; +void RecordWriter::PopulateHeader(char* header, const char* data, size_t n) { + core::EncodeFixed64(header + 0, n); + core::EncodeFixed32(header + sizeof(uint64), + MaskedCrc(header, sizeof(uint64))); +} + +void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) { + core::EncodeFixed32(footer, MaskedCrc(data, n)); +} + } // namespace io } // namespace tensorflow -- cgit v1.2.3 From b5c0161db4546dd8a71239ab563cd7398c9cff2c Mon Sep 17 00:00:00 2001 From: Shivani Agrawal Date: Mon, 10 Sep 2018 10:49:18 -0700 Subject: Automated rollback of commit e258e52d2c4060fc26fda43e4ce068d5ba2ab1ff PiperOrigin-RevId: 212294062 --- .../python/kernel_tests/stats_dataset_ops_test.py | 25 ++++++++++++++++++++++ .../python/kernel_tests/stats_dataset_test_base.py | 10 +++++++++ .../core/kernels/data/prefetch_dataset_op.cc | 25 +++++++++++++++++----- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 43067b4245..e25570c5ad 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -75,6 +75,31 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + def testPrefetchBufferUtilization(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch( + -1).apply(stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + float(i + 1)) + self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization", + 0, 1) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + 100) + def testReinitialize(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py index 9a13acf8f0..2f5a44408f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -34,6 +34,16 @@ class StatsDatasetTestBase(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertLessEqual(min_value, value.histo.min) + self.assertGreaterEqual(max_value, value.histo.max) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasSum(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() summary_proto.ParseFromString(summary_str) diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index baf448e572..ad7d5eb3ff 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" +#include + #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace data { @@ -71,7 +73,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { public: explicit Iterator(const Params& params) : DatasetIterator(params), - auto_tuner_(params.dataset->buffer_size_) {} + auto_tuner_(params.dataset->buffer_size_) { + std::vector components = + str_util::Split(params.prefix, "::", str_util::SkipEmpty()); + prefix_end_ = components.back(); + } ~Iterator() override { // Signal the prefetch thread to terminate it. We will then @@ -98,6 +104,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { bool* end_of_sequence) override { { mutex_lock l(mu_); + auto stats_aggregator = ctx->stats_aggregator(); TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); // Wait until the next element in the buffer has been // produced, or we are shutting down. @@ -113,7 +120,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } if (!buffer_.empty()) { - return Consume(out_tensors, end_of_sequence); + return Consume(out_tensors, end_of_sequence, stats_aggregator); } if (prefetch_thread_finished_) { @@ -201,8 +208,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { std::vector value; }; - Status Consume(std::vector* out_tensors, bool* end_of_sequence) + Status Consume(std::vector* out_tensors, bool* end_of_sequence, + const std::shared_ptr& stats_aggregator) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (stats_aggregator) { + stats_aggregator->AddToHistogram( + strings::StrCat(prefix_end_, "::buffer_utilization"), + {static_cast(buffer_.size()) / + static_cast(auto_tuner_.buffer_limit())}); + } // A new element is available. Forward the status from computing it, and // (if we successfully got an element) the output values. Status s = buffer_.front().status; @@ -326,6 +340,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { mutex parent_mu_ ACQUIRED_BEFORE(mu_); std::unique_ptr input_impl_ GUARDED_BY(parent_mu_); condition_variable cond_var_; + string prefix_end_; PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); std::unique_ptr prefetch_thread_ GUARDED_BY(mu_); -- cgit v1.2.3 From 8a752ecd583846aa5b3157c4d9c2c7c654beb6fb Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Mon, 10 Sep 2018 11:00:30 -0700 Subject: Update internal-only tags PiperOrigin-RevId: 212296477 --- tensorflow/contrib/lite/testing/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index aad1ecaeb6..3a6c16cafc 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -36,7 +36,7 @@ load( tags = [ "gen_zip_test", "no_oss", - "tflite_not_portable", + "tflite_not_portable_intentional", ], test_name = test_name, deps = [ -- cgit v1.2.3 From c5b14b334e89b9bcb0fd0199481318b8fdd65762 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 11:04:38 -0700 Subject: Bug fix: consult graph's op registry to look up ops. This is needed when the graph contains custom call ops. These functions are found only in the graph's registry and not the default one. PiperOrigin-RevId: 212297305 --- .../compiler/jit/mark_for_compilation_pass.cc | 2 +- .../compiler/jit/mark_for_compilation_pass_test.cc | 47 ++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 44caf0be52..e6cc6e52ae 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -443,7 +443,7 @@ Status FindCompilationCandidates( !registration->requires_compilation) { const OpDef* op_def; TF_RETURN_IF_ERROR( - OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def)); + graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { // We need to be able to constant fold the nodes in // compile_time_const_nodes given constant inputs (required by XLA) and diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 9473ac0a4c..c59770a4c8 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -847,5 +848,51 @@ TEST(XlaCompilationTest, RandomShape) { EXPECT_EQ(clusters["shape"], ""); } +TEST(XlaCompilationTest, RandomShapeWithFunc) { + Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); + + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/"Stateful_func", /*in_def=*/{}, + /*out_def=*/{"out: int32"}, + /*attr_def*/ + {}, /*node_def=*/ + {FunctionDefHelper::Const("shape_shape", 2), + FunctionDefHelper::Const("minval", 1), + FunctionDefHelper::Const("maxval", 20), + {{"shape"}, + "RandomUniformInt", + {"shape_shape:output:0", "minval:output:0", "maxval:output:0"}, + {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}}, + /*ret_def=*/{{"out", "shape:output:0"}}); + + func.mutable_signature()->set_is_stateful(true); + *flib_def.add_function() = std::move(func); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + NodeDef call_node; + call_node.set_name("fn_call"); + call_node.set_op("Stateful_func"); + Status status; + Node* call = root.graph()->AddNode(call_node, &status); + TF_ASSERT_OK(status); + + Output shape = Output(call, 0); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + auto fld = absl::make_unique(OpRegistry::Global(), + flib_def); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["fn_call"], ""); +} + } // namespace } // namespace tensorflow -- cgit v1.2.3 From a8b2dd9f72fe78cca59d525230f5358430fec45c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 11:35:24 -0700 Subject: Fix unhelpful error message For 99% of all usecases, if the expected shape differs from the actual shape, people will typically rerun with an additional print statement to see what the actual output was. PiperOrigin-RevId: 212303323 --- tensorflow/python/framework/test_util.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4bece9e25e..d63abd7f01 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1327,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase): def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): a = self._GetNdArray(a) b = self._GetNdArray(b) - self.assertEqual( - a.shape, b.shape, - "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)) + # When the array rank is small, print its contents. Numpy array printing is + # implemented using inefficient recursion so prints can cause tests to + # time out. + if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): + shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " + "%s.") % (a.shape, b.shape, b) + else: + shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, + b.shape) + self.assertEqual(a.shape, b.shape, shape_mismatch_msg) + if not np.allclose(a, b, rtol=rtol, atol=atol): # Prints more details than np.testing.assert_allclose. # -- cgit v1.2.3 From 96b77a647b1391d43cae869306628b479a22daa4 Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Mon, 10 Sep 2018 11:37:05 -0700 Subject: [TF:XLA] Migrate unit tests to use the HLO verifier (only tests where the conversion is mostly automated). PiperOrigin-RevId: 212303594 --- tensorflow/compiler/xla/service/BUILD | 12 ++++++++++ .../service/bfloat16_conversion_folding_test.cc | 18 +++++++++------ .../xla/service/bfloat16_normalization_test.cc | 22 ++++++++++-------- tensorflow/compiler/xla/service/call_graph_test.cc | 26 +++++++++++----------- tensorflow/compiler/xla/service/cpu/BUILD | 4 ++++ .../xla/service/cpu/conv_canonicalization_test.cc | 8 +++---- .../xla/service/cpu/cpu_copy_insertion_test.cc | 8 +++---- .../service/cpu/cpu_hlo_support_checker_test.cc | 8 +++---- .../xla/service/cpu/shape_partition_test.cc | 8 +++---- tensorflow/compiler/xla/service/cpu/tests/BUILD | 1 + .../xla/service/cpu/tests/cpu_fusion_test.cc | 20 ++++++++--------- .../xla/service/flatten_call_graph_test.cc | 22 +++++++++--------- tensorflow/compiler/xla/service/gpu/BUILD | 3 +++ .../xla/service/gpu/gpu_hlo_schedule_test.cc | 4 ++-- .../service/gpu/gpu_hlo_support_checker_test.cc | 8 +++---- .../xla/service/gpu/stream_assignment_test.cc | 4 ++-- .../compiler/xla/service/heap_simulator_test.cc | 8 +++---- .../compiler/xla/service/hlo_reachability_test.cc | 4 ++-- .../xla/service/hlo_rematerialization_test.cc | 20 ++++++++--------- .../xla/service/hlo_tfgraph_builder_test.cc | 4 ++-- .../compiler/xla/service/tuple_simplifier_test.cc | 20 ++++++++--------- 21 files changed, 130 insertions(+), 102 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 6ace6d3271..1965ba1204 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,6 +87,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -123,6 +124,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -352,6 +354,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -402,6 +405,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -498,6 +502,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -568,6 +573,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1131,6 +1137,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1709,6 +1716,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2237,6 +2245,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2315,6 +2324,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2428,6 +2438,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2888,6 +2899,7 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 6363a21c3b..5f93740887 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloTestBase { +class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { protected: + BFloat16ConversionFoldingTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16ConversionFolding fold(&bfloat16_support_); @@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 933cf873e0..cef0eba14e 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloTestBase { +class BFloat16NormalizationTest : public HloVerifiedTestBase { protected: + BFloat16NormalizationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16Normalization normalization(&bfloat16_support_); @@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module.get())); + EXPECT_FALSE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index cc80b74843..34f3f914d5 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloTestBase { +class CallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) { HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) { HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) { HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(3, call_graph->nodes().size()); @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // Test visitation of only reachable nodes. { @@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. auto module = CreateNewModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 039cbbff6c..8cc522a59e 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -801,6 +801,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -822,6 +823,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -946,6 +948,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -971,6 +974,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 05792795a1..2083f440fd 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloTestBase { +class ConvCanonicalizationTest : public HloVerifiedTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index 4db7fa446e..c9fb34be1c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloTestBase { +class CpuCopyInsertionTest : public HloVerifiedTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*module), 3); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index 0f463e6de6..be1208fb2d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloTestBase { +class CpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("CPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 7d8e51f909..1a3d82de95 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloTestBase { +class ShapePartitionAssignerTest : public HloVerifiedTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloTestBase { +class ShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloTestBase { +class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index f11aff0573..c55206eee7 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,6 +48,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 22721051e5..6bf3810967 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloTestBase { +class CpuFusionTest : public HloVerifiedTestBase { protected: CpuFusionTest() {} @@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -75,7 +75,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); @@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -122,7 +122,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, @@ -184,7 +184,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -209,7 +209,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { << fusion_instruction2->fused_instructions_computation()->ToString(); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, @@ -256,7 +256,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -315,7 +315,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8f6608241e..5fbd73a536 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloTestBase { +class FlattenCallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); + std::unique_ptr flat_call_graph = CallGraph::Build(module); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 569381f5b0..af953a2a16 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -108,6 +108,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -832,6 +833,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -901,6 +903,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 59ade96f7d..b857fa775a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,14 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloTestBase { +class GpuHloScheduleTest : public HloVerifiedTestBase { protected: using HloVec = std::vector; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 0a4089df4c..27a4d0b601 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloTestBase { +class GpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("GPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 8f0dedfa40..c4f43cc9a6 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,14 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloTestBase { +class StreamAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr CreateNewModule() { HloModuleConfig config; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 00a25db467..957c4a6891 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; +class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { auto module = CreateNewModule(); @@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - HloSchedule schedule(module.get()); + HloSchedule schedule(module); schedule.set_sequence(cond_computation, {cond_param, cond_iter, cond_data, cond_lt}); schedule.set_sequence(body_computation, {body_param}); @@ -233,7 +233,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloTestBase { +class HeapSimulatorTest : public HloVerifiedTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 585c95972b..d9848cee0b 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloTestBase {}; +class HloReachabilityTest : public HloVerifiedTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 4b611fe450..f7e82fb1f8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloTestBase { +class HloRematerializationTest : public HloVerifiedTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -177,7 +177,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, module.get())); + /*memory_limit_bytes=*/14 * 1024, module)); EXPECT_TRUE(changed); // Root should not have changed. @@ -211,7 +211,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, module.get())); + /*memory_limit_bytes=*/20 * 1024, module)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -249,7 +249,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // bit lower (17KB) to force rematerialization of the entry computation. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, module.get())); + /*memory_limit_bytes=*/17 * 1024, module)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -282,7 +282,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, module.get())); + /*memory_limit_bytes=*/15 * 1024, module)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -321,7 +321,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // ~12K so pick something slightly larger. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, module.get())); + /*memory_limit_bytes=*/13 * 1024, module)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -390,7 +390,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get())); + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -482,7 +482,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module.get())); + /*memory_limit_bytes=*/22 * 1024, module)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -576,7 +576,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module.get())); + /*memory_limit_bytes=*/22 * 1024, module)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 1e2b31a1f2..6fd734a2b9 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloTestBase { +class HloTfGraphBuilderTest : public HloVerifiedTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 39b693872d..516754e211 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloTestBase { +class TupleSimplifierTest : public HloVerifiedTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { EXPECT_THAT(computation->root_instruction(), gte); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { EXPECT_THAT(computation->root_instruction(), element); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + Run(module, /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); -- cgit v1.2.3 From 656b3e9c847c187ff011982fe806f9f48853ed1a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 12:08:32 -0700 Subject: Match the presubmit test machine setup in the Dockerfile. PiperOrigin-RevId: 212309247 --- .../Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 | 83 ++++++++++++++++++++++ tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu | 43 ----------- 2 files changed, 83 insertions(+), 43 deletions(-) create mode 100644 tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 delete mode 100644 tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 new file mode 100644 index 0000000000..a30858db82 --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 @@ -0,0 +1,83 @@ +# To push a new version, run: +# $ docker build -f Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 \ +# --tag "gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04" . +# $ docker push gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04 +# +# TODO(klimek): Include clang in this image so we can also target clang +# builds. + +FROM ubuntu:14.04 +LABEL maintainer="Manuel Klimek " + +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates apt-transport-https gnupg-curl && \ + rm -rf /var/lib/apt/lists/* && \ + NVIDIA_GPGKEY_SUM=d1be581509378368edeec8c1eb2958702feedf3bc3d17011adbf24efacce4ab5 && \ + NVIDIA_GPGKEY_FPR=ae09fe4bbd223a84b2ccfce3f60f4b3d7fa2af80 && \ + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/7fa2af80.pub && \ + apt-key adv --export --no-emit-version -a $NVIDIA_GPGKEY_FPR | tail -n +2 > cudasign.pub && \ + echo "$NVIDIA_GPGKEY_SUM cudasign.pub" | sha256sum -c --strict - && rm cudasign.pub && \ + echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ + echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list + +ENV CUDA_VERSION 9.0.176 +ENV CUDA_PKG_VERSION 9-0=$CUDA_VERSION-1 +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility +ENV NVIDIA_REQUIRE_CUDA "cuda>=9.0" +ENV NCCL_VERSION 2.2.13 +ENV CUDNN_VERSION 7.2.1.38 + +# TODO(b/110903506): /usr/loca/cuda/lib64/stubs should not be needed in +# LD_LIBRARY_PATH. The stubs/libcuda.so is not meant to used at runtime. The +# correct way to pass the path to bfd-ld is to pass +# -Wl,-rpath-link=/usr/local/cuda/lib64/stubs to all binaries transitively +# depending on libcuda. Optimally, builds targeting cuda would do that +# internally. +ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64/stubs + +LABEL com.nvidia.volumes.needed="nvidia_driver" +LABEL com.nvidia.cuda.version="${CUDA_VERSION}" +LABEL com.nvidia.cudnn.version="${CUDNN_VERSION}" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + cuda-cudart-$CUDA_PKG_VERSION \ + cuda-libraries-$CUDA_PKG_VERSION \ + cuda-cublas-9-0=9.0.176.4-1 \ + libnccl2=$NCCL_VERSION-1+cuda9.0 \ + cuda-libraries-dev-$CUDA_PKG_VERSION \ + cuda-nvml-dev-$CUDA_PKG_VERSION \ + cuda-minimal-build-$CUDA_PKG_VERSION \ + cuda-command-line-tools-$CUDA_PKG_VERSION \ + cuda-core-9-0=9.0.176.3-1 \ + cuda-cublas-dev-9-0=9.0.176.4-1 \ + libnccl-dev=$NCCL_VERSION-1+cuda9.0 \ + libcudnn7-dev=$CUDNN_VERSION-1+cuda9.0 \ + libcudnn7=$CUDNN_VERSION-1+cuda9.0 && \ + ln -s cuda-9.0 /usr/local/cuda && \ + apt-mark hold libnccl2 && \ + apt-mark hold libcudnn7 libcudnn7-dev && \ + rm -rf /var/lib/apt/lists/* + +RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \ + echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf + +# TODO(b/110903506): Provide a link to the SONAME of libcuda.so. +# https://github.com/NVIDIA/nvidia-docker/issues/775 +RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 + +# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find +# libnccl is resolved, delete this block. +RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \ + && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2 + +# Copy and run the install scripts. +COPY install/*.sh /install/ +ARG DEBIAN_FRONTEND=noninteractive +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa && \ + add-apt-repository -y ppa:george-edison55/cmake-3.x +RUN /install/install_deb_packages.sh +RUN /install/install_pip_packages.sh +RUN /install/install_golang.sh + diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu b/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu deleted file mode 100644 index 08dc026328..0000000000 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu +++ /dev/null @@ -1,43 +0,0 @@ -# To push a new version, run: -# $ docker build -f Dockerfile.rbe.gcc.gpu \ -# --tag "gcr.io/asci-toolchain/nosla-nvidia-gcc" . -# $ docker push gcr.io/asci-toolchain/nosla-nvidia-gcc -FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 - -LABEL maintainer="Manuel Klimek " - -# TODO(b/110903506): Fix the nvidia docker image by providing a link to the -# SONAME of libcuda.so. Alternatively, consider using gold or lld which do not -# run into the same problem - that will only work once the tensorflow build does -# not link to libcuda from generators anymore. -# https://github.com/NVIDIA/nvidia-docker/issues/775 -RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 - -# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find -# libnccl is resolved, delete this block. -RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \ - && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2 - -# TODO(b/110903506): Fix tensorflow to not require the use of LD_LIBRARY_PATH. -# The stubs/libcuda.so is not meant to used at runtime. The correct way to -# pass the path to bfd-ld is to pass -Wl,-rpath-link=/usr/local/cuda/lib64/stubs -# to all binaries transitively depending on libcuda. Optimally the tensorflow -# build would do that internally. -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs - -# Copy and run the install scripts. -COPY install/*.sh /install/ -ARG DEBIAN_FRONTEND=noninteractive -RUN /install/install_bootstrap_deb_packages.sh -RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:george-edison55/cmake-3.x -RUN /install/install_deb_packages.sh -RUN /install/install_pip_packages.sh -RUN /install/install_golang.sh - -# Install nccl2. -RUN apt-get update && apt-get install -y \ - libnccl2 \ - libnccl-dev \ - && rm -rf /var/lib/apt-lists/* - -- cgit v1.2.3 From 470305c95c6b607e87ca476e5a109e5993f3cf6f Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Mon, 10 Sep 2018 15:24:22 -0400 Subject: Use random_seed for the process input --- tensorflow/contrib/tensor_forest/kernels/stats_ops.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index f80a34ece6..fe2c91c104 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -246,7 +246,8 @@ class ProcessInputOp : public OpKernel { const Tensor& input_weights = context->input(7); const Tensor& leaf_ids_tensor = context->input(8); - std::unique_ptr data_set(new TensorDataSet(input_spec_, 0)); + std::unique_ptr data_set( + new TensorDataSet(input_spec_, random_seed_)); data_set->set_input_tensors(input_data, sparse_input_indices, sparse_input_values, sparse_input_shape); -- cgit v1.2.3 From dd6d7c5c586b541b9d4793b7578feadd0c2da8f6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 12:33:49 -0700 Subject: Global de-std::unique_ptr cleanup for xla::Literal. PiperOrigin-RevId: 212313258 --- tensorflow/compiler/tf2xla/graph_compiler.cc | 2 +- .../compiler/tf2xla/kernels/index_ops_cpu.cc | 6 +- tensorflow/compiler/tf2xla/lib/util.cc | 26 +- tensorflow/compiler/tf2xla/literal_util_test.cc | 23 +- tensorflow/compiler/tf2xla/tf2xla_test.cc | 8 +- tensorflow/compiler/tf2xla/xla_compiler_test.cc | 198 ++--- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 7 +- tensorflow/compiler/xla/client/client.cc | 12 +- tensorflow/compiler/xla/client/client.h | 10 +- tensorflow/compiler/xla/client/lib/testing.cc | 4 +- tensorflow/compiler/xla/client/local_client.cc | 20 +- tensorflow/compiler/xla/client/local_client.h | 10 +- tensorflow/compiler/xla/client/xla_builder.cc | 2 +- tensorflow/compiler/xla/client/xla_builder.h | 38 +- tensorflow/compiler/xla/literal.cc | 133 ++- tensorflow/compiler/xla/literal.h | 53 +- tensorflow/compiler/xla/literal_test.cc | 910 ++++++++++----------- tensorflow/compiler/xla/literal_util.cc | 273 +++---- tensorflow/compiler/xla/literal_util.h | 228 +++--- tensorflow/compiler/xla/packed_literal_reader.cc | 10 +- tensorflow/compiler/xla/packed_literal_reader.h | 3 +- .../xla/python/local_computation_builder.cc | 20 +- .../xla/python/local_computation_builder.h | 8 +- .../xla/python/local_computation_builder.i | 18 +- tensorflow/compiler/xla/python/numpy_bridge.cc | 7 +- tensorflow/compiler/xla/python/numpy_bridge.h | 2 +- tensorflow/compiler/xla/reference_util.cc | 28 +- tensorflow/compiler/xla/reference_util_test.cc | 50 +- tensorflow/compiler/xla/rpc/grpc_client_test.cc | 5 +- .../compiler/xla/service/algebraic_simplifier.cc | 19 +- .../xla/service/algebraic_simplifier_test.cc | 6 +- .../compiler/xla/service/batchnorm_expander.cc | 12 +- .../xla/service/bfloat16_propagation_test.cc | 4 +- .../compiler/xla/service/buffer_assignment_test.cc | 5 +- .../compiler/xla/service/buffer_liveness_test.cc | 14 +- .../service/convolution_feature_group_converter.cc | 4 +- .../xla/service/cpu/tests/cpu_fusion_test.cc | 15 +- .../xla/service/cpu/tests/cpu_infeed_test.cc | 66 +- .../xla/service/cpu/tests/cpu_noalias_test.cc | 3 +- .../xla/service/elemental_ir_emitter_test.cc | 6 +- .../xla/service/generic_transfer_manager.cc | 4 +- .../service/gpu/cudnn_convolution_rewriter_test.cc | 2 +- .../xla/service/gpu/pad_for_tensor_cores.cc | 5 +- .../compiler/xla/service/gpu/pad_insertion.cc | 16 +- .../xla/service/gpu/tests/gpu_copy_test.cc | 3 +- .../compiler/xla/service/gpu/tests/infeed_test.cc | 32 +- .../compiler/xla/service/hlo_constant_folding.cc | 4 +- .../xla/service/hlo_constant_folding_test.cc | 4 +- .../compiler/xla/service/hlo_creation_utils.cc | 11 +- .../xla/service/hlo_creation_utils_test.cc | 68 +- tensorflow/compiler/xla/service/hlo_cse_test.cc | 6 +- tensorflow/compiler/xla/service/hlo_evaluator.cc | 237 +++--- tensorflow/compiler/xla/service/hlo_evaluator.h | 57 +- .../compiler/xla/service/hlo_evaluator_test.cc | 484 +++++------ .../xla/service/hlo_evaluator_typed_visitor.h | 195 +++-- tensorflow/compiler/xla/service/hlo_instruction.cc | 4 +- tensorflow/compiler/xla/service/hlo_instruction.h | 3 +- .../compiler/xla/service/hlo_instructions.cc | 15 +- tensorflow/compiler/xla/service/hlo_instructions.h | 12 +- tensorflow/compiler/xla/service/hlo_parser.cc | 54 +- tensorflow/compiler/xla/service/hlo_runner.cc | 28 +- tensorflow/compiler/xla/service/hlo_runner.h | 25 +- .../compiler/xla/service/hlo_verifier_test.cc | 8 +- .../compiler/xla/service/indexed_array_analysis.cc | 6 +- .../compiler/xla/service/indexed_array_analysis.h | 14 +- tensorflow/compiler/xla/service/inliner_test.cc | 6 +- .../compiler/xla/service/interpreter/executable.cc | 15 +- .../compiler/xla/service/layout_assignment_test.cc | 2 +- tensorflow/compiler/xla/service/service.cc | 42 +- .../compiler/xla/service/transfer_manager.cc | 12 +- tensorflow/compiler/xla/service/transfer_manager.h | 8 +- .../xla/service/tuple_points_to_analysis_test.cc | 8 +- .../compiler/xla/service/while_loop_analysis.cc | 19 +- .../xla/tests/array_elementwise_ops_test.cc | 256 +++--- .../compiler/xla/tests/batch_normalization_test.cc | 128 ++- tensorflow/compiler/xla/tests/bfloat16_test.cc | 26 +- .../compiler/xla/tests/broadcast_simple_test.cc | 89 +- tensorflow/compiler/xla/tests/broadcast_test.cc | 53 +- tensorflow/compiler/xla/tests/call_test.cc | 19 +- .../xla/tests/check_execution_arity_test.cc | 14 +- .../compiler/xla/tests/client_library_test_base.cc | 71 +- .../compiler/xla/tests/client_library_test_base.h | 101 ++- tensorflow/compiler/xla/tests/client_test.cc | 29 +- .../compiler/xla/tests/compilation_cache_test.cc | 19 +- .../compiler/xla/tests/compute_constant_test.cc | 26 +- tensorflow/compiler/xla/tests/concat_test.cc | 20 +- tensorflow/compiler/xla/tests/conditional_test.cc | 64 +- tensorflow/compiler/xla/tests/constants_test.cc | 25 +- tensorflow/compiler/xla/tests/convert_test.cc | 40 +- .../tests/convolution_dimension_numbers_test.cc | 3 +- tensorflow/compiler/xla/tests/convolution_test.cc | 115 ++- .../xla/tests/convolution_variants_test.cc | 24 +- tensorflow/compiler/xla/tests/copy_test.cc | 60 +- .../compiler/xla/tests/cross_replica_sum_test.cc | 11 +- tensorflow/compiler/xla/tests/custom_call_test.cc | 12 +- .../compiler/xla/tests/deconstruct_tuple_test.cc | 41 +- .../compiler/xla/tests/dot_operation_test.cc | 69 +- tensorflow/compiler/xla/tests/dynamic_ops_test.cc | 117 ++- .../compiler/xla/tests/execution_profile_test.cc | 2 +- .../tests/exhaustive_f32_elementwise_op_test.cc | 2 +- tensorflow/compiler/xla/tests/fusion_test.cc | 130 ++- .../compiler/xla/tests/gather_operation_test.cc | 161 ++-- tensorflow/compiler/xla/tests/hlo_test_base.cc | 23 +- tensorflow/compiler/xla/tests/hlo_test_base.h | 12 +- tensorflow/compiler/xla/tests/literal_test_util.h | 30 +- .../compiler/xla/tests/literal_test_util_test.cc | 43 +- .../xla/tests/local_client_allocation_test.cc | 6 +- .../xla/tests/local_client_execute_test.cc | 253 +++--- .../compiler/xla/tests/local_client_test_base.cc | 2 +- .../compiler/xla/tests/local_client_test_base.h | 3 +- tensorflow/compiler/xla/tests/map_test.cc | 150 ++-- .../compiler/xla/tests/matrix_ops_simple_test.cc | 22 +- .../compiler/xla/tests/multioutput_fusion_test.cc | 87 +- .../tests/outfeed_in_nested_computation_test.cc | 30 +- tensorflow/compiler/xla/tests/pad_test.cc | 46 +- tensorflow/compiler/xla/tests/params_test.cc | 149 ++-- tensorflow/compiler/xla/tests/prng_test.cc | 62 +- tensorflow/compiler/xla/tests/reduce_hlo_test.cc | 2 +- .../compiler/xla/tests/reduce_precision_test.cc | 37 +- tensorflow/compiler/xla/tests/reduce_test.cc | 123 ++- .../compiler/xla/tests/reduce_window_test.cc | 184 ++--- tensorflow/compiler/xla/tests/replay_test.cc | 16 +- tensorflow/compiler/xla/tests/reshape_test.cc | 308 ++++--- tensorflow/compiler/xla/tests/reverse_test.cc | 14 +- .../xla/tests/round_trip_packed_literal_test.cc | 42 +- .../compiler/xla/tests/round_trip_transfer_test.cc | 51 +- .../compiler/xla/tests/scalar_computations_test.cc | 38 +- tensorflow/compiler/xla/tests/scatter_test.cc | 172 ++-- tensorflow/compiler/xla/tests/slice_test.cc | 16 +- tensorflow/compiler/xla/tests/test_utils.cc | 74 +- tensorflow/compiler/xla/tests/test_utils.h | 12 +- tensorflow/compiler/xla/tests/test_utils_test.cc | 16 +- tensorflow/compiler/xla/tests/token_hlo_test.cc | 20 +- .../compiler/xla/tests/transfer_manager_test.cc | 204 +++-- tensorflow/compiler/xla/tests/tuple_test.cc | 152 ++-- tensorflow/compiler/xla/tests/unary_op_test.cc | 18 +- tensorflow/compiler/xla/tests/while_test.cc | 66 +- .../compiler/xla/tests/xla_hlo_profile_test.cc | 4 +- tensorflow/compiler/xla/text_literal_reader.cc | 11 +- tensorflow/compiler/xla/text_literal_reader.h | 4 +- .../compiler/xla/text_literal_reader_test.cc | 17 +- .../compiler/xla/text_literal_writer_test.cc | 2 +- .../compiler/xla/tools/replay_computation.cc | 17 +- tensorflow/compiler/xrt/kernels/xrt_state_ops.h | 10 +- tensorflow/compiler/xrt/tests/raw_api_test.cc | 36 +- tensorflow/compiler/xrt/xrt_state.cc | 2 +- tensorflow/compiler/xrt/xrt_state.h | 2 +- 147 files changed, 3797 insertions(+), 4195 deletions(-) diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index bc2e640559..82e9eef005 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -81,7 +81,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, TF_ASSIGN_OR_RETURN(auto literal, client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + LiteralToHostTensor(literal, arg.type, &arg.constant_value)); } else { arg.kind = XlaCompiler::Argument::kParameter; } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 22a45b2a11..3d81ae9eb8 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index c267848524..804671fbc7 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::C64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: - literal = std::move( - *xla::LiteralUtil::CreateR0(static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::F16: - literal = std::move(*xla::LiteralUtil::CreateR0( - static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 7dc16b5a46..ed452bceeb 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -27,19 +27,17 @@ TEST(LiteralUtil, LiteralToHostTensor) { // int64 literal can only be converted to an int64 host tensor. { std::vector int64_values = {1, 2, 3}; - std::unique_ptr int64_values_literal = + xla::Literal int64_values_literal = xla::LiteralUtil::CreateR1(absl::Span(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", - LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) + LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) .error_message()); - EXPECT_EQ( - "Cannot convert literal of type S64 to tensor of type qint32", - LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); EXPECT_TRUE( - LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor) - .ok()); + LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); test::ExpectTensorEqual(host_tensor, test::AsTensor(int64_values)); } @@ -48,23 +46,22 @@ TEST(LiteralUtil, LiteralToHostTensor) { // Repeat tests with int32. Tensor host_tensor; std::vector int32_values = {10, 11}; - std::unique_ptr int32_values_literal = + xla::Literal int32_values_literal = xla::LiteralUtil::CreateR1(absl::Span(int32_values)); EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) - .ok()); + LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok()); test::ExpectTensorEqual(host_tensor, test::AsTensor(int32_values)); EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor) + LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor) .ok()); std::vector qint32_values = {10, 11}; test::ExpectTensorEqual(host_tensor, test::AsTensor(qint32_values)); EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", - LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor) + LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor) .error_message()); } } diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 56f7045a98..ab26d939cc 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) { // Set up arguments. auto x_literal = xla::LiteralUtil::CreateR0(10); auto y_literal = xla::LiteralUtil::CreateR0(32); - auto x_global_or = client->TransferToServer(*x_literal); - auto y_global_or = client->TransferToServer(*y_literal); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); TF_EXPECT_OK(x_global_or.status()); TF_EXPECT_OK(y_global_or.status()); std::unique_ptr x_global = @@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) { auto result_or = client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); - std::unique_ptr result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 40ce9fb41c..70efa7781d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -208,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation of a graph where the _Retval node is not necessarily last @@ -264,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } // Tests that the compiler doesn't reorder the parameters. @@ -408,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE( - xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } { @@ -443,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR0(7); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); } } @@ -672,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { update.tensor_array_gradients_accessed); // Tests that the generated computation works. - std::unique_ptr input_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr input_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr input = - xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); + xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr param0_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); + client_->TransferToServer(input).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr output_read = - xla::LiteralUtil::CreateR0(42); - std::unique_ptr output_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr output_grad1 = - xla::LiteralUtil::CreateR1({0, 1}); - std::unique_ptr output_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr output_resource = xla::LiteralUtil::MakeTuple( - {output_base.get(), output_grad1.get(), output_grad2.get()}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal output_read = xla::LiteralUtil::CreateR0(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal output_resource = + xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&output_read, &output_resource}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -866,29 +843,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { void RunAndCheckVariablesComputation( xla::Client* client, const XlaCompiler::CompilationResult& result) { - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({5, 144}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a simple graph that reads and writes a variable. @@ -952,20 +924,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ReturnResourceHandle) { @@ -1069,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = + xla::Literal expected0 = xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { @@ -1138,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({4, 55, 1, -3}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({27, 67, 35, 402}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({27, 67, 35, 402}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a graph which has a function with an invalid op. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index d1534e9a15..d10a504da0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -213,16 +213,15 @@ Status XlaOpKernelContext::ConstantInputReshaped( context_->op_kernel().name(), " input ", index, ".\nError: ", constant_graph.status().error_message()); } - xla::StatusOr> computed = - compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), - &layout); + xla::StatusOr computed = compiler()->client()->ComputeConstant( + constant_graph.ValueOrDie(), &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, " as a compile-time constant.\nError: ", computed.status().error_message()); } - *constant_literal = std::move(*computed.ValueOrDie()); + *constant_literal = std::move(computed).ValueOrDie(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 8818f81312..5dde5b432f 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {} Client::~Client() = default; -StatusOr> Client::Transfer( - const GlobalData& data, const Shape* shape_with_layout) { +StatusOr Client::Transfer(const GlobalData& data, + const Shape* shape_with_layout) { TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { @@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, return Status::OK(); } -StatusOr> Client::TransferFromOutfeed( +StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id, const DeviceHandle* device_handle) { TransferFromOutfeedRequest request; @@ -162,7 +162,7 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( +StatusOr Client::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { @@ -177,8 +177,8 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } -StatusOr> Client::ComputeConstant( - const XlaComputation& computation, const Layout* output_layout) const { +StatusOr Client::ComputeConstant(const XlaComputation& computation, + const Layout* output_layout) const { ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 7960b07868..6f4d33c469 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -96,8 +96,8 @@ class Client { // // If shape_with_layout is not nullptr, it points to a shape whose layout will // be the layout of the returned literal. - StatusOr> Transfer( - const GlobalData& data, const Shape* shape_with_layout = nullptr); + StatusOr Transfer(const GlobalData& data, + const Shape* shape_with_layout = nullptr); // Transfer the given literal to the server. This allocates memory on the // device and copies the literal's contents over. Returns a global data handle @@ -122,7 +122,7 @@ class Client { // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - StatusOr> TransferFromOutfeed( + StatusOr TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); @@ -132,7 +132,7 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options = nullptr, @@ -153,7 +153,7 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - StatusOr> ComputeConstant( + StatusOr ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 6861521acc..25cc37edc4 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -76,7 +76,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { if (DataSizeOfShape(shape) < (1LL << 20)) { - StatusOr> literal_status = MakeFakeLiteral(shape); + StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via // an on-device computation. @@ -84,7 +84,7 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, tensorflow::error::UNIMPLEMENTED); return MakeFakeDataViaDeviceOrDie(shape, client); } - return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie(); + return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 4402ba8762..f96b6c9c26 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments( HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*argument)); - *hlo_snapshot->add_arguments() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument)); + *hlo_snapshot->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments( Status LocalExecutable::RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_result(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*result)); - *hlo_snapshot->mutable_result() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result)); + *hlo_snapshot->mutable_result() = literal.ToProto(); return Status::OK(); } -StatusOr> LocalExecutable::LiteralFromShapedBuffer( +StatusOr LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, backend_->BorrowStream(shaped_buffer.device_ordinal())); @@ -277,7 +275,7 @@ StatusOr LocalClient::LiteralToShapedBuffer( return std::move(scoped_buffer); } -StatusOr> LocalClient::ShapedBufferToLiteral( +StatusOr LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( shaped_buffer.device_ordinal())); @@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal, literal); } -StatusOr> LocalClient::TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal) { +StatusOr LocalClient::TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, shape, literal.get())); + executor, shape, &literal)); return std::move(literal); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 56c3a3da02..feb2f8ec9d 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -84,8 +84,7 @@ class LocalExecutable { Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. - StatusOr> LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer); + StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); // The ordinal of the device which this executable was compiled for. The // executable can run on all equivalent devices (as determined by @@ -132,8 +131,7 @@ class LocalClient : public Client { // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. - StatusOr> ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid // as long as the handle is valid. @@ -151,8 +149,8 @@ class LocalClient : public Client { // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with // Client::TransferFromOutfeed. - StatusOr> TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal); + StatusOr TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal); // Returns the device ordinal that corresponds to the given replica number. // diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 887b970661..4e1ff9e5c0 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -738,7 +738,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 58e8f4e7fa..833eafcf85 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -2112,12 +2112,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*LiteralUtil::CreateR0(value)); + return ConstantLiteral(LiteralUtil::CreateR0(value)); } template XlaOp XlaBuilder::ConstantR1(absl::Span values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template @@ -2129,44 +2129,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template XlaOp XlaBuilder::ConstantR2( std::initializer_list> values) { - return ConstantLiteral(*LiteralUtil::CreateR2(values)); + return ConstantLiteral(LiteralUtil::CreateR2(values)); } template XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(*LiteralUtil::CreateFromArray(values)); + return ConstantLiteral(LiteralUtil::CreateFromArray(values)); } template XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D(values)); + return ConstantLiteral(LiteralUtil::CreateR2FromArray2D(values)); } template XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template @@ -2189,12 +2189,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { template XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *LiteralUtil::CreateR0(value)); + return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); } template XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template @@ -2207,13 +2207,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list> values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR2(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); } template @@ -2221,14 +2221,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateFromArray(values)); + LiteralUtil::CreateFromArray(values)); } template @@ -2236,15 +2235,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateR2FromArray2D(values)); + LiteralUtil::CreateR2FromArray2D(values)); } template @@ -2253,7 +2251,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 3f7635bd40..f1f255efae 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) { return *this; } -std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = absl::make_unique(shape); - literal->root_piece_->ForEachMutableSubpiece( +Literal LiteralBase::CreateFromShape(const Shape& shape) { + Literal literal(shape); + literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { memset(piece->untyped_data(), 0, piece->size_bytes()); @@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } -/* static */ StatusOr> -MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { +/* static */ StatusOr MutableLiteralBase::CreateFromProto( + const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { return InvalidArgument("LiteralProto has no layout"); } - auto literal = absl::make_unique(proto.shape()); + Literal literal(proto.shape()); - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { const LiteralProto* proto_element = &proto; for (int64 i : index) { @@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { } } -std::unique_ptr LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { +Literal LiteralBase::Relayout(const Layout& new_layout, + const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = absl::make_unique(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); + Literal result(new_shape); + TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr LiteralBase::Relayout( - const Shape& shape_with_layout) const { +Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) << " not compatible with literal shape " << ShapeUtil::HumanString(shape()); - std::unique_ptr result = CreateFromShape(shape_with_layout); + Literal result = CreateFromShape(shape_with_layout); ShapeUtil::ForEachSubshape( - result->shape(), + result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); + TF_CHECK_OK(result.CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); } }); return result; } -StatusOr> LiteralBase::Broadcast( +StatusOr LiteralBase::Broadcast( const Shape& result_shape, absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Broadcast only supports arrays."); @@ -598,14 +597,14 @@ StatusOr> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr result = absl::make_unique(result_shape); + Literal result(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in // every iteration of ShapeUtil::ForEachIndex. std::vector scratch_source_index(shape().dimensions_size()); - char* dest_data = static_cast(result->untyped_data()); + char* dest_data = static_cast(result.untyped_data()); const char* source_data = static_cast(untyped_data()); const int64 primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); @@ -627,37 +626,36 @@ StatusOr> LiteralBase::Broadcast( return std::move(result); } -StatusOr> LiteralBase::Reshape( +StatusOr LiteralBase::Reshape( absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } - std::unique_ptr output; + Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { output = Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { - output = CloneToUnique(); + output = Clone(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - *output->mutable_shape_do_not_use() = + *output.mutable_shape_do_not_use() = ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + int64 elements_after = ShapeUtil::ElementsIn(output.shape()); if (elements_before != elements_after) { return InvalidArgument( "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", ShapeUtil::HumanString(shape()), - ShapeUtil::HumanString(output->shape())); + ShapeUtil::HumanString(output.shape())); } return std::move(output); } -std::unique_ptr LiteralBase::Transpose( - absl::Span permutation) const { +Literal LiteralBase::Transpose(absl::Span permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; @@ -687,32 +685,31 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = absl::make_unique(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + Literal new_literal(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes()); return new_literal; } template -std::unique_ptr LiteralBase::SliceInternal( +Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span start_indices) const { - auto result_literal = absl::make_unique(result_shape); + Literal result_literal(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell( + result_literal.EachCell( [&](absl::Span indices, NativeT /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); - result_literal->Set(indices, value); + result_literal.Set(indices, value); }); return result_literal; } -std::unique_ptr LiteralBase::Slice( - absl::Span start_indices, - absl::Span limit_indices) const { +Literal LiteralBase::Slice(absl::Span start_indices, + absl::Span limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; @@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const { return result; } -std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = absl::make_unique(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - string LiteralBase::GetAsString(absl::Span multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); @@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString( namespace { template -std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { +Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, + const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique(ShapeUtil::ChangeElementType( + Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); auto src_data = src_literal.data(); - auto dest_data = result_literal->template data(); + auto dest_data = result_literal.template data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { @@ -1208,8 +1199,7 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { +Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1217,7 +1207,7 @@ std::unique_ptr ConvertBetweenNativeTypes( template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); @@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { // identical sizes higher up. template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { +Literal ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique( + Literal result_literal( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; absl::Span src_data = src_literal.data(); - absl::Span dest_data = result_literal->data(); + absl::Span dest_data = result_literal.data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); @@ -1254,8 +1244,7 @@ std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { +Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { return BitcastBetweenNativeTypes< @@ -1273,9 +1262,9 @@ std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, } template -StatusOr> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, + PrimitiveType primitive_dest_type, + bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ case (type): \ @@ -1307,12 +1296,12 @@ StatusOr> ConvertIfDestTypeMatches( PrimitiveType_Name(primitive_dest_type)); } -StatusOr> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertSwitch(const LiteralBase& literal, + PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); + return literal.Clone(); } switch (literal.shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ @@ -1342,12 +1331,12 @@ StatusOr> ConvertSwitch( } // namespace -StatusOr> LiteralBase::Convert( +StatusOr LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> LiteralBase::BitcastConvert( +StatusOr LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { @@ -1362,8 +1351,8 @@ StatusOr> LiteralBase::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { +StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape, + bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { if (round_f32_to_bf16 && shape().element_type() == F32 && dest_shape.element_type() == BF16) { @@ -1381,11 +1370,9 @@ StatusOr> LiteralBase::ConvertToShape( TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); + elements.push_back(std::move(new_element)); } - auto converted = absl::make_unique(); - *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); - return std::move(converted); + return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); } /* static */ Literal MutableLiteralBase::MoveIntoTuple( diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index b928cb6374..fa5b5f7fab 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -223,25 +223,21 @@ class LiteralBase { // // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + StatusOr ConvertToShape(const Shape& dest_shape, + bool round_f32_to_bf16 = false) const; // Converts this literal to another primitive type using a bitcast // conversion. The to and from primitive types must have the same bit // width. Returns an error if the conversion is not possible. This literal // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; + StatusOr BitcastConvert(PrimitiveType primitive_dest_type) const; // Converts this literal to another primitive type. Returns an error if the // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; + StatusOr Convert(PrimitiveType primitive_dest_type) const; - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr. + // Clones the underlying buffers into a new Literal. Literal Clone() const; - std::unique_ptr CloneToUnique() const; // TODO(b/67651157): The methods below which perform computation on Literals // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with @@ -259,24 +255,23 @@ class LiteralBase { // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; + Literal Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; // An overload of Relayout which changes the layout of the entire shape rather // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; + Literal Relayout(const Shape& shape_with_layout) const; // Creates a new literal by reshaping this literal to have the given // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. - StatusOr> Reshape( - absl::Span dimensions) const; + StatusOr Reshape(absl::Span dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. - StatusOr> Broadcast( - const Shape& result_shape, absl::Span dimensions) const; + StatusOr Broadcast(const Shape& result_shape, + absl::Span dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -285,7 +280,7 @@ class LiteralBase { // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. // This literal must be an array. - std::unique_ptr Transpose(absl::Span permutation) const; + Literal Transpose(absl::Span permutation) const; // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the @@ -293,15 +288,15 @@ class LiteralBase { // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. // This literal must be an array. - std::unique_ptr Slice(absl::Span start_indices, - absl::Span limit_indices) const; + Literal Slice(absl::Span start_indices, + absl::Span limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. // This literal must be an array. template - std::unique_ptr Replicate(int64 times) const; + Literal Replicate(int64 times) const; // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive @@ -312,7 +307,7 @@ class LiteralBase { // initialization, then reinitialization. Conside if a call to // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. - static std::unique_ptr CreateFromShape(const Shape& shape); + static Literal CreateFromShape(const Shape& shape); protected: // A data structure representing a subshape at a particular ShapeIndex within @@ -539,8 +534,8 @@ class LiteralBase { private: template - std::unique_ptr SliceInternal( - const Shape& result_shape, absl::Span start_indices) const; + Literal SliceInternal(const Shape& result_shape, + absl::Span start_indices) const; }; // Abstract base class representing a mutable literal in XLA. @@ -687,8 +682,7 @@ class MutableLiteralBase : public LiteralBase { static Literal MoveIntoTuple(absl::Span elements); // Serialize from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); + static StatusOr CreateFromProto(const LiteralProto& proto); protected: // Returns the piece at the given ShapeIndex. @@ -1137,15 +1131,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) { } template -std::unique_ptr LiteralBase::Replicate(int64 times) const { +Literal LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = absl::make_unique( - ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); + Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal.shape()); if (elements == 0) { return literal; } @@ -1157,7 +1150,7 @@ std::unique_ptr LiteralBase::Replicate(int64 times) const { bool done = false; while (!done) { const auto element = Get(input_indices); - literal->Set(output_indices, element); + literal.Set(output_indices, element); done = true; for (int n = 0; n < output_indices.size(); ++n) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 1a64594db8..ba7fd29a62 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -92,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test { Layout layout_r3_dim0minor_; Layout layout_r4_dim0major_; Layout layout_r4_dim0minor_; - std::unique_ptr literal_r4_2x2x3x3_dim0major_; - std::unique_ptr literal_r4_2x2x3x3_dim0minor_; + Literal literal_r4_2x2x3x3_dim0major_; + Literal literal_r4_2x2x3x3_dim0minor_; }; TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - EXPECT_EQ("true", true_lit->ToString()); + EXPECT_EQ("true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - EXPECT_EQ("false", false_lit->ToString()); + EXPECT_EQ("false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - EXPECT_EQ("42", u32_lit->ToString()); + EXPECT_EQ("42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - EXPECT_EQ("-999", s32_lit->ToString()); + EXPECT_EQ("-999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - EXPECT_EQ("3.14", f32_lit->ToString()); + EXPECT_EQ("3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", f16_lit->ToString()); + EXPECT_EQ("0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString()); + EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", bf16_lit->ToString()); + EXPECT_EQ("0.5", bf16_lit.ToString()); // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.14062", bf16_lit_truncated->ToString()); + ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - EXPECT_EQ("9", bf16_lit_truncated2->ToString()); + EXPECT_EQ("9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{101}", pred_vec->ToString()); + EXPECT_EQ("{101}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -143,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) { { 3, 4 }, { 5, 6 } })"; - EXPECT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, R3ToString) { @@ -157,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - EXPECT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -171,7 +171,7 @@ f32[2,2] { { 3, 4 } } ))"; - EXPECT_EQ(expected, tuple->ToString()); + EXPECT_EQ(expected, tuple.ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -187,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { // clang-format on auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[2,3,2] { { { 1, 2 }, { 3, 4 }, @@ -220,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) { }; std::vector expected_values = {8, 9, 7, 10}; - EXPECT_EQ(literal->sparse_indices()->data(), + EXPECT_EQ(literal.sparse_indices()->data(), absl::Span(expected_indices.data(), expected_indices.num_elements())); - EXPECT_EQ(literal->data(), absl::Span(expected_values)); + EXPECT_EQ(literal.data(), absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -234,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { { /*i0=0*/ { /*i1=0*/ @@ -254,9 +254,9 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { - EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), + EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(), ElementsAre(2, 2, 3, 3)); - string result = literal_r4_2x2x3x3_dim0major_->ToString(); + string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { { /*i0=0*/ { /*i1=0*/ @@ -294,7 +294,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { }); // clang-format on std::vector> seen; - literal->EachCellAsString( + literal.EachCellAsString( [&seen](absl::Span indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -310,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) { auto f32_42 = LiteralUtil::CreateR0(42.0); auto f32_42_clone = LiteralUtil::CreateR0(42.0); - EXPECT_EQ(*f32_42, *f32_42); - EXPECT_EQ(*f32_42, *f32_42_clone); + EXPECT_EQ(f32_42, f32_42); + EXPECT_EQ(f32_42, f32_42_clone); auto f32_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*f32_42, *f32_123); + EXPECT_NE(f32_42, f32_123); auto f64_42 = LiteralUtil::CreateR0(42.0); - EXPECT_NE(*f32_42, *f64_42); + EXPECT_NE(f32_42, f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { @@ -330,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { auto scalar = LiteralUtil::CreateR0(1.0); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(*matrix, *matrix); - EXPECT_EQ(*matrix, *matrix_clone); - EXPECT_NE(*matrix, *matrix_different); - EXPECT_NE(*matrix, *vector_literal); - EXPECT_NE(*matrix, *scalar); - EXPECT_NE(*matrix, nil); + EXPECT_EQ(matrix, matrix); + EXPECT_EQ(matrix, matrix_clone); + EXPECT_NE(matrix, matrix_different); + EXPECT_NE(matrix, vector_literal); + EXPECT_NE(matrix, scalar); + EXPECT_NE(matrix, nil); EXPECT_EQ(nil, nil); } @@ -344,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) { auto token1 = LiteralUtil::CreateToken(); auto scalar = LiteralUtil::CreateR0(1.0); - EXPECT_EQ(*token0, *token1); - EXPECT_NE(*token0, *scalar); + EXPECT_EQ(token0, token1); + EXPECT_NE(token0, scalar); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}), - *LiteralUtil::MakeTuple({token0.get()})); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({token1.get(), scalar.get()})); - EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({scalar.get(), token1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0}), + LiteralUtil::MakeTuple({&token0})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&token1, &scalar})); + EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&scalar, &token1})); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = absl::make_unique( - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); - colmajor->Set({0, 0}, 1.0); - colmajor->Set({0, 1}, 2.0); - colmajor->Set({1, 0}, 3.0); - colmajor->Set({1, 1}, 4.0); + Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); + colmajor.Set({0, 0}, 1.0); + colmajor.Set({0, 1}, 2.0); + colmajor.Set({1, 0}, 3.0); + colmajor.Set({1, 1}, 4.0); - auto rowmajor = absl::make_unique( - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); - rowmajor->Set({0, 0}, 1.0); - rowmajor->Set({0, 1}, 2.0); - rowmajor->Set({1, 0}, 3.0); - rowmajor->Set({1, 1}, 4.0); + Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); + rowmajor.Set({0, 0}, 1.0); + rowmajor.Set({0, 1}, 2.0); + rowmajor.Set({1, 0}, 3.0); + rowmajor.Set({1, 1}, 4.0); - EXPECT_EQ(*rowmajor, *colmajor); + EXPECT_EQ(rowmajor, colmajor); } TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. auto scalar_clone = LiteralUtil::CreateR0(1.0); - auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_EQ(*tuple1, *tuple2); + auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix}); + EXPECT_EQ(tuple1, tuple2); // Tuple with elements reversed. - auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_NE(*tuple1, *reversed_tuple); + auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar}); + EXPECT_NE(tuple1, reversed_tuple); // Tuple with different value. auto scalar_42 = LiteralUtil::CreateR0(42.0); - auto different_tuple = - LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_NE(*tuple1, *different_tuple); + auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix}); + EXPECT_NE(tuple1, different_tuple); } TEST_F(LiteralUtilTest, C64Equality) { @@ -405,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) { // tuple, the other is a clone of the element in the original tuple. auto vector_clone = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); - EXPECT_EQ(*vector, *vector_clone); + EXPECT_EQ(vector, vector_clone); auto vector_reversed = LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); - EXPECT_NE(*vector, *vector_reversed); + EXPECT_NE(vector, vector_reversed); } TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + auto tuple = LiteralUtil::MakeTuple({&element1, &element1}); // Tuples should always return false for IsAll. - EXPECT_FALSE(tuple->IsAll(0)); - EXPECT_FALSE(tuple->IsAll(1)); + EXPECT_FALSE(tuple.IsAll(0)); + EXPECT_FALSE(tuple.IsAll(1)); } // Verifies that CreateFromShape works for tuples. TEST_F(LiteralUtilTest, CreateFromShapeTuple) { auto scalar = LiteralUtil::CreateR0(0.0); auto matrix = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - auto x = Literal::CreateFromShape(tuple->shape()); - EXPECT_EQ(*tuple, *x); + auto x = Literal::CreateFromShape(tuple.shape()); + EXPECT_EQ(tuple, x); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(LiteralUtil::CreateR0(false)->IsAll(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(true)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(-1)); + EXPECT_TRUE(LiteralUtil::CreateR0(false).IsAll(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(true).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE(LiteralUtil::CreateR0(255)->IsAll(int8_min)); + EXPECT_FALSE(LiteralUtil::CreateR0(255).IsAll(int8_min)); - EXPECT_TRUE(LiteralUtil::CreateR0(42.0)->IsAll(42)); - EXPECT_FALSE(LiteralUtil::CreateR0(42.0001)->IsAll(42)); + EXPECT_TRUE(LiteralUtil::CreateR0(42.0).IsAll(42)); + EXPECT_FALSE(LiteralUtil::CreateR0(42.0001).IsAll(42)); - EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100})->IsAll(100)); - EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001})->IsAll(100)); + EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100}).IsAll(100)); + EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001}).IsAll(100)); - EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}}).IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}}).IsAll(8)); bfloat16 b8(8.0f); bfloat16 b9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}}).IsAll(8)); // 9.001 will be truncated to 9.0 bfloat16 b91(9.001f); bfloat16 b90(9.00f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}})->IsAll(9.0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); complex64 c8_9 = {8, 9}; - EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) - ->IsAll(-1)); + .IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); EXPECT_TRUE(LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}) - ->IsAllFloat(.5)); + .IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsAllComplex) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c7_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); } TEST_F(LiteralUtilTest, IsAllFirst) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR1({false, true})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({false, false})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({false, true}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({false, false}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); - EXPECT_FALSE( - LiteralUtil::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}).IsAllFirst()); } TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = LiteralUtil::CreateR0(0.0f); auto scalar_one = LiteralUtil::CreateR0(1.0f); - EXPECT_TRUE(scalar_zero->IsZero({})); - EXPECT_FALSE(scalar_one->IsZero({})); + EXPECT_TRUE(scalar_zero.IsZero({})); + EXPECT_FALSE(scalar_one.IsZero({})); auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); - EXPECT_FALSE(array->IsZero({0, 1})); - EXPECT_TRUE(array->IsZero({0, 2})); - EXPECT_TRUE(array->IsZero({1, 1})); - EXPECT_FALSE(array->IsZero({1, 2})); + EXPECT_FALSE(array.IsZero({0, 1})); + EXPECT_TRUE(array.IsZero({0, 2})); + EXPECT_TRUE(array.IsZero({1, 1})); + EXPECT_FALSE(array.IsZero({1, 2})); auto complex_zero = LiteralUtil::CreateR0(0.0f); auto complex_nonzero = LiteralUtil::CreateR0(0.5f); - EXPECT_TRUE(complex_zero->IsZero({})); - EXPECT_FALSE(complex_nonzero->IsZero({})); + EXPECT_TRUE(complex_zero.IsZero({})); + EXPECT_FALSE(complex_nonzero.IsZero({})); } template @@ -576,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); - auto data01 = data->Relayout(layout01); - EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_EQ(*data, *data01); + auto data01 = data.Relayout(layout01); + EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01)); + EXPECT_EQ(data, data01); - auto data10 = data->Relayout(layout10); - EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_EQ(*data, *data10); + auto data10 = data.Relayout(layout10); + EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10)); + EXPECT_EQ(data, data10); } TEST_F(LiteralUtilTest, ReshapeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, ReshapeR4) { @@ -606,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { @@ -626,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Transpose(/*permutation=*/{}); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Transpose(/*permutation=*/{}); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, TransposeR4) { @@ -646,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}); // clang-format on - auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); + auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell([&](absl::Span indices, float value) { - EXPECT_EQ(value, original->Get( + reshape.EachCell([&](absl::Span indices, float value) { + EXPECT_EQ(value, original.Get( {indices[2], indices[3], indices[0], indices[1]})); }); } @@ -658,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. auto dim0minor_relaid_to_dim0major = - literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major); + literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major); auto dim0major_relaid_to_dim0minor = - literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor); + literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); - EXPECT_EQ(mat_dim0minor->element_count(), 6); - EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); + EXPECT_EQ(mat_dim0minor.element_count(), 6); + EXPECT_THAT(mat_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. - auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); - EXPECT_THAT(relaid_mat_to_dim0major->data(), + auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_); + EXPECT_THAT(relaid_mat_to_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); - EXPECT_EQ(mat_dim0major->element_count(), 6); - EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); + EXPECT_EQ(mat_dim0major.element_count(), 6); + EXPECT_THAT(mat_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. - auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); - EXPECT_THAT(relaid_mat_to_dim0minor->data(), + auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_); + EXPECT_THAT(relaid_mat_to_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); } @@ -707,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0minor_); - EXPECT_EQ(lit_dim0minor->element_count(), 12); + EXPECT_EQ(lit_dim0minor.element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_THAT(lit_dim0minor->data(), + EXPECT_THAT(lit_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. - auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); + auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_THAT(relaid_lit_to_dim0major->data(), + EXPECT_THAT(relaid_lit_to_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0major_); - EXPECT_EQ(lit_dim0major->element_count(), 12); - EXPECT_THAT(lit_dim0major->data(), + EXPECT_EQ(lit_dim0major.element_count(), 12); + EXPECT_THAT(lit_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. - auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); - EXPECT_THAT(relaid_lit_to_dim0minor->data(), + auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_); + EXPECT_THAT(relaid_lit_to_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { auto input = LiteralUtil::CreateR0(1); - auto result = input->Slice({}, {}); - EXPECT_EQ(*input, *result); + auto result = input.Slice({}, {}); + EXPECT_EQ(input, result); } TEST_F(LiteralUtilTest, SliceR1F32) { auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); - auto result = input->Slice({3}, {4}); + auto result = input.Slice({3}, {4}); auto expected = LiteralUtil::CreateR1({4.0}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR2U32) { auto input_3x4 = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto result = input_3x4->Slice({0, 2}, {2, 4}); + auto result = input_3x4.Slice({0, 2}, {2, 4}); auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { auto input_2x3x2 = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); - EXPECT_EQ(*input_2x3x2, *result); + auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2}); + EXPECT_EQ(input_2x3x2, result); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); auto expected = LiteralUtil::CreateR1({77}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR2C64) { @@ -785,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = LiteralUtil::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { @@ -793,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { bfloat16 h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { @@ -801,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { bfloat16 h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { @@ -809,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { bfloat16 h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output(ShapeUtil::MakeShape(F32, {})); output.PopulateWithValue(2.5f); auto expected = LiteralUtil::CreateR0(2.5f); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output(ShapeUtil::MakeShape(S64, {3})); output.PopulateWithValue(-7); auto expected = LiteralUtil::CreateR1({-7, -7, -7}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output(ShapeUtil::MakeShape(U64, {2, 2})); output.PopulateWithValue(42); auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { @@ -838,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { output.PopulateWithValue({4, 2}); auto expected = LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { @@ -846,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { half h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { @@ -854,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { half h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { @@ -862,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { half h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto output = input->Replicate(3); + auto output = input.Replicate(3); auto expected = LiteralUtil::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_EQ(*output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, CopySliceFrom) { @@ -889,17 +885,17 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; auto init_proc = [&](absl::Span indexes) { - source->Set(indexes, ++seqnr); + source.Set(indexes, ++seqnr); return true; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step, init_proc); auto blank = Literal::CreateFromShape(shape); const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size)); + TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); @@ -911,12 +907,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, blank_indexes.begin(), std::plus()); - auto bval = blank->Get(blank_indexes); - matched = (bval != 0 && bval == source->Get(source_indexes)); + auto bval = blank.Get(blank_indexes); + matched = (bval != 0 && bval == source.Get(source_indexes)); return matched; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step, check_proc); EXPECT_TRUE(matched); } @@ -925,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { TEST_F(LiteralUtilTest, CopyFromScalars) { auto zero = LiteralUtil::CreateR0(0); auto nine = LiteralUtil::CreateR0(9); - TF_EXPECT_OK(zero->CopyFrom(*nine)); - EXPECT_EQ(*zero, *nine); + TF_EXPECT_OK(zero.CopyFrom(nine)); + EXPECT_EQ(zero, nine); auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); - EXPECT_EQ(zero->Get({}), 17); - TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); - EXPECT_EQ(vect->Get({4}), 17); + TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {})); + EXPECT_EQ(zero.Get({}), 17); + TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {})); + EXPECT_EQ(vect.Get({4}), 17); } TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { @@ -945,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); - EXPECT_EQ(*nine, *const_nine); + TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0})); + EXPECT_EQ(nine, const_nine); } { // Copy 0 element to destination with zero elements. - const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); - EXPECT_EQ(*empty, *const_empty); + TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0})); + EXPECT_EQ(empty, const_empty); } } @@ -969,74 +965,75 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) { TEST_F(LiteralUtilTest, CopyFromArrays) { auto scalar_42 = LiteralUtil::CreateR0(42.0); auto scalar_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*scalar_42, *scalar_123); - TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*scalar_42, *scalar_123); - EXPECT_EQ(scalar_42->Get({}), 123.0f); + EXPECT_NE(scalar_42, scalar_123); + TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(scalar_42, scalar_123); + EXPECT_EQ(scalar_42.Get({}), 123.0f); auto matrix_1234 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_5678 = LiteralUtil::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); - EXPECT_NE(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); - TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 5.0f); + EXPECT_NE(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 1.0f); + TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 5.0f); } TEST_F(LiteralUtilTest, CopyFromTuples) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {matrix.get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get()}); + Literal inner_elements[] = {LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0})}; + Literal inner_tuple = LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}); + Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple}); // Create a tuple the same shape as the inner tuple of nested_tuple but with // different values.. - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(-5).get(), - LiteralUtil::CreateR1({2.0, 4.0}).get(), &nil_literal}); + Literal int32_minus5 = LiteralUtil::CreateR0(-5); + Literal double_2_4 = LiteralUtil::CreateR1({2.0, 4.0}); + Literal tuple = + LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal}); - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), 42); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 23.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 44.0); // Overwrite the inner tuple element of nested_tuple with the contents of // 'tuple'. - TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{})); + TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 2.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), -5); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 2.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 4.0); } TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { - auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0(-2).get(), - LiteralUtil::CreateR0(4).get()}); + Literal elements[] = {LiteralUtil::CreateR0(-2), + LiteralUtil::CreateR0(4)}; + Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), 4); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), 4); // Copy from one element to the other. - TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{0})); + TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{0})); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), -2); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), -2); } TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto vector = LiteralUtil::CreateR1({5.0, 7.0}); - Status status = matrix->CopyFrom(*vector); + Status status = matrix.CopyFrom(vector); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), HasSubstr("Destination subshape incompatible")); @@ -1046,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent - auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); - Literal* l1 = m1.get(); - const char* d1 = reinterpret_cast(l1->data().data()); + Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + const char* d1 = reinterpret_cast(m1.data().data()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -1061,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l2 = m2.get(); - const char* d2 = reinterpret_cast(l2->data().data()); + const char* d2 = reinterpret_cast(m2.data().data()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -1091,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = absl::make_unique(shape); + Literal literal(shape); auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->Populate(generator)); + TF_EXPECT_OK(literal.Populate(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](absl::Span indexes) { - auto value = literal->Get(indexes); + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1133,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = absl::make_unique(shape); + Literal literal(shape); auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->PopulateParallel(generator)); + TF_EXPECT_OK(literal.PopulateParallel(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](absl::Span indexes) { - auto value = literal->Get(indexes); + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1170,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->Convert(U32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32)); - EXPECT_EQ(*expected, *converted); + EXPECT_EQ(expected, converted); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { @@ -1245,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); // clang-format on - std::unique_ptr conv; + Literal conv; - conv = s8->Convert(U32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u32); + conv = s8.Convert(U32).ConsumeValueOrDie(); + EXPECT_EQ(conv, u32); - conv = s8->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = s8.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s8->Convert(U64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u64); + conv = s8.Convert(U64).ConsumeValueOrDie(); + EXPECT_EQ(conv, u64); - conv = s8->Convert(S64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s64); + conv = s8.Convert(S64).ConsumeValueOrDie(); + EXPECT_EQ(conv, s64); - conv = s8->Convert(PRED).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *pred); + conv = s8.Convert(PRED).ConsumeValueOrDie(); + EXPECT_EQ(conv, pred); - conv = bf16->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = bf16.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = bf16->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = bf16.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = pred->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *int32_pred); + conv = pred.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, int32_pred); - conv = f32->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f32.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = f64->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f64.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s32->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = s32.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = f32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = f64->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f64.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = s32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = u32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = u32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = s32.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - conv = f16->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = f16.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - EXPECT_EQ(s32->Convert(TUPLE).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(S16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(U16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(F32).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(S32).status().code(), + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1317,13 +1307,12 @@ TEST_F(LiteralUtilTest, BitcastConvert) { tensorflow::bit_cast(100.f), 0xbeef}); auto expected = LiteralUtil::CreateR1( {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->BitcastConvert(F32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32)); } TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { auto literal = LiteralUtil::CreateR0(1234); - Status status = literal->BitcastConvert(F64).status(); + Status status = literal.BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); EXPECT_TRUE( absl::StrContains(status.error_message(), "bit widths are different")); @@ -1341,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) { p.add_preds((i % 2) == (len % 2)); } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - ASSERT_EQ(len, literal->data().size()); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + ASSERT_EQ(len, literal.data().size()); int i = 0; - for (bool value : literal->data()) { + for (bool value : literal.data()) { EXPECT_EQ((i % 2) == (len % 2), value); ++i; } @@ -1358,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) { half h2(2.0f); auto m = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l = m.get(); - EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); - EXPECT_EQ(4, l->data().size()); + EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape())); + EXPECT_EQ(4, m.data().size()); - LiteralProto p = l->ToProto(); + LiteralProto p = m.ToProto(); EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); EXPECT_EQ(8, p.f16s().size()); const char* d = p.f16s().data(); @@ -1389,9 +1376,8 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - auto r = literal->data(); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); ASSERT_EQ(4, r.size()); EXPECT_EQ(h1, r[0]); EXPECT_EQ(h2, r[1]); @@ -1402,43 +1388,41 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); - EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); - EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(scalar, {}), scalar); + EXPECT_EQ(LiteralSlice(matrix, {}), matrix); + EXPECT_EQ(LiteralSlice(tuple, {}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple); EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(tuple, {0}), scalar); + EXPECT_EQ(LiteralSlice(tuple, {1}), matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix); + EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar); } TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralSlice(*nested_tuple); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 1.0f); + const auto nested_tuple_view = LiteralSlice(nested_tuple); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 1.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); - nested_tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 555.0f); + nested_tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 555.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 555.0f); @@ -1447,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) { TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); - const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(nested_tuple); const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { @@ -1497,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { } TEST_F(LiteralUtilTest, LiteralMove) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - Literal literal(std::move(*matrix)); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal(std::move(matrix)); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1511,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) { TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get(), - &nil_literal}); - - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape())); - std::vector elements = nested_tuple->DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape())); + Literal inner_elements[] = { + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0}), + }; + Literal tuple_elements[] = { + LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}), + }; + Literal nested_tuple = LiteralUtil::MakeTuple( + {&tuple_elements[0], &tuple_elements[1], &nil_literal}); + + EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + std::vector elements = nested_tuple.DecomposeTuple(); + EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1552,13 +1539,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { TEST_F(LiteralUtilTest, MoveIntoTuple) { std::vector elements; - elements.push_back(std::move(*LiteralUtil::CreateR0(1.0))); - elements.push_back(std::move(*LiteralUtil::CreateR1({4, 8}))); - elements.push_back(std::move(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get()}) - - )); + elements.push_back(LiteralUtil::CreateR0(1.0)); + elements.push_back(LiteralUtil::CreateR1({4, 8})); + std::vector inner_elements; + inner_elements.push_back(LiteralUtil::CreateR0(42)); + inner_elements.push_back(LiteralUtil::CreateR1({23.0, 44.0})); + elements.push_back( + LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]})); Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); @@ -1586,9 +1573,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { Literal literal; EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - literal = std::move(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + literal = std::move(matrix); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1599,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { } TEST_F(LiteralUtilTest, LiteralSliceCopy) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralSlice(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + const auto matrix_view = LiteralSlice(matrix); LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); @@ -1611,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) { } TEST_F(LiteralUtilTest, GetSetTuple) { - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42.0).get(), - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); - tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); - - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), - 3.0); - tuple->Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + Literal elements[] = { + LiteralUtil::CreateR0(42.0), + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + }; + auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); + tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); + + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0); + tuple.Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), -4.0); } TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { // Literals constructed using CreateFromShape should be zero initialized. - std::unique_ptr scalar_f32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); - EXPECT_EQ(scalar_f32->Get({}), 0.0); - EXPECT_TRUE(scalar_f32->IsAll(0)); - - std::unique_ptr vector_s32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); - EXPECT_EQ(vector_s32->Get({0}), 0); - EXPECT_EQ(vector_s32->Get({1}), 0); - EXPECT_EQ(vector_s32->Get({2}), 0); - EXPECT_TRUE(vector_s32->IsAll(0)); - - std::unique_ptr tuple = - Literal::CreateFromShape(ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); - - EXPECT_EQ(tuple->Get({}, {0}), 0.0); - EXPECT_EQ(tuple->Get({0}, {1}), false); - EXPECT_EQ(tuple->Get({1}, {1}), false); - EXPECT_EQ(tuple->Get({0, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({1, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({}, {3}), complex64(0.0f, 0.0f)); + Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); + EXPECT_EQ(scalar_f32.Get({}), 0.0); + EXPECT_TRUE(scalar_f32.IsAll(0)); + + Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); + EXPECT_EQ(vector_s32.Get({0}), 0); + EXPECT_EQ(vector_s32.Get({1}), 0); + EXPECT_EQ(vector_s32.Get({2}), 0); + EXPECT_TRUE(vector_s32.IsAll(0)); + + Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + + EXPECT_EQ(tuple.Get({}, {0}), 0.0); + EXPECT_EQ(tuple.Get({0}, {1}), false); + EXPECT_EQ(tuple.Get({1}, {1}), false); + EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1665,25 +1648,25 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto tuple = LiteralUtil::MakeTuple( - {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); + {&one_f32, &vector_half, &matrix_pred, &matrix_pred}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); + auto nested_tuple = + LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal}); auto to_from_proto = [](const Literal& literal) -> Literal { - return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie()); + return Literal::CreateFromProto(literal.ToProto()).ValueOrDie(); }; - EXPECT_EQ(*one_f32, to_from_proto(*one_f32)); - EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64)); - EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16)); - EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred)); - EXPECT_EQ(*tuple, to_from_proto(*tuple)); - EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple)); + EXPECT_EQ(one_f32, to_from_proto(one_f32)); + EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); + EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); + EXPECT_EQ(tuple, to_from_proto(tuple)); + EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple)); EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); - EXPECT_NE(*one_f32, *two_f32); - EXPECT_NE(*one_f32, to_from_proto(*two_f32)); + EXPECT_NE(one_f32, two_f32); + EXPECT_NE(one_f32, to_from_proto(two_f32)); } TEST_F(LiteralUtilTest, InvalidProtoNoValues) { @@ -1802,11 +1785,11 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { TEST_F(LiteralUtilTest, SortSparseElements) { auto literal = LiteralUtil::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); - literal->AppendSparseElement({2, 3, 4}, 2.0); - literal->AppendSparseElement({3, 4, 5}, 3.0); - literal->AppendSparseElement({1, 2, 3}, 1.0); - literal->SortSparseElements(); - EXPECT_EQ(literal->ToString(false), + literal.AppendSparseElement({2, 3, 4}, 2.0); + literal.AppendSparseElement({3, 4, 5}, 3.0); + literal.AppendSparseElement({1, 2, 3}, 1.0); + literal.SortSparseElements(); + EXPECT_EQ(literal.ToString(false), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } @@ -1816,57 +1799,54 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), "false"); EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(int64{2})); EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(double{2.0})); EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(static_cast(half{2.0}))); EXPECT_EQ(LiteralUtil::CreateSparse( dimensions, indices, std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{0})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 1}, {2, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 1}, {2, 2}})); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{1})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 2}, {1, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 2}, {1, 2}})); } TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { - std::unique_ptr literal = LiteralUtil::CreateR0(9); + Literal literal = LiteralUtil::CreateR0(9); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), - /*dimensions=*/{})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{9, 9}, {9, 9}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } } // namespace diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 613449cf10..0cb1ae35f4 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -45,7 +45,7 @@ using absl::StrCat; // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template -std::unique_ptr ConvertType(LiteralSlice literal) { +Literal ConvertType(LiteralSlice literal) { // First construct shape of the result. Shape result_shape(literal.shape()); ShapeUtil::ForEachMutableSubshape( @@ -56,7 +56,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()); } }); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); // Then copy over the data from 'literal' converting FromNativeT values to // ToNativeT values as necessary. @@ -67,14 +67,14 @@ std::unique_ptr ConvertType(LiteralSlice literal) { if (subshape.element_type() == primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); - auto dest = result->data(shape_index); + auto dest = result.data(shape_index); for (int64 i = 0; i < src.size(); ++i) { dest[i] = static_cast(src[i]); } } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); + TF_CHECK_OK(result.CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); } } }); @@ -83,53 +83,52 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace -/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( +/* static */ Literal LiteralUtil::CreateFromDimensions( PrimitiveType primitive_type, absl::Span dimensions) { return Literal::CreateFromShape( ShapeUtil::MakeShape(primitive_type, dimensions)); } -/* static */ std::unique_ptr LiteralUtil::ConvertBF16ToF32( +/* static */ Literal LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); } -/* static */ std::unique_ptr LiteralUtil::ConvertF32ToBF16( +/* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); } -/* static */ std::unique_ptr LiteralUtil::CreateToken() { - return absl::make_unique(ShapeUtil::MakeTokenShape()); +/* static */ Literal LiteralUtil::CreateToken() { + return Literal(ShapeUtil::MakeTokenShape()); } /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case C64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -145,30 +144,29 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case C64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -184,42 +182,36 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case F32: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -232,40 +224,34 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case F32: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -275,31 +261,29 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } } -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ Literal LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = absl::make_unique( + Literal literal( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR1U8( - absl::string_view value) { - auto literal = absl::make_unique( - ShapeUtil::MakeShape(U8, {static_cast(value.size())})); +/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) { + Literal literal(ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { - literal->Set({i}, value[i]); + literal.Set({i}, value[i]); } return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( - float from, float to, int64 rows, int64 cols) { +/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -/* static */ std::unique_ptr LiteralUtil::ReshapeSlice( +/* static */ Literal LiteralUtil::ReshapeSlice( absl::Span new_dimensions, absl::Span minor_to_major, const LiteralSlice& literal) { int64 new_num_elements = 1; @@ -309,13 +293,13 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = absl::make_unique( + Literal new_literal( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used // solely for converting linear address to multi-dimensional addresses when // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); + Shape shape_with_layout = new_literal.shape(); *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); // Copy data into new literal, element-by-element. @@ -326,40 +310,40 @@ std::unique_ptr ConvertType(LiteralSlice literal) { IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); switch (literal.shape().element_type()) { case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; default: LOG(FATAL) << "Unhandled primitive element type: " @@ -376,97 +360,82 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); switch (literal.shape().element_type()) { case PRED: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 8 bit types. case S8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 16 bit types. case BF16: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 32 bit types. case F32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 64 bit types. case C64: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); } } -/* static */ std::unique_ptr LiteralUtil::MakeTuple( +/* static */ Literal LiteralUtil::MakeTuple( absl::Span elements) { std::vector element_shapes; for (const auto* element : elements) { element_shapes.push_back(element->shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( +/* static */ Literal LiteralUtil::MakeTupleFromSlices( absl::Span elements) { std::vector element_shapes; for (const auto& element : elements) { element_shapes.push_back(element.shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleOwned( - std::vector> elements) { +/* static */ Literal LiteralUtil::MakeTupleOwned( + std::vector elements) { std::vector element_shapes; element_shapes.reserve(elements.size()); for (const auto& element : elements) { - element_shapes.push_back(element->shape()); + element_shapes.push_back(element.shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int64 i = 0; i < elements.size(); ++i) { TF_CHECK_OK( - literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); } return literal; } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 2d6084a67a..2b181621ed 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -69,36 +69,34 @@ class LiteralUtil { // The variants not ending with WithLayout use the default XLA layout for the // literal's linear representation in memory. template - static std::unique_ptr CreateR0(NativeT value); + static Literal CreateR0(NativeT value); template - static std::unique_ptr CreateR1(absl::Span values); - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values); + static Literal CreateR1(absl::Span values); + static Literal CreateR1(const tensorflow::core::Bitmap& values); template - static std::unique_ptr CreateR2( + static Literal CreateR2( std::initializer_list> values); template - static std::unique_ptr CreateR2WithLayout( + static Literal CreateR2WithLayout( std::initializer_list> values, const Layout& layout); template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values); + static Literal CreateR3(std::initializer_list< + std::initializer_list>> + values); template - static std::unique_ptr CreateR3WithLayout( + static Literal CreateR3WithLayout( std::initializer_list< std::initializer_list>> values, const Layout& layout); template - static std::unique_ptr CreateR4( + static Literal CreateR4( std::initializer_list>>> values); template - static std::unique_ptr CreateR4WithLayout( + static Literal CreateR4WithLayout( std::initializer_list>>> values, @@ -139,9 +137,10 @@ class LiteralUtil { // [9, 10, 11]: 4.0 // template - static std::unique_ptr CreateSparse( - absl::Span dimensions, SparseIndexArray indices, - absl::Span values, bool sort = true); + static Literal CreateSparse(absl::Span dimensions, + SparseIndexArray indices, + absl::Span values, + bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -155,130 +154,120 @@ class LiteralUtil { static Literal MaxValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template - static std::unique_ptr CreateFullWithDescendingLayout( + static Literal CreateFullWithDescendingLayout( absl::Span dimensions, NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear // representation in memory. template - static std::unique_ptr CreateFromArray(const Array& values); + static Literal CreateFromArray(const Array& values); template - static std::unique_ptr CreateFromArrayWithLayout( - const Array& values, const Layout& layout); + static Literal CreateFromArrayWithLayout(const Array& values, + const Layout& layout); template - static std::unique_ptr CreateR2FromArray2D( - const Array2D& values); + static Literal CreateR2FromArray2D(const Array2D& values); template - static std::unique_ptr CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); + static Literal CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); template - static std::unique_ptr CreateR3FromArray3D( - const Array3D& values); + static Literal CreateR3FromArray3D(const Array3D& values); template - static std::unique_ptr CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); + static Literal CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); template - static std::unique_ptr CreateR4FromArray4D( - const Array4D& values); + static Literal CreateR4FromArray4D(const Array4D& values); template - static std::unique_ptr CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); + static Literal CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(absl::string_view value); + static Literal CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. - static std::unique_ptr CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols); + static Literal CreateR2F32Linspace(float from, float to, int64 rows, + int64 cols); // Creates a literal that projects the (x, y) dimensions given in values into // the z dimension given by "projection". template - static std::unique_ptr CreateR3Projected( + static Literal CreateR3Projected( std::initializer_list> values, int64 projection); // Creates a literal that projects the (x, y) dimensions given in values into // the z and p dimensions given. template - static std::unique_ptr CreateR4Projected( + static Literal CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template - static std::unique_ptr MakeIdentityR2(int64 size); + static Literal MakeIdentityR2(int64 size); // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. - static std::unique_ptr MakeTuple( - absl::Span elements); + static Literal MakeTuple(absl::Span elements); - static std::unique_ptr MakeTupleFromSlices( - absl::Span elements); + static Literal MakeTupleFromSlices(absl::Span elements); // As above, but intended to be invoked with move semantics; i.e. // - // std::vector> elements = ...; + // std::vector elements = ...; // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. - static std::unique_ptr MakeTupleOwned( - std::vector> elements); + static Literal MakeTupleOwned(std::vector elements); - // This overload lets you pass a braced list of unique_ptrs to + // This overload lets you pass a braced list of Literals to // MakeTupleOwned: // // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // - // Simply relying on the MakeTupleOwned(std::vector>) + // Simply relying on the MakeTupleOwned(std::vector) // overload doesn't work because std::initializer_list's elements are always // const. // - // The arguments to this function must all be unique_ptr. + // The arguments to this function must all be Literal. template - static std::unique_ptr MakeTupleOwned( - std::unique_ptr... elements) { - std::array, sizeof...(Ts)> arr{ - std::move(elements)...}; - std::vector> v; + static Literal MakeTupleOwned(Ts... elements) { + std::array arr{std::move(elements)...}; + std::vector v; v.insert(v.begin(), std::make_move_iterator(arr.begin()), std::make_move_iterator(arr.end())); return MakeTupleOwned(std::move(v)); } // Create a constant token literal. Token types have no value. - static std::unique_ptr CreateToken(); + static Literal CreateToken(); // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, absl::Span dimensions); + static Literal CreateFromDimensions(PrimitiveType primitive_type, + absl::Span dimensions); // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32( - const LiteralSlice& bf16_literal); + static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); // If the given literal's data type is float, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16( - const LiteralSlice& f32_literal); + static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); // Creates a literal with a new shape with the given new dimensions using the // data in the given input literal. For reshaping purposes the (flat) data // buffer of the input literal is assumed to have the given minor_to_major // layout order. - static std::unique_ptr ReshapeSlice( - absl::Span new_dimensions, - absl::Span minor_to_major, const LiteralSlice& literal); + static Literal ReshapeSlice(absl::Span new_dimensions, + absl::Span minor_to_major, + const LiteralSlice& literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -286,7 +275,7 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( + static StatusOr CreateRandomLiteral( const Shape& shape, const std::function)>& generator); @@ -297,8 +286,8 @@ class LiteralUtil { template < PrimitiveType type, typename E, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, E* engine, + T mean, T stddev); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -307,8 +296,8 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, T mean, + T stddev); // // End of factory methods. @@ -322,44 +311,43 @@ class LiteralUtil { std::ostream& operator<<(std::ostream& out, const Literal& literal); template -/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { - auto literal = absl::make_unique(ShapeUtil::MakeShape( +/* static */ Literal LiteralUtil::CreateR0(NativeT value) { + Literal literal(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); - literal->Set({}, value); + literal.Set({}, value); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR1( - absl::Span values) { - auto literal = absl::make_unique( +/* static */ Literal LiteralUtil::CreateR1(absl::Span values) { + Literal literal( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( +/* static */ Literal LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, AsInt64Slice(layout.minor_to_major()))); - literal->PopulateR2(values); + literal.PopulateR2(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2( +/* static */ Literal LiteralUtil::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( +/* static */ Literal LiteralUtil::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -384,14 +372,14 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR3( +/* static */ Literal LiteralUtil::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( +/* static */ Literal LiteralUtil::CreateR4WithLayout( std::initializer_list>>> values, @@ -422,23 +410,22 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateSparse( +/* static */ Literal LiteralUtil::CreateSparse( absl::Span dimensions, SparseIndexArray indices, absl::Span values, bool sort) { int64 num_elements = values.size(); int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = - absl::make_unique(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType(), dimensions, - indices.max_indices())); - literal->PopulateSparse(indices, values, sort); + Literal literal(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); + literal.PopulateSparse(indices, values, sort); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR4( +/* static */ Literal LiteralUtil::CreateR4( std::initializer_list>>> values) { @@ -446,50 +433,48 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( +/* static */ Literal LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); - literal->PopulateFromArray(values); + literal.PopulateFromArray(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArray( +/* static */ Literal LiteralUtil::CreateFromArray( const Array& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( +/* static */ Literal LiteralUtil::CreateR2FromArray2D( const Array2D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( +/* static */ Literal LiteralUtil::CreateR3FromArray3D( const Array3D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( +/* static */ Literal LiteralUtil::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -514,7 +499,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( +/* static */ Literal LiteralUtil::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -542,21 +527,20 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( +/* static */ Literal LiteralUtil::CreateR4FromArray4D( const Array4D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -565,33 +549,29 @@ template } template -/* static */ std::unique_ptr -LiteralUtil::CreateFullWithDescendingLayout(absl::Span dimensions, - NativeT value) { - auto literal = - absl::make_unique(ShapeUtil::MakeShapeWithDescendingLayout( - primitive_util::NativeToPrimitiveType(), dimensions)); - literal->PopulateWithValue(value); +/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout( + absl::Span dimensions, NativeT value) { + Literal literal(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); + literal.PopulateWithValue(value); return literal; } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral( +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( const Shape& shape, const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = absl::make_unique(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( + Literal literal(shape); + TF_RETURN_IF_ERROR(literal.Populate( [&](absl::Span indexes) { return generator(indexes); })); return std::move(literal); } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( @@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index f9473d372b..0f86f9f35e 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file) PackedLiteralReader::~PackedLiteralReader() { delete file_; } -StatusOr> PackedLiteralReader::Read( - const Shape& shape, const Layout* layout) { +StatusOr PackedLiteralReader::Read(const Shape& shape, + const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) << " layout: " << (layout == nullptr ? "" : layout->ShortDebugString()); @@ -57,11 +57,11 @@ StatusOr> PackedLiteralReader::Read( PrimitiveType_Name(shape.element_type())); } - auto result = absl::make_unique(literal_shape); - result->PopulateWithValue(std::numeric_limits::quiet_NaN()); + Literal result(literal_shape); + result.PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); - absl::Span field = result->data(); + absl::Span field = result.data(); char* data = absl::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); absl::string_view sp; diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 98dccaa9a2..d6d2ff1521 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -41,8 +41,7 @@ class PackedLiteralReader { // // Layout is optional. If it is not provided, no layout is set on the literal // that is produced. - StatusOr> Read(const Shape& shape, - const Layout* layout = nullptr); + StatusOr Read(const Shape& shape, const Layout* layout = nullptr); // Returns whether the input file has been fully exhausted; i.e. all available // packed literals have been read and we're at the end of the file. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cd6e20b693..9da5dc0d2d 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, return client->TransferToInfeedLocal(literal, device_ordinal); } -StatusOr> TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number) { +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number) { VLOG(1) << "Outfeeding literal from replica number: " << replica_number << " shape: " << shape; LocalClient* client = GetOrCreateLocalClient(); @@ -141,9 +141,8 @@ StatusOr LocalShapedBuffer::FromLiteral( LocalClient* client = GetOrCreateLocalClient(); StatusOr buf = [&] { if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + return ToBuffer(client, /*device_ordinal=*/0, relaid); } return ToBuffer(client, /*device_ordinal=*/0, argument); }(); @@ -151,7 +150,7 @@ StatusOr LocalShapedBuffer::FromLiteral( return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -StatusOr> LocalShapedBuffer::ToLiteral() const { +StatusOr LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); return client->ShapedBufferToLiteral(*shaped_buffer()); } @@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation( std::unique_ptr executable) : executable_(std::move(executable)) {} -StatusOr> CompiledLocalComputation::Execute( +StatusOr CompiledLocalComputation::Execute( const std::vector& arguments, const std::vector>& shapes_with_layout) { LocalClient* client = GetOrCreateLocalClient(); @@ -169,7 +168,7 @@ StatusOr> CompiledLocalComputation::Execute( // Each replica populates a StatusOr result, but only replica zero actually // retrieves its literal value. - std::vector>> results(GetReplicaCount()); + std::vector> results(GetReplicaCount()); { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", GetReplicaCount()); @@ -198,9 +197,8 @@ StatusOr> CompiledLocalComputation::Execute( StatusOr pushed; if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, relaid); } else { pushed = ToBuffer(client, device_ordinal, argument); } diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 78b3c598b9..1d5dfe5911 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); // Transfers a literal of the given shape from the outfeed of the given replica. // // The replica number is resolved to an appropriate device ordinal. -StatusOr > TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number); +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number); // Wraps a ScopedShapedBuffer produced by copying a literal "to // device," i.e. copying a literal to a scoped buffer via the local @@ -65,7 +65,7 @@ class LocalShapedBuffer { LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; - StatusOr > ToLiteral() const; + StatusOr ToLiteral() const; // Transfers ownership of the encapsulated ShapedBuffer to the caller, // analogous to std::unique_ptr::release(). @@ -117,7 +117,7 @@ class CompiledLocalComputation { // with optionally-specified argument layouts. The literals will be // re-laid out according to the corresponding elements of // shapes_with_layout. - StatusOr > Execute( + StatusOr Execute( const std::vector& arguments, const std::vector >& shapes_with_layout); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 450d3fe5af..521490e76c 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -216,9 +216,9 @@ tensorflow::ImportNumpy(); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if ($1.ok()) { - std::unique_ptr value = $1.ConsumeValueOrDie(); + Literal value = $1.ConsumeValueOrDie(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -346,25 +346,25 @@ tensorflow::ImportNumpy(); // Literal -%typemap(in) const Literal& (StatusOr< std::unique_ptr > literal_status) { +%typemap(in) const Literal& (StatusOr literal_status) { literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); SWIG_fail; } - $1 = literal_status.ValueOrDie().get(); + $1 = &literal_status.ValueOrDie(); } -%typemap(out) std::unique_ptr { +%typemap(out) Literal { $result = numpy::PyObjectFromXlaLiteral(*$1); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } - $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); + $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); } %typemap(in) const std::vector& (std::vector temps) { @@ -375,13 +375,13 @@ tensorflow::ImportNumpy(); const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - StatusOr< std::unique_ptr > literal_status = numpy::XlaLiteralFromPyObject(o); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); SWIG_fail; } - temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); + temps.push_back(literal_status.ConsumeValueOrDie()); Py_DECREF(o); } $1 = &temps; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index fc6511bef5..b0aa024c74 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { } } -StatusOr> XlaLiteralFromPyObject(PyObject* o) { +StatusOr XlaLiteralFromPyObject(PyObject* o) { if (PyTuple_Check(o)) { int num_elements = PyTuple_Size(o); - std::vector> elements; + std::vector elements; elements.reserve(num_elements); for (int i = 0; i < num_elements; i++) { PyObject* element = PyTuple_GetItem(o, i); @@ -389,8 +389,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { int np_type = PyArray_TYPE(py_array); auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); - TF_RETURN_IF_ERROR( - CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); + TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal)); return std::move(literal); } else { return InvalidArgument( diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 8cae175185..40ff2d9ad2 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // To avoid transferring ownership of the data buffers that underlie // PyArrays and XLA literals, this function makes deep copies of all // array data. -StatusOr > XlaLiteralFromPyObject(PyObject* o); +StatusOr XlaLiteralFromPyObject(PyObject* o); // The following functions copy array data from the buffers underlying Numpy // ndarrays into those underlying XLA literals, and vice versa. diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 9f1afa2671..05325367f5 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -529,13 +529,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)); std::vector> paddings = MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, @@ -546,7 +546,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim; dim.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0))); dim.set_stride(kernel_stride.first); dim.set_padding_low(paddings[0].first); dim.set_padding_high(paddings[0].second); @@ -556,7 +556,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim2; dim2.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1))); dim2.set_stride(kernel_stride.second); dim2.set_padding_low(paddings[1].first); dim2.set_padding_high(paddings[1].second); @@ -565,7 +565,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( *window.add_dimensions() = dim2; const Shape& shape = ShapeInference::InferConvolveShape( - lhs_literal->shape(), rhs_literal->shape(), + lhs_literal.shape(), rhs_literal.shape(), /*feature_group_count=*/1, window, dnums) .ConsumeValueOrDie(); @@ -585,18 +585,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; - std::unique_ptr result_literal = + Literal result_literal = evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); + CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); auto result = - absl::make_unique>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique>(result_literal.shape().dimensions(0), + result_literal.shape().dimensions(1), + result_literal.shape().dimensions(2), + result_literal.shape().dimensions(3)); result->Each([&](absl::Span indices, float* value) { - *value = result_literal->Get(indices); + *value = result_literal.Get(indices); }); return result; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 3ec0192148..a1b0f4045f 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, actual_literal, ErrorSpec(0.0001)); } @@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, actual_literal, ErrorSpec(0.0001)); } @@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, [](float a, float b) { return a + b; })); - LiteralTestUtil::ExpectR1Equal({0}, *result); + LiteralTestUtil::ExpectR1Equal({0}, result); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal, ErrorSpec(0.0001)); } @@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { @@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray3D) { @@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, + {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal, ErrorSpec(0.0001)); } @@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, - *actual_literal, ErrorSpec(0.0001)); + {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray4D) { @@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray4D) { @@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { @@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 43fd8fe1bd..84fe5b17d1 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal expected_literal = LiteralUtil::CreateR1(expected); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal, ErrorSpec(0.0001))); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 3d18fe3be2..2a0823aeca 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -527,7 +527,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + HloInstruction::CreateConstant(literal.Clone())); } } @@ -546,7 +546,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = absl::make_unique( + Literal unique_scalar( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -676,7 +676,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } auto inverse = computation_->AddInstruction( - HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + HloInstruction::CreateConstant((new_literal.Clone()))); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); @@ -1469,7 +1469,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(iota->shape().element_type()).Clone())); return ReplaceWithNewInstruction( iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); } @@ -1572,7 +1572,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - LiteralUtil::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).Clone()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1607,7 +1607,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).Clone())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -2062,7 +2062,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (!converted_pad_literal.ok()) { return false; } - return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + return converted_pad_literal.ValueOrDie() == reduce_init_literal; }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. @@ -2223,8 +2223,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction::CreateBroadcast( convolution->shape(), computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()) - .CloneToUnique())), + LiteralUtil::Zero(convolution->shape().element_type()))), {})); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index a0db4563fb..3fc1ba2427 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2932,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get()}); + Literal elements[] = {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector)}; + Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index ec281ae68f..30d33e0d35 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape feature_shape = scale->shape(); auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); @@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( const Shape feature_shape = scale->shape(); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, computation_->AddInstruction( @@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); auto epsilon_activation = add( @@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto elements_per_feature_literal = LiteralUtil::CreateR0(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); + elements_per_feature_literal.Convert(ptype)); auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 388fd5df99..e032b5c624 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -163,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)), dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)), dot->operand(1)->literal())); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index c30abd1d3e..795beb9ff5 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1245,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); + Literal elements[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}))); + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 414bfe7999..17e5090505 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + Literal elements0[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; + auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]}); + Literal element1 = LiteralUtil::CreateR0(3); + auto inner_tuple1 = LiteralUtil::MakeTuple({&element1}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( - inner_tuple0->shape(), tuple_constant, 0)); + inner_tuple0.shape(), tuple_constant, 0)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 0826380f65..0ac4a65ec6 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { expanded_filter = add(HloInstruction::CreateConcatenate( expanded_filter_shape, concat_operands, input_feature_dim)); } - auto zero = add(HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(expanded_filter_shape.element_type())))); + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); auto zero_filter = add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); auto new_filter = add( diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 6bf3810967..1deb412064 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); - Shape vshape = input_literal1->shape(); + Shape vshape = input_literal1.shape(); auto input1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal1))); @@ -78,13 +78,13 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); + LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, result, error_spec_); } TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -125,8 +125,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, - error_spec_); + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, result, error_spec_); } TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { @@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -213,7 +212,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { // Check the output correctness. LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, - *result, error_spec_); + result, error_spec_); } TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { @@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index c35569c661..5cc6d01c0f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { // Send 5 Infeed data of shape F32[3]. ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({1, 2, 3}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({1, 2, 3}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({4, 5, 6}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({4, 5, 6}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({7, 8, 9}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({10, 11, 12}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({10, 11, 12}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({13, 14, 15}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 3 infeed data should be added. - LiteralTestUtil::ExpectR0Near(45.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(45.0f, result_literal, ErrorSpec{1e-7}); } // Tests two Infeed operations with a total order. The order is enforced by @@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({3, 4}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({3, 4}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({5, 6}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8}), + LiteralUtil::CreateR0(false)}))); // Asynchronously launch the execution on the device. std::unique_ptr result; @@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8, 9}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8, 9}), + LiteralUtil::CreateR0(false)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({4, 5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({4, 5, 6}), + LiteralUtil::CreateR0(true)}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 6 infeed data should be added. - LiteralTestUtil::ExpectR0Near(66.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(66.0f, result_literal, ErrorSpec{1e-7}); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index bb105194f1..7af51db55a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {}; TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 1b3be199f6..852f34e06d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -56,9 +56,9 @@ ENTRY main { } )"; - std::unique_ptr lhs = LiteralUtil::CreateR3({{{1}, {2}}}); - std::unique_ptr rhs = LiteralUtil::CreateR3({{{3}, {4}}}); - RunTest(hlo_text, {lhs.get(), rhs.get()}); + Literal lhs = LiteralUtil::CreateR3({{{1}, {2}}}); + Literal rhs = LiteralUtil::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {&lhs, &rhs}); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 4ed91ef187..bec02e14f9 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_memory.size()); // Element is array-shaped: transfer array data to device buffer. const auto subliteral = LiteralSlice(literal, index); - std::unique_ptr relayed_out_literal; + Literal relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { @@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); - source = relayed_out_literal->untyped_data(); + source = relayed_out_literal.untyped_data(); TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, /*size=*/GetByteSizeRequirement(device_subshape), source, diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index bda8ebe579..d237f8930b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -590,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); ParseAndVerifyModule(absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index fa84d77223..b0061fa655 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -23,7 +23,6 @@ limitations under the License. namespace xla { namespace gpu { - // We want the input/output feature counts of an f16 conv to be factors of 8, // because without this cudnn can't use tensor cores on the conv. static constexpr int64 kDesiredNumFeaturesFactor = 8; @@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloComputation* comp = instr->parent(); const Shape& shape = instr->shape(); - auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + auto* zero = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 9d85d746d8..2a6415d0b6 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput( conv_window.dimensions(i).base_dilation() - 1); } PrimitiveType element_type = input->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloComputation* computation = kernel->parent(); PrimitiveType element_type = kernel->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(input->shape().element_type()))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 4550f36fdf..780539c164 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {}; TEST_F(GpuCopyTest, UseMemcpy) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index 9072b30317..f8120a5fa0 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } @@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) { TEST_F(InfeedTest, LargeInfeed) { Array4D array(80, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D(array)); + TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D(array)); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests that a large tuple infeed can be handled. TEST_F(InfeedTest, SingleInfeedLargeTuple) { Array4D array(40, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR4FromArray4D(array).get(), - LiteralUtil::CreateR0(5).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR4FromArray4D(array), + LiteralUtil::CreateR0(5)})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 8a45939c61..f837816cea 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -76,10 +76,10 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } - std::unique_ptr result = evaluator->TryEvaluate(instruction); + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. - if (result == nullptr) { + if (!evaluator->TryEvaluate(instruction, &result)) { VLOG(2) << "Constant folding failed for instruction: " << instruction->ToString(); continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 07cd1efc12..4da42844bd 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->Literal::CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { root->literal().EachCell( [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - matched = matched && (value == literal_clone->Get(rindexes)); + matched = matched && (value == literal_clone.Get(rindexes)); }); EXPECT_TRUE(matched); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index a3fcc0fefa..b76c50bb5b 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -321,18 +321,17 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(operand->shape().element_type())))); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(operand->shape().element_type()))); return MakePadHlo(operand, zero, padding_config); } StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, absl::Span broadcast_dimensions) { - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index eb6affadc8..e07a196d11 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({3, 4})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { @@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2( + CHECK_EQ(result_literal, + LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9, 10}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { @@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR3({{{9, 10}}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { @@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(9)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { @@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { @@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { @@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(0)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{0, 0}, {0, 0}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { @@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR0(0.0f)})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index e09d5868f2..9b18b0284f 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -73,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR0(84.0); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -105,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -135,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index a2f683b690..064b86493d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -54,9 +54,8 @@ namespace xla { namespace { template -StatusOr> Compare(const Shape& shape, HloOpcode opcode, - LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -94,9 +93,9 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); })); @@ -105,9 +104,9 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr> Compare( - const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -125,9 +124,9 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); })); @@ -193,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( const HloModule& module, absl::Span arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); @@ -206,11 +205,21 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .CloneToUnique(); + .Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(module, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( const HloComputation& computation, absl::Span arg_literals) { CHECK(computation.parent() != nullptr); @@ -224,11 +233,21 @@ StatusOr> HloEvaluator::Evaluate( } TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloComputation& computation, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(computation, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); @@ -247,18 +266,27 @@ StatusOr> HloEvaluator::Evaluate( << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - evaluated_[operand] = input_literal->CloneToUnique(); + evaluated_[operand] = input_literal->Clone(); } } TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal : arg_literals) { + arg_literal_ptrs.push_back(&literal); + } + return Evaluate(instruction, arg_literal_ptrs); } -StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction) { +StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { if (instruction->opcode() == HloOpcode::kParameter) { return tensorflow::errors::FailedPrecondition( "Cannot evaluate a parameter."); @@ -274,21 +302,22 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); } -std::unique_ptr HloEvaluator::TryEvaluate( - HloInstruction* instruction) { +bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) { + CHECK(result != nullptr); auto result_or = Evaluate(instruction); if (!result_or.ok()) { VLOG(1) << "TryEvaluate failed:" << result_or.status(); - return nullptr; + return false; } - return result_or.ConsumeValueOrDie(); + *result = result_or.ConsumeValueOrDie(); + return true; } -StatusOr> HloEvaluator::EvaluateWithSubstitutions( +StatusOr HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions) { @@ -299,7 +328,7 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( owned_operands.push_back(operand->Clone()); } else { owned_operands.push_back( - HloInstruction::CreateConstant(it->second->CloneToUnique())); + HloInstruction::CreateConstant(it->second->Clone())); } } @@ -316,12 +345,12 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( +StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), @@ -331,10 +360,10 @@ StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( +StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr operand_instr = - HloInstruction::CreateConstant(operand.CloneToUnique()); + HloInstruction::CreateConstant(operand.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); @@ -343,14 +372,14 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr> HloEvaluator::EvaluateDotOp( +StatusOr HloEvaluator::EvaluateDotOp( const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); TF_ASSIGN_OR_RETURN( Shape dot_shape, @@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); - evaluated_[parameter] = input_literal->CloneToUnique(); + evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( GetEvaluatedLiteralFor(operand), source_indices, dest_indices, AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += @@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex { // there is one) to `reshaped_start_indices`. static StatusOr> ReshapedGatherIndices( int64 index_vector_dim, const Literal& start_indices, - std::unique_ptr* reshaped_start_indices) { + Literal* reshaped_start_indices) { if (start_indices.shape().dimensions_size() != index_vector_dim) { return std::cref(start_indices); } @@ -834,16 +863,16 @@ static StatusOr> ReshapedGatherIndices( new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_start_indices, start_indices.Reshape(new_shape)); - return std::cref(**reshaped_start_indices); + return std::cref(*reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { - std::unique_ptr result = Literal::CreateFromShape(gather->shape()); + Literal result = Literal::CreateFromShape(gather->shape()); const Shape& shape = gather->shape(); const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr reshaped_start_indices; + Literal reshaped_start_indices; TF_ASSIGN_OR_RETURN( const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), @@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( - result->CopyElementFrom(operand, input_index, output_index)); + result.CopyElementFrom(operand, input_index, output_index)); return true; }; @@ -977,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = absl::make_unique( - ShapeUtil::GetTupleElementShape(operand->shape(), index)); - return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, - /*dest_shape_index=*/{}, - /*src_shape_index=*/{index}); + evaluated_[get_tuple_element] = + Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index)); + return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal, + /*dest_shape_index=*/{}, + /*src_shape_index=*/{index}); } Status HloEvaluator::HandleCopy(HloInstruction* copy) { TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); - - auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); - evaluated_[copy] = std::move(result); + evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone(); return Status::OK(); } @@ -1004,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); @@ -1036,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator .Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); @@ -1056,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; - std::unique_ptr result; + Literal result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -1081,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { - evaluated_[select] = on_true.CloneToUnique(); + evaluated_[select] = on_true.Clone(); } else { - evaluated_[select] = on_false.CloneToUnique(); + evaluated_[select] = on_false.Clone(); } return Status::OK(); } @@ -1097,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); if (pred.Get({})) { - evaluated_[tuple_select] = on_true.CloneToUnique(); + evaluated_[tuple_select] = on_true.Clone(); } else { - evaluated_[tuple_select] = on_false.CloneToUnique(); + evaluated_[tuple_select] = on_false.Clone(); } return Status::OK(); } @@ -1108,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); // Initialize the loop carried valued with the input to the While instruction. - auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); @@ -1118,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", while_hlo->name(), max_loop_iterations_); } - TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( - *cond_comp, {lcv.get()})); - keep_going = cond_val->GetFirstElement(); + TF_ASSIGN_OR_RETURN(auto cond_val, + cond_evaluator.Evaluate(*cond_comp, {&lcv})); + keep_going = cond_val.GetFirstElement(); if (keep_going) { TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {lcv.get()})); - VLOG(3) << "Loop iteration result: " << body_val->ToString(); + *body_comp, {&lcv})); + VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); loop_body_evaluator.ResetVisitStates(); @@ -1139,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { // hoops to make this work. namespace { template -StatusOr> EvaluateSortInternal( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortInternal(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { auto rank = ShapeUtil::Rank(keys_literal.shape()); TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) @@ -1179,57 +1206,55 @@ StatusOr> EvaluateSortInternal( result_keys.push_back(key_value.first); result_values.push_back(key_value.second); } - auto result_keys_literal = absl::make_unique(keys_literal.shape()); - result_keys_literal->PopulateR1(absl::Span(result_keys)); - auto result_values_literal = - absl::make_unique(values_literal.shape()); - result_values_literal->PopulateR1( + Literal result_keys_literal(keys_literal.shape()); + result_keys_literal.PopulateR1(absl::Span(result_keys)); + Literal result_values_literal(values_literal.shape()); + result_values_literal.PopulateR1( absl::Span(result_values)); return std::make_pair(std::move(result_keys_literal), std::move(result_values_literal)); }; - std::unique_ptr result_tuple; + Literal result_tuple; if (rank == 1) { auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = LiteralUtil::MakeTuple( - {result_pair.first.get(), result_pair.second.get()}); + result_tuple = + LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second}); } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto keys_result_literal = absl::make_unique(keys_literal.shape()); - auto values_result_literal = - absl::make_unique(values_literal.shape()); + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); int64 r1_length = keys_literal.shape().dimensions(1); for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto keys_r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); + .Reshape({r1_length})); TF_ASSIGN_OR_RETURN(auto values_r1_slice, values_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice); + .Reshape({r1_length})); + auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice); TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first->Reshape({1, r1_length})); + r1_result_pair.first.Reshape({1, r1_length})); TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom( - *sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom( - *sorted_values, {0, 0}, {row, 0}, {1, r1_length})); + r1_result_pair.second.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values, {0, 0}, {row, 0}, {1, r1_length})); } - result_tuple = LiteralUtil::MakeTuple( - {keys_result_literal.get(), values_result_literal.get()}); + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); } - VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } template -StatusOr> EvaluateSortCurried( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortCurried(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(1)->shape().element_type()) { case F32: return EvaluateSortInternal(sort, keys_literal, @@ -1248,9 +1273,9 @@ StatusOr> EvaluateSortCurried( } } -StatusOr> EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(0)->shape().element_type()) { case F32: return EvaluateSortCurried(sort, keys_literal, values_literal); @@ -1319,28 +1344,14 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr> -HloEvaluator::Evaluate( +template StatusOr HloEvaluator::Evaluate( const HloModule& module, absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - const HloModule& module, - absl::Span> arg_literals); - -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloComputation& computation, - absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( + +template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - absl::Span> arg_literals); + absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate( +template StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - HloInstruction* instruction, - absl::Span> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 72252bafc7..21e676d671 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloModule& module, absl::Span arg_literals); + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloComputation& computation, - absl::Span arg_literals); + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -82,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - HloInstruction* instruction, absl::Span arg_literals); + StatusOr Evaluate(HloInstruction* instruction, + absl::Span arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. // Precondition: // 1. all operands of the input instruction are constants. // 2. the instruction is not a Parameter operation. - StatusOr> Evaluate(HloInstruction* instruction); + StatusOr Evaluate(HloInstruction* instruction); - // Same as Evaluate, except returning nullptr on error. - std::unique_ptr TryEvaluate(HloInstruction* instruction); + // Same as Evaluate, except returning false on error and accepts an output + // pointer. + bool TryEvaluate(HloInstruction* instruction, Literal* result); // Evaluates a single HLO instruction, substituting the given literals for // some of the instruction's operands. // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr> EvaluateWithSubstitutions( + StatusOr EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions); - StatusOr> EvaluateElementwiseBinaryOp( - HloOpcode opcode, const Literal& lhs, const Literal& rhs); + StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr> EvaluateElementwiseUnaryOp( - HloOpcode opcode, const Literal& operand); + StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr> EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, const Literal& lhs, - const Literal& rhs); + StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this @@ -197,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); - return *(it->second); + return it->second; } // Tracks the HLO instruction and its evaluated literal result. @@ -205,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // that are no longer a parent for any other subsequent instruction in // post-orderring. // Must be cleared for each evaluation. - tensorflow::gtl::FlatMap> - evaluated_; + // Storing Literal in place require the container to have pointer stability so + // we cannot use FlatMap any more. + std::unordered_map evaluated_; private: template - static StatusOr> ElementWiseUnaryOpImpl( + static StatusOr ElementWiseUnaryOpImpl( HloInstruction* instruction, const std::function& unary_op, const Literal& operand_literal) { @@ -227,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeUtil::HumanString(operand->shape())); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 102ebb24ab..16411eb078 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -56,8 +56,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, evaluator_ = absl::make_unique(); } - std::unique_ptr Evaluate( - absl::Span arg_literals = {}) { + Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); @@ -69,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr input, float aabs = 0) { + void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, + float aabs = 0) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); - b.AddInstruction( - HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - auto element_type = expected->shape().element_type(); + auto element_type = expected.shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error)); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } } - void TestBinaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr lhs, - std::unique_ptr rhs) { + void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs, + Literal rhs) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( - HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } bool use_bfloat16_; @@ -117,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - Shape shape = low->shape(); + Shape shape = low.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -126,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -138,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); - Shape shape = value->shape(); + Shape shape = value.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -147,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -161,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); - Shape shape = on_true->shape(); + Shape shape = on_true.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred))); auto c2 = @@ -172,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) { HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -295,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; + std::vector args = {&lhs, &rhs, &rhs2}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -313,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { lhs_instruction, param_rhs2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(args); + Literal result = Evaluate(args); auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies Reshape operation is correctly evaluated. @@ -327,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(literal))); @@ -337,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - result->EachCell( - [&](absl::Span indices, NativeT value) { - std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); - }); + result.EachCell([&](absl::Span indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_NEAR(value, literal_clone.Get(rindexes), 0.031250); + }); } // Verifies Broadcast operation is correctly evaluated. @@ -356,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, {1, 2})); + output_literal.shape(), literal_instruction, {1, 2})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -374,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloInstruction::CreateConstant(std::move(input_literal))); // Broadcast dimension should be empty in the case of scalars. b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, + output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -398,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -420,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({100, 200}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -432,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); - ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -452,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); auto expected = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); - ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } PaddingConfig CreatePaddingConfig( @@ -495,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -522,7 +518,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -535,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = LiteralUtil::CreateR4FromArray4D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -566,7 +562,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = absl::make_unique>(1, 5); @@ -577,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -611,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -650,7 +646,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected_array = Array2D({ @@ -662,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -696,11 +692,11 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({22.f, 28.f}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -740,7 +736,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = Array2D({ {22.f, 28.f}, @@ -750,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -794,12 +790,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -849,7 +845,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -862,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -933,7 +929,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -943,7 +939,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1011,7 +1007,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -1021,7 +1017,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1071,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1085,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1135,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1150,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, @@ -1207,7 +1203,7 @@ TEST_P(HloEvaluatorTest, window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1223,7 +1219,7 @@ TEST_P(HloEvaluatorTest, })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { @@ -1261,14 +1257,14 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); std::iota(input_elems.begin(), input_elems.end(), -7); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4))); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); std::iota(filter_elems.begin(), filter_elems.end(), -31); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4))); @@ -1278,13 +1274,13 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 1, 8); expected_array.FillWithYX( Array2D({{668, 664, 660, 656, 668, 680, 692, 704}})); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1317,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { module().AddEntryComputation(b.Build()); HloEvaluator hlo_eval; - std::unique_ptr result = - hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR0Equal(kNumElements, *result); + Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(kNumElements, result); } // Reducing many numbers should be fast because it doesn't create @@ -1396,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({6, 18}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1448,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{6, 7}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1505,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1516,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = @@ -1566,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); std::vector output_dims = {4, 3, 3, 3, 4, 4}; - std::unique_ptr result_literal = + Literal result_literal = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 8.0f); - EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1598,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*strides=*/{2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {3}, {19}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1632,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1668,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1705,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { shape, operand, update, start_indices)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, {5, -6, -7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1741,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, {5, 6, 7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1780,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto result_inner_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); - auto expected = LiteralUtil::MakeTuple({ - result_inner_literal.get(), - result_inner_literal.get(), - }); + auto expected = + LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1820,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) { b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected = LiteralUtil::CreateR4FromArray4D({ @@ -1842,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1858,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; + Literal param0_literal = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, LiteralUtil::CreateR1({1, 2, 3, 4}).get()}, - {square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + add, {{param0, ¶m0_literal}, {square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1883,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; - auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); + auto result = + evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1906,12 +1901,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1930,12 +1925,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1954,14 +1949,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3( + LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), start_indices.get()}))); + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1980,15 +1974,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, @@ -2008,15 +2001,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -2035,12 +2027,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -2059,13 +2050,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{8}}, {{5}}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2084,11 +2074,10 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2108,12 +2097,12 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr start_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{0, 1}, {2, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { @@ -2138,15 +2127,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { @@ -2171,15 +2158,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { @@ -2205,15 +2191,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { @@ -2239,15 +2223,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { @@ -2273,17 +2255,15 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2( + Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({2, 1}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2( + LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), - ErrorSpec{0.1, 0.01})); + Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { @@ -2309,15 +2289,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { @@ -2343,15 +2321,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { @@ -2376,21 +2353,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + Literal expected = LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // {{-40, 40}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, @@ -2416,21 +2390,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + Literal expected = LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { @@ -2455,16 +2426,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { @@ -2489,17 +2458,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - std::unique_ptr updates = - LiteralUtil::CreateR3({{{10}}, {{20}}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { @@ -2524,13 +2490,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *operand, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + operand, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { @@ -2557,16 +2521,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr scatter_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20}, {30, 40}}); - std::unique_ptr expected = - LiteralUtil::CreateR1({10, 61, 32}); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + Literal expected = LiteralUtil::CreateR1({10, 61, 32}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -2603,11 +2564,10 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr arg = LiteralUtil::CreateR1( + Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); - std::unique_ptr expected = - LiteralUtil::CreateR0(bfloat16(44.0f)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()}))); + Literal expected = LiteralUtil::CreateR0(bfloat16(44.0f)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); } INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 63303aef1e..7f090a52db 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -246,15 +246,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) { parent_->evaluated_[convert] = std::move(result); } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); + parent_->evaluated_[convert] = result.Relayout(convert->shape().layout()); } return Status::OK(); } @@ -262,15 +261,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleBitcastConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) { parent_->evaluated_[convert] = std::move(result); } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); + parent_->evaluated_[convert] = result.Relayout(convert->shape().layout()); } return Status::OK(); } @@ -978,10 +976,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span out_index) { + result.Populate([&](absl::Span out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -1157,8 +1155,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(result_val); }; - auto result = absl::make_unique(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + Literal result(result_shape); + TF_RETURN_IF_ERROR(result.PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); @@ -1231,9 +1229,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } } - auto result = absl::make_unique(dot->shape()); + Literal result(dot->shape()); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span result_index) { + result.Populate([&](absl::Span result_index) { ElementwiseT result_val = static_cast(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1280,8 +1278,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = absl::make_unique(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( + Literal result(pad->shape()); + TF_RETURN_IF_ERROR(result.Populate( [&scalar](absl::Span multi_index) { return scalar; })); const Literal& evaluated_operand = @@ -1289,7 +1287,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + std::vector target_index(ShapeUtil::Rank(result.shape()), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1311,8 +1309,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return true; } } - result->Set(target_index, - evaluated_operand.Get(input_index)); + result.Set(target_index, + evaluated_operand.Get(input_index)); return true; }; @@ -1439,16 +1437,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> MapImpl(HloInstruction* map) { + StatusOr MapImpl(HloInstruction* map) { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = absl::make_unique(map->shape()); + Literal result(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { - std::vector> arg_literals; + result.Populate([&](absl::Span multi_index) { + std::vector arg_literals; arg_literals.reserve(operands.size()); // Construct scalar literal parameters to be passed to the map @@ -1463,16 +1461,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_literals.push_back(std::move(curr_val_literal)); } - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) + Literal computed_result = + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. embedded_evaluator.ResetVisitStates(); - return computed_result->Get({}); + return computed_result.Get({}); })); return std::move(result); } @@ -1557,9 +1553,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess(a, b); }); - auto result_literal = absl::make_unique(keys_literal.shape()); - result_literal->PopulateR1(absl::Span(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + Literal result_literal(keys_literal.shape()); + result_literal.PopulateR1(absl::Span(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); return result_literal; }; @@ -1568,16 +1564,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = absl::make_unique(keys_literal.shape()); + Literal result_literal(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result = sort_r1(*r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( - *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + .Reshape({r1_length})); + auto r1_result = sort_r1(r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + r1_result, {0, 0}, {row, 0}, {1, r1_length})); } parent_->evaluated_[sort] = std::move(result_literal); } @@ -1651,9 +1647,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - absl::InlinedVector, 1> results(num_args); + absl::InlinedVector results(num_args); for (int64 i = 0; i < num_args; ++i) { - results[i] = absl::make_unique(result_shape); + results[i] = Literal(result_shape); } Status eval_status; @@ -1667,7 +1663,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } for (int64 input = 0; input < num_args; ++input) { - TF_RETURN_IF_ERROR(results[input]->Populate( + TF_RETURN_IF_ERROR(results[input].Populate( [&](absl::Span multi_index) { if (!eval_status.ok()) { return init_scalars[input]; @@ -1703,8 +1699,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Evaluate computation with specified literal operands. - absl::InlinedVector, 1> - embedded_operands; + absl::InlinedVector embedded_operands; for (ReturnT value : result_values) { embedded_operands.push_back( LiteralUtil::CreateR0(value)); @@ -1717,11 +1712,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_operands.size()); std::transform(embedded_operands.begin(), embedded_operands.end(), embedded_operands_ptrs.begin(), - [](const std::unique_ptr& ptr) { - return ptr.get(); - }); + [](Literal& literal) { return &literal; }); - TF_ASSIGN_OR_RETURN(std::unique_ptr computed_result, + TF_ASSIGN_OR_RETURN(Literal computed_result, embedded_evaluator.Evaluate( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on @@ -1729,10 +1722,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_evaluator.ResetVisitStates(); // Assign computed result to result_val. if (!has_tuple_output) { - result_values[0] = computed_result->Get({}); + result_values[0] = computed_result.Get({}); } else { for (int64 i = 0; i < num_args; ++i) { - result_values[i] = computed_result->Get( + result_values[i] = computed_result.Get( /*multi_index=*/{}, /*shape_index=*/{i}); } } @@ -1748,9 +1741,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (!has_tuple_output) { parent_->evaluated_[reduce] = std::move(results[0]); } else { - auto tuple_result = absl::make_unique(reduce->shape()); + Literal tuple_result(reduce->shape()); for (int64 i = 0; i < num_args; ++i) { - TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i})); + TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); } parent_->evaluated_[reduce] = std::move(tuple_result); } @@ -1781,10 +1774,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = absl::make_unique(select_and_scatter->shape()); + Literal result(select_and_scatter->shape()); // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( + TF_RETURN_IF_ERROR(result.Populate( [&](absl::Span output_index) { return init_scalar; })); std::vector window_dimension_sizes; @@ -1834,15 +1827,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - curr_val_literal->Set({}, curr_val); - selected_val_literal->Set({}, *selected_val); - std::unique_ptr computed_result = + curr_val_literal.Set({}, curr_val); + selected_val_literal.Set({}, *selected_val); + Literal computed_result = embedded_evaluator .Evaluate( - *select, - {selected_val_literal.get(), curr_val_literal.get()}) + *select, {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); - bool selected = !computed_result->Get({}); + bool selected = !computed_result.Get({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1856,16 +1848,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (std::equal(operand_index.begin(), operand_index.end(), selected_index->begin())) { auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - source_literal_scatter->Set({}, source); - scattered_literal->Set({}, scattered); - std::unique_ptr computed_result = + auto scattered = result.Get(operand_index); + source_literal_scatter.Set({}, source); + scattered_literal.Set({}, scattered); + Literal computed_result = embedded_evaluator - .Evaluate(*scatter, - {source_literal_scatter.get(), - scattered_literal.get()}) + .Evaluate( + *scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); + result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again // on the same computation. embedded_evaluator.ResetVisitStates(); @@ -1916,10 +1908,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = absl::make_unique(reduce_window->shape()); + Literal result(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span output_index) { + result.Populate([&](absl::Span output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1935,18 +1927,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(curr_val); const auto result_val_literal = LiteralUtil::CreateR0(result_val); - std::unique_ptr computed_result = + Literal computed_result = embedded_evaluator .Evaluate( - *function, - {result_val_literal.get(), curr_val_literal.get()}) + *function, {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again // on the same computation. embedded_evaluator.ResetVisitStates(); - result_val = computed_result->Get({}); + result_val = computed_result.Get({}); }); return result_val; @@ -1961,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // literal (if there is one) to `reshaped_indices`. StatusOr> ReshapedScatterIndices( int64 index_vector_dim, const Literal& indices, - std::unique_ptr* reshaped_indices) { + Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { return std::cref(indices); } @@ -1970,7 +1961,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { indices.shape().dimensions().end()); new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); - return std::cref(**reshaped_indices); + return std::cref(*reshaped_indices); } // Returns an ShapeUtil::IndexIterationSpace that iterates over the update @@ -2230,7 +2221,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scatter->scatter_dimension_numbers(); const Literal& operand = parent_->GetEvaluatedLiteralFor(scatter->operand(0)); - std::unique_ptr reshaped_scatter_indices; + Literal reshaped_scatter_indices; TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, ReshapedScatterIndices(dim_numbers.index_vector_dim(), parent_->GetEvaluatedLiteralFor( @@ -2260,7 +2251,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Initialize the result with the operand. This makes it easier to handle // the updates even when the indices are repeated. - std::unique_ptr result = operand.CloneToUnique(); + Literal result = operand.Clone(); HloEvaluator embedded_evaluator; auto scatter_inner_loop_body = [&](absl::Span update_window_index, @@ -2299,19 +2290,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } auto result_value_literal = - LiteralUtil::CreateR0(result->Get(input_index)); + LiteralUtil::CreateR0(result.Get(input_index)); auto update_value_literal = LiteralUtil::CreateR0(updates.Get(update_index)); - std::unique_ptr updated_result = + Literal updated_result = embedded_evaluator .Evaluate( *scatter->to_apply(), - {result_value_literal.get(), update_value_literal.get()}) + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. embedded_evaluator.ResetVisitStates(); - result->Set(input_index, updated_result->Get({})); + result.Set(input_index, updated_result.Get({})); return true; }; @@ -2361,7 +2352,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); + TF_RETURN_IF_ERROR(result.Populate(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); } @@ -2575,7 +2566,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (ShapeUtil::Rank(iota->shape()) > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], - result->Broadcast(iota->shape(), {iota->iota_dimension()})); + result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); parent_->evaluated_[iota] = std::move(result); @@ -2645,9 +2636,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { + StatusOr DynamicSlice(const Literal& operand_literal, + const Literal& start_indices_literal, + const Shape& result_shape) { auto start_indices_typed = start_indices_literal.data(); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); @@ -2660,9 +2651,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector operand_indices(start.size()); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2676,12 +2667,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); + StatusOr DynamicUpdateSlice(const Literal& operand_literal, + const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.Clone(); auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result->shape()); + const auto rank = ShapeUtil::Rank(result.shape()); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the @@ -2689,15 +2680,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { for (int64 i = 0; i < rank; ++i) { start[i] = std::min( std::max(0, start[i]), - result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + result.shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector result_index(rank, 0); auto func = [&](absl::Span update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - result->Set(result_index, - update_literal.Get(update_index)); + result.Set(result_index, + update_literal.Get(update_index)); return true; }; @@ -2710,7 +2701,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } - StatusOr> ElementWiseUnaryOp( + StatusOr ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { const Literal& operand_literal = @@ -2723,7 +2714,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr> ElementWiseBinaryOp( + StatusOr ElementWiseBinaryOp( HloInstruction* instruction, const std::function& binary_op) { @@ -2745,10 +2736,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2757,7 +2748,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> ElementwiseTernaryOp( + StatusOr ElementwiseTernaryOp( HloInstruction* instruction, const std::function& ternary_op) { const auto shape = instruction->shape(); @@ -2782,10 +2773,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f06c98f2e7..85fa3ce964 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -250,7 +250,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); - instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + instruction = CreateTrace(literal.GetR1U8AsString(), operands(0)); break; } case HloOpcode::kFusion: { @@ -527,7 +527,7 @@ StatusOr> HloInstruction::CreateFromProto( } /* static */ std::unique_ptr HloInstruction::CreateConstant( - std::unique_ptr literal) { + Literal literal) { return absl::make_unique(std::move(literal)); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index bf25157395..4f6cac1396 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -359,8 +359,7 @@ class HloInstruction { const string& name); // Creates a literal constant instruction. - static std::unique_ptr CreateConstant( - std::unique_ptr literal); + static std::unique_ptr CreateConstant(Literal literal); // Creates an Iota instruction. static std::unique_ptr CreateIota(const Shape& shape, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index fb7345a2ad..e92882c22a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -845,8 +845,8 @@ std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } -HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) - : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), +HloConstantInstruction::HloConstantInstruction(Literal literal) + : HloInstruction(HloOpcode::kConstant, literal.shape()), literal_(std::move(literal)) {} HloConstantInstruction::HloConstantInstruction(const Shape& shape) @@ -854,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape) HloInstructionProto HloConstantInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - if (literal_ != nullptr) { + if (literal_.has_value()) { *proto.mutable_literal() = literal_->ToProto(); } return proto; @@ -876,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, if (!mutable_array_subshape->has_layout() || !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); + *literal_ = literal_->Relayout(new_layout, shape_index); *mutable_array_subshape->mutable_layout() = new_layout; } } @@ -893,7 +893,8 @@ std::unique_ptr HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return absl::make_unique(literal_->CloneToUnique()); + CHECK(literal_.has_value()); + return absl::make_unique(literal_->Clone()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -901,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( CanonicalNameMap* canonical_name_map) const { string operands; // For constants, show the actual value in place of an empty operand list. - if (literal_ != nullptr && + if (literal_.has_value() && ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple @@ -936,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstructionProto HloTraceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_literal() = literal_->ToProto(); + *proto.mutable_literal() = literal_.ToProto(); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index c3a7801164..2d7bc83855 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction { public: - explicit HloConstantInstruction(std::unique_ptr literal); + explicit HloConstantInstruction(Literal literal); // Used when the literal is too large and dropped. explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } // Returns whether there is literal associated with this instruction. - bool HasLiteral() const { return literal_ != nullptr; } + bool HasLiteral() const { return literal_.has_value(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + absl::optional literal_; }; class HloTraceInstruction : public HloInstruction { public: explicit HloTraceInstruction(const string& tag, HloInstruction* operand); // Returns a tag to be used in tracing. - string TracingTag() const { return literal_->GetR1U8AsString(); } + string TracingTag() const { return literal_.GetR1U8AsString(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + Literal literal_; }; class HloFusionInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index c54360b063..11caa89c54 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -105,16 +105,13 @@ class HloParser { string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); bool ParseControlPredecessors(HloInstruction* instruction); - bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape); - bool ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape); + bool ParseLiteral(Literal* literal, const Shape& shape); + bool ParseTupleLiteral(Literal* literal, const Shape& shape); + bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); + bool ParseDenseLiteral(Literal* literal, const Shape& shape); + bool ParseSparseLiteral(Literal* literal, const Shape& shape); template - bool ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape); + bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. @@ -577,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConstant: { - std::unique_ptr literal; + Literal literal; if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || @@ -1810,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // literal // ::= tuple // ::= non_tuple -bool HloParser::ParseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) : ParseNonTupleLiteral(literal, shape); } @@ -1821,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr* literal, // literal_list // ::= /*empty*/ // ::= literal (',' literal)* -bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return TokenError(StrCat("expects tuple constant in shape ", ShapeUtil::HumanString(shape))); @@ -1830,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } - std::vector> elements( - ShapeUtil::TupleElementCount(shape)); + std::vector elements(ShapeUtil::TupleElementCount(shape)); if (lexer_.GetKind() == TokKind::kRparen) { // empty @@ -1857,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // ::= rank01 // ::= rank2345 // rank2345 ::= shape sparse_or_nested_array -bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { if (LayoutUtil::IsSparseArray(shape)) { return ParseSparseLiteral(literal, shape); } @@ -1867,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return ParseDenseLiteral(literal, shape); } -bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1962,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // TODO(congliu): bool type literals with rank >= 1 are actually // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, - linear_index++, literal->get())) { + linear_index++, literal)) { return false; } lexer_.Lex(); @@ -1973,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -1984,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else { @@ -1996,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, } // end of switch } while (nest_level > 0); - *literal = (*literal)->Relayout(shape.layout()); + *literal = literal->Relayout(shape.layout()); return true; } -bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return false; } @@ -2041,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, } template -bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector index; tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = absl::make_unique(shape); + *literal = Literal(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -2121,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return false; } - if ((*literal)->sparse_element_count() + 1 == + if (literal->sparse_element_count() + 1 == LayoutUtil::MaxSparseElements(shape.layout())) { return Error( lexer_.GetLoc(), @@ -2129,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, ShapeUtil::HumanStringWithLayout(shape))); } - (*literal)->AppendSparseElement(index, value); + literal->AppendSparseElement(index, value); } - (*literal)->SortSparseElements(); + literal->SortSparseElements(); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 66ac1f66fd..fa7f216321 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -118,16 +118,16 @@ StatusOr> HloRunner::TransferLiteralsToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const absl::Span> literals) { + const absl::Span literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { - literal_pointers.push_back(literal.get()); + literal_pointers.push_back(&literal); } return TransferLiteralsToDevice(literal_pointers); } -StatusOr> HloRunner::TransferLiteralFromDevice( +StatusOr HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { TF_ASSIGN_OR_RETURN( auto stream, backend().BorrowStream(backend().default_stream_executor())); @@ -135,7 +135,7 @@ StatusOr> HloRunner::TransferLiteralFromDevice( buffer); } -StatusOr> HloRunner::Execute( +StatusOr HloRunner::Execute( std::unique_ptr module, const absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { @@ -150,15 +150,15 @@ StatusOr> HloRunner::Execute( return TransferLiteralFromDevice(result); } -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const absl::Span> arguments, - bool run_hlo_passes, ExecutionProfile* profile) { +StatusOr HloRunner::Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { - argument_pointers.push_back(argument.get()); + argument_pointers.push_back(&argument); } return Execute( /*module=*/std::move(module), @@ -204,7 +204,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } -StatusOr>> HloRunner::ExecuteReplicated( +StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { TF_ASSIGN_OR_RETURN( @@ -290,9 +290,9 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique(); + Literal literal; TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, options.outfeed_shape, literal.get())); + executor, options.outfeed_shape, &literal)); if (options.outfeed_values != nullptr) { options.outfeed_values->push_back(std::move(literal)); } @@ -310,10 +310,10 @@ StatusOr>> HloRunner::ExecuteReplicated( argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; - std::vector> exec_results; + std::vector exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, backend().transfer_manager()->TransferLiteralFromDevice( streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 76d8b92bed..2e934bf66a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -72,7 +72,7 @@ class HloRunner { // A pointer to a vector where the outfeed values will be stored. If // nullptr, the values will be read and discarded. - std::vector>* outfeed_values = nullptr; + std::vector* outfeed_values = nullptr; // Whether the HLO passes should be run on the input module. Usually // saved modules are coming from after the HLO pass pipeline, so triggering @@ -106,24 +106,23 @@ class HloRunner { StatusOr> TransferLiteralsToDevice( const absl::Span literals); StatusOr> TransferLiteralsToDevice( - const absl::Span> literals); - StatusOr> TransferLiteralFromDevice( - const ShapedBuffer& buffer); + const absl::Span literals); + StatusOr TransferLiteralFromDevice(const ShapedBuffer& buffer); // Executes the given module with given literals as input and returns the // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - StatusOr> Execute( - std::unique_ptr module, - const absl::Span arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); - StatusOr> Execute( - std::unique_ptr module, - const absl::Span> arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. @@ -140,7 +139,7 @@ class HloRunner { // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. - StatusOr>> ExecuteReplicated( + StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 0cac210c24..8f0423bb1c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { padding_config.add_dimensions()->set_interior_padding(-1); builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {100}), param, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(F32).CloneToUnique())), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); auto module = CreateNewModule(); @@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { padding_config.add_dimensions()->set_interior_padding(-1); builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {100}), param, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(F32).CloneToUnique())), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 37b774b8a5..06f0e1ed25 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -918,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // inner_broadcast_result is the Broadcast'(Const0) bit in // BinaryOp(Broadcast'(Const0), Const1) TF_ASSIGN_OR_RETURN( - std::unique_ptr inner_broadcast_result, + Literal inner_broadcast_result, broadcast_const_operand->literal().Broadcast( scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); @@ -928,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + opcode, scalar_indexed_const->literal(), inner_broadcast_result))); } else { TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + opcode, inner_broadcast_result, scalar_indexed_const->literal()))); } ConstantArray* new_source = Construct(literal_for_new_source); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 9746d176cc..df9cbab915 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -347,21 +347,19 @@ class IndexedArrayAnalysis { } } - Literal* TakeOwnership(std::unique_ptr literal) { + Literal* TakeOwnership(Literal literal) { owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } - StatusOr TakeOwnership( - StatusOr> literal_or_error) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - std::move(literal_or_error)); + StatusOr TakeOwnership(StatusOr literal_or_error) { + TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } std::vector> owned_tensors_; - std::vector> owned_literals_; + std::vector owned_literals_; tensorflow::gtl::FlatMap cache_; }; diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 5695bc2420..93a74dbfa6 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } // Test that `constant` function is changed to `broadcast`. @@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 5dea124768..a06d6113e8 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -73,30 +73,29 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // Transform the ShapedBuffer arguments into literals which the evaluator // consumes. - std::vector> arg_literals; + std::vector arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + TF_ASSIGN_OR_RETURN(Literal arg_literal, transfer_manager->TransferLiteralFromDevice( run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. - std::unique_ptr result_literal; + Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); - TF_ASSIGN_OR_RETURN(result_literal, - evaluator_->Evaluate>( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( + *computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, transfer_manager->AllocateScopedShapedBuffer( - result_literal->shape(), run_options->allocator(), + result_literal.shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - run_options->stream(), *result_literal, result)); + run_options->stream(), result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 69c7e42601..f8baba03c3 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); auto constant_literal2 = LiteralUtil::CreateR2WithLayout( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); - Shape ashape = constant_literal1->shape(); + Shape ashape = constant_literal1.shape(); auto constant1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(constant_literal1))); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index f0e2566a3f..922ebdf0e3 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span arguments, module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, *argument)); - *module->add_arguments() = literal->ToProto(); + *module->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, result)); - *module->mutable_result() = literal->ToProto(); + *module->mutable_result() = literal.ToProto(); return Status::OK(); } @@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, + Literal result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, - result_literal->shape())) { - *result->mutable_literal() = result_literal->ToProto(); + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal->Relayout(*return_shape)->ToProto(); + result_literal.Relayout(*return_shape).ToProto(); } return Status::OK(); } @@ -959,9 +958,9 @@ std::unique_ptr CloneShapedBufferOnDevice( Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); std::vector replicas; if (arg->has_device_handle()) { @@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), *literal, shaped_buffer)); + stream.get(), literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor = replicas[arg->replica_id()]; } - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, *literal); + return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, + literal); } Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, @@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), *literal)); - *result->mutable_literal() = literal->ToProto(); + executor, arg->shape_with_layout(), literal)); + *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, HloModule::CreateFromProto(arg->computation(), config)); HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate>( - *module, /*arg_literals=*/{})); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( + *module, /*arg_literals=*/{})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); + result_literal = result_literal.Relayout(arg->output_layout()); } - *result->mutable_literal() = result_literal->ToProto(); + *result->mutable_literal() = result_literal.ToProto(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index b8d2d546e5..a21e586efa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() { return r; } -StatusOr> TransferManager::TransferLiteralFromDevice( +StatusOr TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { - StatusOr> ret; + StatusOr ret; se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); @@ -63,7 +63,7 @@ StatusOr> TransferManager::TransferLiteralFromDevice( if (!s.ok()) { return s; } - return absl::make_unique(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferLiteralFromDevice( @@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice( return substream->BlockHostUntilDone(); } -StatusOr> TransferManager::TransferArrayFromDevice( +StatusOr TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { - StatusOr> ret; + StatusOr ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. @@ -122,7 +122,7 @@ StatusOr> TransferManager::TransferArrayFromDevice( if (!s.ok()) { return s; } - return absl::make_unique(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferArrayToDevice( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 21725946b3..f952e64af2 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -57,7 +57,7 @@ class TransferManager { // without waiting for any other operation on a stream to complete. // // This function should be avoided in favor of the asynchronous version below. - virtual StatusOr> TransferLiteralFromDevice( + virtual StatusOr TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); virtual Status TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, @@ -113,9 +113,9 @@ class TransferManager { Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - StatusOr> TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source); + StatusOr TransferArrayFromDevice(se::Stream* stream, + const Shape& shape, + const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 2b2a2eb42a..e9a07b14ed 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()}))); + Literal elements[] = {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})}; + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index c3c2603c7e..541b117e02 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -183,8 +183,7 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->mutable_operand(0); auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); + StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable init: " << indvar_init_result.status(); @@ -197,31 +196,27 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); // The initial value of the induction variable. - std::unique_ptr indvar_iter_val = - std::move(indvar_init_result).ValueOrDie(); + Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie(); for (int64 trip_count = 0; trip_count != max_value_returned + 1; ++trip_count) { auto* while_cond = while_op->while_condition(); auto* while_cond_root = while_cond->root_instruction(); auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions( - while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + StatusOr result = evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); if (!result.ok()) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - if (result.ValueOrDie()->data() == absl::Span{false}) { + if (result.ValueOrDie().data() == absl::Span{false}) { VLOG(2) << "Loop has static trip count of " << trip_count; return trip_count; } // Calculate the value of the induction variable after one iteration of the // loop, and check whether the while condition is true with this new value. - StatusOr> indvar_next_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, - {{while_body_indvar, indvar_iter_val.get()}}); + StatusOr indvar_next_result = evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); if (!indvar_next_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable update: " << indvar_next_result.status(); diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0bf4556b43..c257566fb2 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -41,7 +41,6 @@ limitations under the License. namespace xla { namespace { - class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 0.0001}; @@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 0x8000000000000000LL, 1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{1, 0x7FFFFFFFFFFFFFFLL, @@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0, 1, 0x8000000000000000LL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Add(lhs_param, rhs_param); @@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 1, 0, -1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{-1, 0, @@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Sub(lhs_param, rhs_param); @@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { XlaBuilder b(TestName()); std::vector lhs{static_cast(0x8000000000000000ULL)}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::vector rhs{static_cast(0x7FFFFFFFFFFFFFFFULL)}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); Lt(lhs_param, rhs_param); - ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)}); + ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)}); } TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { @@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + Literal a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a_constant = ConstantR1(&builder, a_values); - auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); + auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param"); - std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); auto b_param = ConstantR1(&builder, b_values); auto sum1 = Add(a_constant, b_constant); @@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + Literal param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto sum = ConstantR0(&b, 0.0f); - auto param = Parameter(&b, 0, param_literal->shape(), "param"); + auto param = Parameter(&b, 0, param_literal.shape(), "param"); for (float exponent : exponents) { sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } @@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Pow(Exp(param0), param1); std::vector expected(values0.size()); @@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Log(Pow(param0, param1)); std::vector expected(values0.size()); @@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); @@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Div(param0, Exp(param1)); std::vector expected(values0.size()); @@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + client_->TransferToServer(literal2).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(Div(param0, param1), param2); std::vector expected(values0.size()); @@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Div(param1, param2)); std::vector expected(values0.size()); @@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); @@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - std::unique_ptr literal3 = LiteralUtil::CreateR1(values3); + Literal literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = - client_->TransferToServer(*literal3).ConsumeValueOrDie(); + client_->TransferToServer(literal3).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); - auto param3 = Parameter(&b, 3, literal3->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); + auto param3 = Parameter(&b, 3, literal3.shape(), "param2"); Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); @@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, @@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); Array3D expected(0, 7, 0); @@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); - auto p = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p = Parameter(&builder, 0, param0_literal.shape(), "param0"); Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, @@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}); TF_ASSERT_OK_AND_ASSIGN(auto input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Tanh(input); ComputeAndCompareR1( @@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, @@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Exp(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::exp(input_literal->Get({i}))); + expected_result.push_back(std::exp(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // implementation on XLA CPU. XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, @@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Log(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::log(input_literal->Get({i}))); + expected_result.push_back(std::log(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); Tuple(&builder, {cmp_dim_0, cmp_dim_1}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), - LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{true, true}, {true, false}}), + LiteralUtil::CreateR2({{true, false}, {false, false}})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { @@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); - auto a = ConstantLiteral(&builder, *a_literal); + Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); + auto a = ConstantLiteral(&builder, a_literal); auto b = ConstantR1(&builder, r1); Add(a, b, {1}); @@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); auto y_literal = LiteralUtil::CreateR1({4, 5}); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); - auto y = Parameter(&builder, 1, y_literal->shape(), "y"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); + auto y = Parameter(&builder, 1, y_literal.shape(), "y"); auto slice = Slice(x, {1}, {2}, {1}); Sub(slice, y); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index ac90a3adb6..bc2ba151a3 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -63,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); + input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { @@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { @@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { BatchNormTraining(h0, h1, h2, /*epsilon=*/1, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) - .get(), - LiteralUtil::CreateR1(std::vector(260, 1.0f)).get(), - LiteralUtil::CreateR1(std::vector(260, 0.0f)).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 0.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { BatchNormTraining(h0, h1, h2, /*epsilon=*/-100, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR3FromArray3D( - {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) - .get(), - LiteralUtil::CreateR1(std::vector(1, 15.0f)).get(), - LiteralUtil::CreateR1(std::vector(1, 125.0f)).get()}); + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}), + LiteralUtil::CreateR1(std::vector(1, 15.0f)), + LiteralUtil::CreateR1(std::vector(1, 125.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) - .get(), - LiteralUtil::CreateR1({0, 0}).get(), - LiteralUtil::CreateR1({16, 20}).get()}); + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}), + LiteralUtil::CreateR1({0, 0}), + LiteralUtil::CreateR1({16, 20})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } struct BatchNormTestParam { @@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); - auto expected = LiteralUtil::MakeTuple( - {expected_normalized.get(), LiteralUtil::CreateR1(mean).get(), - LiteralUtil::CreateR1(var).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_normalized, LiteralUtil::CreateR1(mean), + LiteralUtil::CreateR1(var)}); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); BatchNormTraining(input_activations, scale_activations, offset_activations, epsilon, feature_index); @@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); ComputeAndCompareTuple( - &builder, *expected, + &builder, expected, {input_data.get(), scale_data.get(), offset_data.get()}, ErrorSpec(0.01, 1)); } @@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean"); auto variance_activations = - Parameter(&builder, 4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal.shape(), "variance"); Array4D expected = normalized; std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr variance_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); BatchNormInference(input_activations, scale_activations, offset_activations, mean_activations, variance_activations, epsilon, @@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { auto grad_output_literal = LiteralUtil::CreateR4FromArray4D(grad_output_array); - auto input_parameter = - Parameter(&builder, 0, input_literal->shape(), "input"); - auto scale_parameter = - Parameter(&builder, 1, scale_literal->shape(), "scale"); - auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); - auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); + auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input"); + auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance"); auto grad_output_parameter = - Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal.shape(), "grad_output"); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr var_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); std::unique_ptr grad_output_data = - client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); + client_->TransferToServer(grad_output_literal).ConsumeValueOrDie(); BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, grad_output_parameter, epsilon, feature_index); - auto expected = - LiteralUtil::MakeTuple({expected_grad_activation.get(), - LiteralUtil::CreateR1(grad_scale).get(), - LiteralUtil::CreateR1(grad_offset).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_grad_activation, LiteralUtil::CreateR1(grad_scale), + LiteralUtil::CreateR1(grad_offset)}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {input_data.get(), scale_data.get(), mean_data.get(), var_data.get(), grad_output_data.get()}, ErrorSpec(0.01, 1)); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 65589b0d6a..e9728e636f 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-1.6875f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, - {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) - .get(), + {{static_cast(3.7f)}, {static_cast(6.04f)}}}}), LiteralUtil::CreateR1( - {static_cast(4), static_cast(5)}) - .get(), + {static_cast(4), static_cast(5)}), LiteralUtil::CreateR1( - {static_cast(5), static_cast(5)}) - .get()}); + {static_cast(5), static_cast(5)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { @@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, {{{static_cast(1.f)}, {static_cast(1.f)}}, - {{static_cast(3.f)}, {static_cast(3.f)}}}}) - .get(), + {{static_cast(3.f)}, {static_cast(3.f)}}}}), LiteralUtil::CreateR1( - {static_cast(0), static_cast(0)}) - .get(), + {static_cast(0), static_cast(0)}), LiteralUtil::CreateR1( - {static_cast(16), static_cast(20)}) - .get()}); + {static_cast(16), static_cast(20)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index fe4267c73b..dde19fb65d 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = - client_->TransferToServer(*r3_data).ConsumeValueOrDie(); + client_->TransferToServer(r3_data).ConsumeValueOrDie(); return r3_global_data; } @@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = - client_->TransferToServer(*r2_data).ConsumeValueOrDie(); + client_->TransferToServer(r2_data).ConsumeValueOrDie(); return r2_global_data; } @@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); @@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R3ImplicitBroadcastSpec { @@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, - {r3_implicit_global_data.get(), r3_global_data.get()}, + &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()}, ErrorSpec(1e-7, 1e-7)); } @@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, + ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}, {2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}, {2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + &b, LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R2ImplicitBroadcastSpec { @@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, + &builder, expected, {r2_implicit_global_data1.get(), r2_global_data.get(), r2_implicit_global_data2.get()}, ErrorSpec(1e-6, 1e-6)); @@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1}, {2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1, {0}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {1}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {2}); auto expected = LiteralUtil::CreateR3( {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { @@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = ConstantR1(&b, {100, 200}); auto r1_2 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = Add(r1_0, r3, {0}); r3 = Add(r3, r1_1, {1}); @@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { @@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { @@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 74d4d2eb10..9966e4606e 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0(42.0), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0(42.0), result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), result, error_spec_)); } @@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralSlice(*result, {0}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + LiteralSlice(result, {0}), error_spec_)); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralSlice(*result, {1}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + LiteralSlice(result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), result, error_spec_)); } @@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), result, error_spec_)); } @@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_)); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index b1d18210ea..8b31e53707 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = - ConstantLiteral(&builder, *LiteralUtil::CreateR0(42.0)); + auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0(42.0)); Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); - auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); + auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({1.0f, 2.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({1.0f, 2.0f})); auto y = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({2.0f, 3.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({2.0f, 3.0f})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr start, - client_->TransferToServer(*LiteralUtil::CreateR0(1.0f))); + client_->TransferToServer(LiteralUtil::CreateR0(1.0f))); ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); } @@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); auto elem = LiteralUtil::CreateR0(42.0); - auto tuple = LiteralUtil::MakeTuple({elem.get()}); - Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); + auto tuple = LiteralUtil::MakeTuple({&elem}); + Call(&builder, callee, {ConstantLiteral(&builder, elem)}); - ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); + ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index a4eb57fc7b..2f1510ff69 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); - auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1"); Add(p0, p1); auto param0_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto param1_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); @@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { auto computation = computation_status.ConsumeValueOrDie(); auto f32_literal = LiteralUtil::CreateR0(1.1f); - auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); + auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie(); auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = - client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); + client_->TransferToServer(f32_4_literal).ConsumeValueOrDie(); auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); - auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); + auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie(); // Match auto status = client_->Execute( diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 8a236db0ff..fbdf0fcb65 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -101,7 +101,7 @@ StatusOr> ClientLibraryTestBase::Execute( return client_->Execute(computation, arguments, &execution_options_); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -113,7 +113,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. @@ -121,8 +121,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } -StatusOr> -ClientLibraryTestBase::ExecuteAndTransferReference( +StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return result.ValueOrDie()->ToString(); + return result.ValueOrDie().ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, absl::Span arguments) { - std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const string& error_message)>& verify_output) { // Try with no layout requirement. TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); - verify_output(*actual, ""); + verify_output(actual, ""); // Try with all output layouts. std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); @@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, + verify_output(actual, absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); @@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal->shape())) { + if (ShapeUtil::IsTuple(literal.shape())) { layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal->shape())); + ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); + std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = - literal->Relayout(LayoutUtil::MakeLayout(minor_to_major)); + literal.Relayout(LayoutUtil::MakeLayout(minor_to_major)); layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal_relayout->shape())); + ShapeUtil::HumanStringWithLayout(literal_relayout.shape())); TF_ASSIGN_OR_RETURN(auto data, - client_->TransferToServer(*literal_relayout)); + client_->TransferToServer(literal_relayout)); arguments_with_layout.push_back(data.get()); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( for (const auto& str : layout_strings) { absl::StrAppend(&error_message, str, " "); } - verify_output(*actual, error_message); + verify_output(actual, error_message); return Status::OK(); }; @@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)); return Status::OK(); } @@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)); return Status::OK(); } @@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + Literal expected_literal = LiteralUtil::CreateR1U8(expected); - VLOG(1) << "expected: " << expected_literal->ToString(); - VLOG(1) << "actual: " << actual->ToString(); + VLOG(1) << "expected: " << expected_literal.ToString(); + VLOG(1) << "actual: " << actual.ToString(); - EXPECT_EQ(expected, actual->GetR1U8AsString()); + EXPECT_EQ(expected, actual.GetR1U8AsString()); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare( if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(reference, result)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare( if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error)); } -StatusOr, std::unique_ptr>> +StatusOr> ClientLibraryTestBase::ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's @@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return ConstantLiteral(builder, use_bfloat16_ - ? *LiteralUtil::ConvertF32ToBF16(literal) - : literal); + ? LiteralUtil::ConvertF32ToBF16(literal) + : LiteralSlice(literal)); } std::unique_ptr @@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( const Literal& literal) { if (use_bfloat16_) { - return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + return LiteralUtil::ConvertF32ToBF16(literal); } return literal.Clone(); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 22dfdfb0e4..9d32f4f517 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> Execute( XlaBuilder* builder, absl::Span arguments); - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); @@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test { // This executes the computation via the reference client (which connects a // interpreter backend). The result is used as the expected values of the // computation. - StatusOr> ExecuteAndTransferReference( + StatusOr ExecuteAndTransferReference( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); @@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { - return AddParam(*LiteralUtil::CreateFromArray(argument), builder); + return AddParam(LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array), builder); } // Same as CreateConstantFromArray, but for scalars. template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateR0(value), + return CreateConstantFromLiteral(LiteralUtil::CreateR0(value), builder); } @@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test { // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - absl::Span arguments); + StatusOr> ComputeValueAndReference( + XlaBuilder* builder, absl::Span arguments); Client* client_; Client* ref_client_; // To compute reference result. @@ -412,9 +411,8 @@ template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, absl::Span arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -438,9 +435,8 @@ template void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, absl::Span expected, absl::Span arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -464,9 +459,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -490,9 +485,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -516,9 +511,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -542,13 +537,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR0(value); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR0(value); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -556,13 +551,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR1(values); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR1(values); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -570,13 +565,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -584,13 +579,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index c898dacf48..6f2ca84bb6 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::unique_ptr data, client_->Execute(computation, {}, &execution_options)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( - auto computed, client_->Transfer(*data, &expected_literal->shape())); + auto computed, client_->Transfer(*data, &expected_literal.shape())); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } @@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralSlice(*result, {0})); + LiteralSlice(result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralSlice(*result, {1})); + LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 0), + ShapeUtil::GetTupleElementShape(result.shape(), 0), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}))); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 1), + ShapeUtil::GetTupleElementShape(result.shape(), 1), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0}))); } @@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr const_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); + LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); Add(Parameter(&b, 0, shape, "param_0"), @@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN( auto result_literal, - client_->Transfer(*results[0], &expected_result->shape())); + client_->Transfer(*results[0], &expected_result.shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 03d5696499..6ef7ca035f 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase { absl::Span arguments, float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; - std::unique_ptr result = + Literal result = client_ ->ExecuteAndTransfer(computation, arguments, /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR0(expected_result), *result, error_spec_)); + LiteralUtil::CreateR0(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase { ->Execute(computation, arguments, &execution_options_, &execution_profile) .ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data_handle).ConsumeValueOrDie(); + Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2(expected_result), *result, error_spec_)); + LiteralUtil::CreateR2(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(456.0f)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); @@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { auto rowmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = - client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(rowmaj_array).ConsumeValueOrDie(); auto colmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = - client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(colmaj_array).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 8226b6de3f..3b0414a604 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test { LOG(FATAL) << "invalid client_type value"; } - StatusOr> ComputeConstantLiteral( - Client* client, const XlaOp& operand, XlaBuilder* builder, - Layout* output_layout = nullptr) { + StatusOr ComputeConstantLiteral(Client* client, const XlaOp& operand, + XlaBuilder* builder, + Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); TF_ASSIGN_OR_RETURN(auto computed, client->ComputeConstant(subgraph, output_layout)); @@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test { XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, builder, nullptr)); - return literal->Get({}); + return literal.Get({}); } bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { @@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR1({4, 6}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR0(5); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) { ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index be017477d8..9811a015e9 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); auto x_literal = LiteralUtil::CreateR0(2.f); auto y_literal = LiteralUtil::CreateR0(3.f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, f32_scalar, "x"); @@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "z"); auto bcast = Broadcast(y, {5}); @@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "y"); auto y_bcast = Broadcast(y, {1, 5, 7}); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 25d10ab00a..32cac499c7 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12.0f).get(), - LiteralUtil::CreateR0(25.0f).get()}), + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(12.0f), + LiteralUtil::CreateR0(25.0f)}), {pred_arg.get()}, error_spec_); } @@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, CreateR1TupleFloorComputation()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({13.0f, 16.0f}).get(), - LiteralUtil::CreateR1({26.0f, 30.0f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({13.0f, 16.0f}), + LiteralUtil::CreateR1({26.0f, 30.0f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of a predicate, a @@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(true).get(), - LiteralUtil::CreateR0(12.2f).get(), - LiteralUtil::CreateR1({12.8f, 14.6f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(true), + LiteralUtil::CreateR0(12.2f), + LiteralUtil::CreateR1({12.8f, 14.6f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a nested tuple. @@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(46.6f).get(), - LiteralUtil::CreateR1({54.4f, 58.4f}).get()}) - .get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({62.1f, 67.4f}).get(), - LiteralUtil::CreateR0(9.3f).get()}) - .get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(46.6f), + LiteralUtil::CreateR1({54.4f, 58.4f})}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({62.1f, 67.4f}), + LiteralUtil::CreateR0(9.3f)})}), {pred_arg.get()}, error_spec_); } @@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(a).get(), - LiteralUtil::CreateR0(b).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(a), LiteralUtil::CreateR0(b)}), {x_arg.get(), y_arg.get()}, error_spec_); }; @@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { { // Pred is true case. std::vector args; - args.push_back(std::move( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), - LiteralUtil::CreateR0(-42).get()}))); - args.push_back(std::move(*LiteralUtil::CreateR0(true))); + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(true)); XlaBuilder builder(TestName() + ".main"); auto p = Parameter(&builder, 0, tuple2, "p0"); auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); @@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { { // Pred is false case. std::vector args; - args.push_back(std::move( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), - LiteralUtil::CreateR0(-42).get()}))); - args.push_back(std::move(*LiteralUtil::CreateR0(false))); + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(false)); XlaBuilder builder(TestName() + ".main"); auto p = Parameter(&builder, 0, tuple2, "p0"); auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 4937574831..72ff1e74a4 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D( + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D( Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); @@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4D(input_array); + Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array); { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *input_literal); + ConstantLiteral(&builder, input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})})); - std::unique_ptr result = - ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); + Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, - LiteralSlice(*result, {0}), error_spec_); - LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(*result, {1}), + LiteralSlice(result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(result, {1}), error_spec_); } TEST_F(ConstantsTest, Token) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateToken()); + ConstantLiteral(&builder, LiteralUtil::CreateToken()); // TODO(b/80000000): tokens cannot be returned from computations. Tuple(&builder, {}); TF_ASSERT_OK(Execute(&builder, {}).status()); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 7a203d6873..5f063e6784 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, U32); @@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( @@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 38b6da4fa9..fd98bf29b8 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = absl::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index e0a1538850..070b092d18 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {7.0f, 8.0f}, })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); - auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); + auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); - auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); + auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r5).ConsumeValueOrDie(); + client_->TransferToServer(filter_r5).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r5, + ComputeAndCompareLiteral(&builder, expected_r5, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(92115), static_cast(93150), static_cast(94185)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(16029), static_cast(16218), static_cast(16407), @@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { static_cast(18369), static_cast(18576), static_cast(18783), static_cast(19620), static_cast(19836), static_cast(20052), static_cast(20925), static_cast(21150), static_cast(21375)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(5076), static_cast(5160), static_cast(5244), static_cast(5328), static_cast(6164), static_cast(6264), static_cast(6364), static_cast(6464), static_cast(7380), static_cast(7496), static_cast(7612), static_cast(7728)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, expected_result.Fill(0); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(param0)), - std::move(*LiteralUtil::CreateFromArray(param1))}, + {LiteralUtil::CreateFromArray(param0), + LiteralUtil::CreateFromArray(param1)}, error_spec_); } @@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1.0f)); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(1.0f)); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); std::vector expect_elems(batch * output_feature * num_windows, static_cast(window_size * input_feature)); auto expected_r1 = LiteralUtil::CreateR1(expect_elems); - auto expected_r3 = - expected_r1->Reshape({batch, num_windows, output_feature}) - .ConsumeValueOrDie(); + auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r3).ConsumeValueOrDie(); + client_->TransferToServer(input_r3).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r3, + client_->TransferToServer(filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, expected_r3, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } @@ -891,9 +890,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { Array4D filter_data(1, 1, 1, 2); filter_data.FillIota(10); - ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}); + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}); } XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { @@ -928,8 +926,7 @@ XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums, /*feature_group_count=*/64); - ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data))}, + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 6784c16715..ba3e9c436e 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { auto gradients_flat = LiteralUtil::CreateR1({1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); auto weights_literal = - weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto weights = ConstantLiteral(&builder, *weights_literal); + weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto weights = ConstantLiteral(&builder, weights_literal); auto expected_flat = LiteralUtil::CreateR1({10}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto mirrored_weights = Rev(weights, {2, 3, 4}); ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { @@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); auto activations_literal = - activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); - auto activations = ConstantLiteral(&builder, *activations_literal); + activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); + auto activations = ConstantLiteral(&builder, activations_literal); auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto forward_conv = ConvGeneralDilated(activations, gradients, @@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); Transpose(forward_conv, {0, 1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 526626c1dd..1407e68d9a 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase { protected: void TestCopyOp(const Literal& literal) { auto builder = HloComputation::Builder(TestName()); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + auto constant = + builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone())); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); + Literal result = ExecuteAndTransfer(std::move(module), {}); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); @@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase { }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*LiteralUtil::CreateR0(true)); + TestCopyOp(LiteralUtil::CreateR0(true)); } XLA_TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*LiteralUtil::CreateR1({})); + TestCopyOp(LiteralUtil::CreateR1({})); } XLA_TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); + TestCopyOp(LiteralUtil::CreateR1({1, 2, 3})); } XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4( + TestCopyOp(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } XLA_TEST_F(CopyOpTest, CopyParameterScalar) { @@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = LiteralUtil::CreateR0(42.0); - Shape shape = literal->shape(); + Shape shape = literal.shape(); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {literal.get()}); - LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {&literal}); + LiteralTestUtil::ExpectR0Near(42.0f, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { @@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, *result, + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. - Layout* literal_layout = - literal->mutable_shape_do_not_use()->mutable_layout(); + Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); literal_layout->mutable_minor_to_major()->SwapElements(0, 1); @@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); // The result of the computation has the default layout, which is the inverse // of the layout of the source literal. - LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, *result, + LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, result, error_spec_); } @@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + Literal literal = LiteralUtil::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR3EqualArray3D(a, *result); + LiteralTestUtil::ExpectR3EqualArray3D(a, result); } void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, @@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + Literal literal = LiteralUtil::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR4EqualArray4D(a, *result); + LiteralTestUtil::ExpectR4EqualArray4D(a, result); } XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) { @@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { XlaBuilder builder(TestName()); Parameter(&builder, 0, in_shape, "input"); - auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); + auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index d12a4e7fcd..410732c07b 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal = LiteralUtil::CreateR1({1, 2, 3}); - EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); + EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal})); } XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { @@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ( - *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0, &literal1})); } // On the GPU backend, constants get special handling. Someone might pass a @@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 6f7fc0e6e5..a693fa3595 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(44.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { @@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(10.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, @@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest, module->AddEntryComputation(b.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR3EqualArray3D( - Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); + Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } class CustomCallClientAPITest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index eb15fc0593..e0f23b0fa8 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { // Try copying the elements back and comparing it auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { @@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { auto handles1 = result_status1.ConsumeValueOrDie(); auto handles2 = result_status2.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); handles1[0].reset(); handles1[1].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { @@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { // the same as handle[3] and handle[1] should be the same as handle[2]. auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { @@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { // should not have been deallocated because of reference counting. global_data.reset(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { @@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); Tuple(&builder, {p}); auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 5873516442..0171f51583 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::CreateR2({{5, 6}, {7, 8}}).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), "arg0", &builder, ¶m); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *LiteralUtil::CreateR2({{19, 22}, {43, 50}}), + LiteralUtil::CreateR2({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { auto lhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() { std::unique_ptr> dot_lhs_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); - std::unique_ptr dot_lhs_lit = - LiteralUtil::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( - param.dot_lhs_row_major))); + Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = - client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie(); std::unique_ptr> dot_rhs_data = MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); - std::unique_ptr dot_rhs_lit = + Literal dot_rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = - client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie(); std::unique_ptr> addend_data; - std::unique_ptr addend_lit; + Literal addend_lit; std::unique_ptr addend_handle; if (param.has_addend) { @@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() { addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); - addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie(); } XlaBuilder builder(TestName()); @@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); @@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{9.0f, 10.0f}, {11.0f, 12.0f}}, {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) @@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) .ConsumeValueOrDie(); @@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { auto lhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *lhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *rhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); @@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2( @@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2( diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 9bf3767ca3..7501c6d957 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { // vector is special so that it cannot be a Span, which // is what the code below wants. So instead we do this. Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie(); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { void RunR0(int input_value_int, int update_value_int, const std::vector slice_starts, int expected_value_int) { Literal input_value = - std::move(*LiteralUtil::CreateR0(input_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(input_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_value = - std::move(*LiteralUtil::CreateR0(update_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(update_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_value = - std::move(*LiteralUtil::CreateR0(expected_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(expected_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, absl::Span expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR1(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { - std::unique_ptr literal = - LiteralUtil::CreateR3FromArray3D(values); - LOG(INFO) << name << ":" << literal->ToString(); + Literal literal = LiteralUtil::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << literal.ToString(); } }; @@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) { auto input_literal = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - auto input = ConstantLiteral(&builder, *input_literal); + auto input = ConstantLiteral(&builder, input_literal); // Create dynamic slice start indices as a parameter: shape [4] auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); @@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) { auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - stream.get(), *start_indices_literal, buffer)); + stream.get(), start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index 5116e60ca6..b08ece0e63 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr input, client_->TransferToServer( - *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index bf1de02ba9..738f2600d4 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -38,7 +38,7 @@ class ExhaustiveF32ElementwiseOpTest XlaBuilder builder(TestName()); - std::unique_ptr input_literal = + Literal input_literal = LiteralUtil::CreateFromDimensions(F32, {input_size}); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 7cb2f0cedf..9c94acb437 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -117,9 +117,9 @@ class FusionTest : public HloTestBase { auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4))); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } } @@ -222,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.5}, {2.72}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -248,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -283,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { // Every element of result should be y = x^2 = 4.0. for (int i = 0; i < rand_dim0_size; ++i) { for (int j = 0; j < dim1_size; ++j) { - EXPECT_EQ(4.0, result->Get({i, j})); + EXPECT_EQ(4.0, result.Get({i, j})); } } } @@ -308,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -323,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(5), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -338,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -353,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -368,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -383,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{7}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -398,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -413,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -428,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -443,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -459,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({3, 2, 1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -477,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-3, -2, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -495,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -513,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -535,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-2, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -552,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, TransposeNegate) { @@ -570,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -602,8 +602,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { HloInstruction::FusionKind::kInput); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -624,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(-15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -674,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -710,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) { EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({8}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } @@ -782,19 +782,17 @@ ENTRY main { } )"; - std::unique_ptr operand = - LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); + Literal operand = LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_text, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, - test_runner_.Execute(std::move(module), {operand.get()}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + test_runner_.Execute(std::move(module), {&operand}, + /*run_hlo_passes=*/false)); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), - *result)); + LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + result)); } class FusionClientLibraryTest : public ClientLibraryTestBase {}; @@ -821,16 +819,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { // where overflow is OK. Array2D arr(32, 32); arr.FillUnique(); - std::unique_ptr l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({0, 1})); - std::unique_ptr l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({1, 0})); - XlaOp p0 = AddParam(*l1, &b); + XlaOp p0 = AddParam(l1, &b); XlaOp sum = p0; for (int i = 1; i < kNumParams; ++i) { - auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b); sum = sum + p0 * pN * pN; } @@ -879,19 +877,19 @@ void BM_ParallelFusion(int num_iters) { auto param0_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); ScopedShapedBuffer buffer0 = - client->LiteralToShapedBuffer(*param0_literal, device_ordinal) + client->LiteralToShapedBuffer(param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); ScopedShapedBuffer buffer1 = - client->LiteralToShapedBuffer(*param1_literal, device_ordinal) + client->LiteralToShapedBuffer(param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); ScopedShapedBuffer buffer2 = - client->LiteralToShapedBuffer(*param2_literal, device_ordinal) + client->LiteralToShapedBuffer(param2_literal, device_ordinal) .ConsumeValueOrDie(); // Build executable. diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 6d63498044..daa89398a6 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -58,10 +58,10 @@ ENTRY main { slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { @@ -79,10 +79,10 @@ ENTRY main { slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { @@ -100,11 +100,10 @@ ENTRY main { slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { @@ -122,11 +121,11 @@ ENTRY main { slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { @@ -144,11 +143,11 @@ ENTRY main { slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { @@ -166,13 +165,12 @@ ENTRY main { slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { @@ -190,13 +188,12 @@ ENTRY main { slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, DynamicSlice) { @@ -214,10 +211,10 @@ ENTRY main { slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { @@ -235,11 +232,10 @@ ENTRY main { slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { @@ -257,9 +253,9 @@ ENTRY main { slice_sizes={1, 0} } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { @@ -281,11 +277,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { @@ -307,11 +303,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { @@ -333,11 +329,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { @@ -359,11 +355,11 @@ ENTRY main { ROOT result = u32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -381,10 +377,10 @@ ENTRY main { slice_sizes={1,3,2} } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr start_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ScalarResult) { @@ -402,9 +398,9 @@ ENTRY main { slice_sizes={1} } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr start_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { @@ -422,10 +418,10 @@ ENTRY main { slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { @@ -446,10 +442,10 @@ ENTRY main { ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { @@ -470,11 +466,10 @@ ENTRY main { ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { @@ -495,11 +490,11 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { @@ -520,13 +515,12 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, @@ -548,13 +542,12 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { @@ -575,10 +568,10 @@ ENTRY main { ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { @@ -599,11 +592,10 @@ ENTRY main { ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } class GatherClientLibraryTest : public ClientLibraryTestBase {}; @@ -640,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr operand_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr indices_arg, - client_->TransferToServer(*LiteralUtil::CreateR1({0, 2}))); + client_->TransferToServer(LiteralUtil::CreateR1({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); @@ -657,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::vector> result_data, client_->ExecuteParallel(computation_instances)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, - *result_literal); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 3df99aac7d..bdd4fd7e3d 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -136,21 +136,21 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { return debug_options; } -StatusOr> HloTestBase::Execute( - std::unique_ptr module, absl::Span arguments) { +StatusOr HloTestBase::Execute(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments); } -std::unique_ptr HloTestBase::ExecuteNoHloPasses( - std::unique_ptr module, absl::Span arguments) { +Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments) { return test_runner_ .Execute(std::move(module), arguments, /*run_hlo_passes=*/false) .ValueOrDie(); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, absl::Span arguments) { +Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } @@ -188,7 +188,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( TF_ASSIGN_OR_RETURN(auto reference, reference_runner_.Execute(std::move(reference_module), arguments, run_hlo_passes)); - return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, error); } @@ -223,13 +223,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( ::testing::AssertionResult HloTestBase::RunAndCompare( std::unique_ptr module, const optional& error, const std::function& reference_preprocessor) { - const auto& fake_arguments = - MakeFakeArguments(module.get()).ConsumeValueOrDie(); + auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompare(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -243,7 +242,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -277,7 +276,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return test_runner_ .Execute(std::move(module_or_status.ValueOrDie()), fake_argument_ptrs, /*run_hlo_passes=*/true) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 21d77c0cc4..0ae4bdc104 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -115,16 +115,16 @@ class HloTestBase : public ::testing::Test { } // Executes the given module and return the result as a Literal. - StatusOr> Execute( - std::unique_ptr module, absl::Span arguments); + StatusOr Execute(std::unique_ptr module, + absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. - std::unique_ptr ExecuteNoHloPasses( - std::unique_ptr module, absl::Span arguments); + Literal ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments); - std::unique_ptr ExecuteAndTransfer( - std::unique_ptr module, absl::Span arguments); + Literal ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments); // Executes the given hlo module on two backends and compares results. // diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 96f72212f3..43cca91f64 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -155,20 +155,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR0(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( absl::Span expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2(expected), actual)); } template @@ -176,46 +176,46 @@ template std::initializer_list>> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR0(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2(expected), actual, error)); } template @@ -223,7 +223,7 @@ template std::initializer_list>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3(expected), actual, error)); } template @@ -232,28 +232,28 @@ template std::initializer_list>>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 4151bfae03..b6f9b8156b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,11 +31,11 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal lhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - std::unique_ptr rhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(64).get(), - LiteralUtil::CreateR0(42).get(), + Literal rhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(64), + LiteralUtil::CreateR0(42), }); - CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; + CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal"; }; ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal"); } @@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto two = LiteralUtil::CreateR0(2); auto four = LiteralUtil::CreateR0(4); ErrorSpec error(0.001); - CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; + CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four"; }; tensorflow::Env* env = tensorflow::Env::Default(); @@ -86,14 +86,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, &literal_proto)); - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal->ToString()); + EXPECT_EQ("2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal->ToString()); + EXPECT_EQ("4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal->ToString()); + EXPECT_EQ("true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -103,8 +103,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto expected = LiteralUtil::CreateR1({1, 2, 3}); auto actual = LiteralUtil::CreateR1({4, 5, 6}); - ::testing::AssertionResult result = - LiteralTestUtil::Equal(*expected, *actual); + ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); EXPECT_THAT(result.message(), @@ -116,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtilTest, NearComparatorR1Nan) { @@ -124,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) { {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtil, NearComparatorDifferentLengths) { @@ -132,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); - EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); - EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 237a4a361e..dbdd20daf0 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { DefaultExecutableBuildOptions(), options); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_); // At least one allocation should have been performed when executing the // computation. @@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { computation, {}, ExecutableBuildOptions().set_device_ordinal(d), ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); // At least one allocation should have been performed when executing the // computation. diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 1a823cf189..a99b43f469 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(123.f, ShapedBufferToLiteral(result), error_spec_); } @@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { auto y = ConstantR0(&builder, 123.0f); Add(x, y); - auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0(42.0f)); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); - LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(165.f, ShapedBufferToLiteral(result), error_spec_); } @@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { auto y = ConstantR1(&builder, {}); Add(x, y); - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1({})); + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1({})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); - LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR1Near({}, ShapedBufferToLiteral(result), error_spec_); } @@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { @@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; ScopedShapedBuffer result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_execution_profile(&profile)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); } @@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with the parameter values in a different order. ScopedShapedBuffer result_param_swap = ExecuteLocallyOrDie(computation, {&y_array, &x_array}); - LiteralTestUtil::ExpectR2Near( - {{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_param_swap), error_spec_); + LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, + ShapedBufferToLiteral(result_param_swap), + error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { @@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( @@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with row-major result layout. @@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_rowmaj), + ShapedBufferToLiteral(result_rowmaj), error_spec_); } @@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {2})); + LiteralSlice(result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); + LiteralSlice(result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); + LiteralSlice(result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralSlice(result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); Shape shape_with_layout = ShapeUtil::MakeTupleShape( @@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array}, options, DefaultExecutableRunOptions()); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}); - auto y_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({2.0, 4.0, 6.0}).get(), - LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + auto x_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}); + auto y_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({2.0, 4.0, 6.0}), + LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}})}); - auto x_buffer = LiteralToShapedBuffer(*x_literal); - auto y_buffer = LiteralToShapedBuffer(*y_literal); + auto x_buffer = LiteralToShapedBuffer(x_literal); + auto y_buffer = LiteralToShapedBuffer(y_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); @@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}) - .get(), - LiteralUtil::CreateR1({222.0, -2.0, 10.0}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}), + LiteralUtil::CreateR1({222.0, -2.0, 10.0})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); + Literal result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); + LiteralSlice(result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, - LiteralSlice(*result_0_literal, {1})); + LiteralSlice(result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); - std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); + Literal result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, - LiteralSlice(*result_1_literal, {0})); + LiteralSlice(result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, - LiteralSlice(*result_1_literal, {1})); + LiteralSlice(result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { // Feed in a tuple where each two-element vector element is {tuple_index, // -tuple_index}. - std::vector> arg_elements; + std::vector arg_elements; for (int i = 0; i < kElementCount; ++i) { arg_elements.push_back(LiteralUtil::CreateR1({1.0f * i, -1.0f * i})); } - std::unique_ptr arg_literal = - LiteralUtil::MakeTupleOwned(std::move(arg_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements)); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_); } } @@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::vector> outer_tuple_elements; + std::vector outer_tuple_elements; for (int i = 0; i < kFanout; ++i) { - std::vector> inner_tuple_elements; + std::vector inner_tuple_elements; for (int j = 0; j < kFanout; ++j) { inner_tuple_elements.push_back(LiteralUtil::CreateR0(i + j)); } @@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { } auto arg_literal = LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { - LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), - error_spec_); + LiteralTestUtil::ExpectR0Near(i + j + i * kFanout + j, + LiteralSlice(result_literal, {i, j}), + error_spec_); } } } @@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::unique_ptr arg_literal = LiteralUtil::CreateR0(123.0); + Literal arg_literal = LiteralUtil::CreateR0(123.0); for (int i = 0; i < kTupleDepth; ++i) { - std::vector> arg_vector; + std::vector arg_vector; arg_vector.push_back(std::move(arg_literal)); arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); } - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal(165.0, - LiteralSlice(*result_literal, index)); + LiteralSlice(result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( @@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { DefaultExecutableRunOptions().set_device_ordinal(d)); EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + ShapedBufferToLiteral(result)); } } } @@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // As a check to verify that the computation ran of the device associated // with the stream. This is a weak check, but stronger verification is hard. EXPECT_EQ(d, result.device_ordinal()); - LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + LiteralTestUtil::ExpectR0Equal(42.0f, ShapedBufferToLiteral(result)); } } @@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); + Literal tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, - LiteralSlice(*tuple_literal, {0})); + LiteralSlice(tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, - LiteralSlice(*tuple_literal, {1})); + LiteralSlice(tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { @@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { executable_status.ConsumeValueOrDie(); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { @@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; // Array shapes. - test_to_device_and_back(*LiteralUtil::CreateR0(42.0)); - test_to_device_and_back(*LiteralUtil::CreateR0(true)); - test_to_device_and_back(*LiteralUtil::CreateR1({1.0, 42.0, 744.4})); + test_to_device_and_back(LiteralUtil::CreateR0(42.0)); + test_to_device_and_back(LiteralUtil::CreateR0(true)); + test_to_device_and_back(LiteralUtil::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). - test_to_device_and_back(*LiteralUtil::MakeTuple({})); + test_to_device_and_back(LiteralUtil::MakeTuple({})); // Non-nested tuples. - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12223.0).get()})); - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(12223.0)})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)})); // Nested tuple. - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()}) - .get(), - LiteralUtil::CreateR0(false).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)}), + LiteralUtil::CreateR0(false)})); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { @@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); test_to_device_and_back( - *LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456789000LL).get()})); + LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456789000LL)})); } XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { @@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); Add(in, constant); - std::unique_ptr result; + Literal result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); // Join the thread. thread.reset(); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { @@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + TF_ASSERT_OK_AND_ASSIGN(Literal result, local_client_->TransferFromOutfeedLocal( shape, local_client_->default_device_ordinal())); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } // Benchmark that measures the overhead of the LocalClient API when running a @@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) { auto literal = LiteralUtil::CreateR2({{0, 0, 0}, {0, 0, 0}}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, - buffer)); + ASSERT_IS_OK( + transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer)); const int kWarmups = 2; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index a8c68fc7fd..f90ef22d2d 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer( .ConsumeValueOrDie(); } -std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( +Literal LocalClientTestBase::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { return local_client_->ShapedBufferToLiteral(shaped_buffer) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 90095c5d41..4027c7b124 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -86,8 +86,7 @@ class LocalClientTestBase : public ::testing::Test { // Construct and return a literal containing the array represented by // shaped_buffer. - std::unique_ptr ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Execute the given computation on the local client. With and without // options. diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 0732e195d4..4d327a6fe9 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + Literal param0_literal = LiteralUtil::CreateR0(42.0); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, @@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, @@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, @@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( @@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( @@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); @@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2WithLayout( + Literal param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = LiteralUtil::CreateR2WithLayout( + Literal param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1}); @@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1, 2}); @@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - std::unique_ptr param2_literal = + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); + Literal param2_literal = LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = - client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); + client_->TransferToServer(param2_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); - auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); + auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2"); Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( @@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) { Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); @@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, @@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, sub_opposite, {}); ComputeAndCompareR0( @@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) { Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + Literal param0_literal = LiteralUtil::CreateR0(10.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index edb592f43e..3f278115e0 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -63,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { }); Exp(data); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 {0.36788f, 1.64872f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { @@ -92,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { }); Map(&builder, {data}, add_half, {0, 1}); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 {-0.5f, 1.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { @@ -111,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { }); Max(lhs, rhs); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 {3.0f, -4.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6)); } struct TestLinspaceMaxParam { @@ -200,14 +200,12 @@ class MatOpsDotAddTest TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index c5e0b9b097..56aaeb0e68 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -114,10 +114,10 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); + Literal literal_r0 = LiteralUtil::CreateR0(-9.0f); auto actual = - ExecuteAndTransfer(std::move(hlo_module), - {LiteralUtil::CreateR0(-9.0f).get(), &arg1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1}); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -178,10 +178,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = - std::move(*LiteralUtil::CreateR1({size * 1.5f * 3.5f})); + Literal expect = LiteralUtil::CreateR1({size * 1.5f * 3.5f}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } }; @@ -218,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { LiteralUtil::CreateR0(1.0)), LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), LiteralUtil::CreateR0(4))); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -247,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -280,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); } const char* const kScalarOps = R"( @@ -324,13 +320,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -356,13 +351,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -389,13 +383,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), - LiteralUtil::CreateR1({36, 64}), - LiteralUtil::CreateR1({66, 138})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR1({36, 64}), + LiteralUtil::CreateR1({66, 138})), + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -422,14 +415,13 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -456,15 +448,14 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -492,16 +483,15 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR1({14, 22}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR3( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -530,13 +520,13 @@ XLA_TEST_F(MultiOutputFusionTest, LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); auto init2 = LiteralUtil::CreateR0(6); - std::unique_ptr result = ExecuteNoHloPasses( - std::move(module), {param.get(), init1.get(), init2.get()}); + Literal result = + ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{167, 172}, {176, 180}}), LiteralUtil::CreateR2({{6, 6}, {6, 8}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -565,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest, auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}}), LiteralUtil::CreateR3( @@ -576,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}})), - *result)); + result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc index 0a0426adcb..f2460822a6 100644 --- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { GetTupleElement(result_tuple, 0); TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { VLOG(1) << "Transferring trip count to computation"; // Transfer number of iterations to Infeed. TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(1))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(1))); // Pick up value from outfeed { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 1); + EXPECT_EQ(r.Get({}), 1); } VLOG(1) << "Writing data to infeed"; // Transfer some stuff to Infeed for use inside of loop. TF_ASSERT_OK(local_client_->TransferToInfeed( - *LiteralUtil::CreateR1({10, 20}))); + LiteralUtil::CreateR1({10, 20}))); // Pick up value from outfeed { VLOG(1) << "Reading from body outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&xfeed_shape)); - EXPECT_EQ(r->Get({0}), 11); - EXPECT_EQ(r->Get({1}), 21); + EXPECT_EQ(r.Get({0}), 11); + EXPECT_EQ(r.Get({1}), 21); } { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 0); + EXPECT_EQ(r.Get({}), 0); } // Joins the thread thread.reset(); - EXPECT_EQ(comp_result->Get({}), 0); + EXPECT_EQ(comp_result.Get({}), 0); } XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { @@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { })); TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(true))); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&result_shape)); - EXPECT_EQ(r->Get({}), true); + EXPECT_EQ(r.Get({}), true); // Join the thread thread.reset(); diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index cbeddffacf..6e98167739 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - Pad(AddParam(*LiteralUtil::CreateR1({1, 2, 3}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({1, 2, 3}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } @@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*LiteralUtil::CreateR0(1.5), &b), + AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); @@ -148,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), + Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = absl::make_unique>(2, 3, 3, 2); @@ -168,7 +168,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); Pad(AddParam(input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = absl::make_unique>(8, 5, 1, 1); @@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - Pad(input, AddParam(*LiteralUtil::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - Pad(input, AddParam(*LiteralUtil::CreateR0(3.14f), &b), - padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -452,13 +451,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - Reduce(input, AddParam(*LiteralUtil::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(LiteralUtil::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - Pad(reduce, AddParam(*LiteralUtil::CreateR0(0.0f), &b), - padding_config); + Pad(reduce, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index f6c762e7a4..dcb4c11c3c 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR0(3.14159f); + Literal param0_literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); @@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); @@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); @@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + Literal param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), @@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); @@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); @@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); // Use both parameters // @@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + Literal literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); @@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + Parameter(&builder, 1, literal1.shape(), "param1"); ComputeAndCompareR1(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = - LiteralUtil::CreateR1({10, 20, 30}); + Literal literal1 = LiteralUtil::CreateR1({10, 20, 30}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&builder, 2, literal1->shape(), "param2"); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&builder, 2, literal1.shape(), "param2"); // This add is unused. Add(param1, param2); @@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + Literal literal = LiteralUtil::CreateR1(sum_value); param_data_owner.push_back( - client_->TransferToServer(*literal).ConsumeValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + client_->TransferToServer(literal).ConsumeValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest, constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR0(i); + Literal literal = LiteralUtil::CreateR0(i); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); sum_handle = Add(sum_handle, param); } @@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({target + i, target + i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest, std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); - parameter_shapes.push_back(literal->shape()); + parameter_shapes.push_back(literal.shape()); } // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr bool_literal = LiteralUtil::CreateR0(false); + Literal bool_literal = LiteralUtil::CreateR0(false); param_data_owner.push_back( - std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + std::move(client_->TransferToServer(bool_literal)).ValueOrDie()); XlaOp bool_param = - Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param"); + Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param"); params.push_back(bool_param); - parameter_shapes.push_back(bool_literal->shape()); + parameter_shapes.push_back(bool_literal.shape()); auto init = Tuple(&builder, params); @@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest, param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({i, i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR1({4, 5, 6}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR1({4, 5, 6}), })) .ConsumeValueOrDie(); @@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + Literal literal = LiteralUtil::CreateR2({ {1, 3}, {2, 4}, }); - const Shape original = literal->shape(); + const Shape original = literal.shape(); { // Reverse the layout present in original, and make that the layout of the // literal. @@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { original.layout().minor_to_major().begin(), original.layout().minor_to_major().end()); std::reverse(original_layout.begin(), original_layout.end()); - *literal->mutable_shape_do_not_use()->mutable_layout() = + *literal.mutable_shape_do_not_use()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); - ASSERT_EQ(2, literal->Get({0, 1})); + ASSERT_EQ(2, literal.Get({0, 1})); } // Use the original shape in building the computation. XlaBuilder builder(TestName()); @@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); // Check that we got the off-diagonal value that we expected. Array2D expected(1, 1); expected(0, 0) = 2; diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 5f322b768d..8f2c26f0ee 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -37,8 +37,7 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - std::unique_ptr UniformTest(T a, T b, absl::Span dims, - int64 seed = 42); + Literal UniformTest(T a, T b, absl::Span dims, int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution // of the given range size. `expected_count` is the number of times each @@ -49,9 +48,8 @@ class PrngTest : public ClientLibraryTestBase { }; template -std::unique_ptr PrngTest::UniformTest(T a, T b, - absl::Span dims, - int64 seed) { +Literal PrngTest::UniformTest(T a, T b, absl::Span dims, + int64 seed) { XlaBuilder builder(TestName()); RngUniform( ConstantR0(&builder, a), ConstantR0(&builder, b), @@ -60,8 +58,8 @@ std::unique_ptr PrngTest::UniformTest(T a, T b, SetSeed(seed); auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); - EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - actual->EachCell([=](absl::Span, T value) { + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions())); + actual.EachCell([=](absl::Span, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); @@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { constexpr int64 count = 100; for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); - result->Literal::EachCell( - [&](absl::Span, bfloat16 value) { - int64 index = static_cast((value - low) / interval); - counts[index]++; - }); + result.EachCell([&](absl::Span, bfloat16 value) { + int64 index = static_cast((value - low) / interval); + counts[index]++; + }); } // Each bucket should have similar amount of counts. That is, not more than // 10% of total counts. This mostly tests that we don't fall into a 1:2:2 @@ -149,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); - actual->EachCell( + actual.EachCell( [&counts](absl::Span, int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { @@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) { }; XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, - client_->TransferToServer(*param0_literal)); + client_->TransferToServer(param0_literal)); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto fn = build_sum_rng(builder); Map(&builder, {param0}, fn, {0}); @@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) { computation, /*arguments=*/{param0_data.get()}, &execution_options)); - EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()), - ShapeUtil::ElementsIn(param0_literal->shape())); - for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) { - EXPECT_GE(actual->data()[i], param0_literal->data()[i]); - EXPECT_LT(actual->data()[i], - param0_literal->data()[i] + 1.0f); + EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()), + ShapeUtil::ElementsIn(param0_literal.shape())); + for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) { + EXPECT_GE(actual.data()[i], param0_literal.data()[i]); + EXPECT_LT(actual.data()[i], param0_literal.data()[i] + 1.0f); } } @@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { ExecutionOptions execution_options2 = execution_options_; execution_options2.set_seed(65); - std::unique_ptr result1; + Literal result1; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options1)); } - std::unique_ptr result2; - std::unique_ptr result3; + Literal result2; + Literal result3; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options1)); } - std::unique_ptr result4; - std::unique_ptr result5; - std::unique_ptr result6; + Literal result4; + Literal result5; + Literal result6; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index 9af9ea4a22..c9096fb29b 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -92,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { *reduce_input_shape->mutable_layout() = LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); - std::unique_ptr reduce_input = LiteralUtil::CreateR4( + Literal reduce_input = LiteralUtil::CreateR4( {{ /*i0=0*/ {/*i1=0*/ {-0.246092796, -0.179497838, -0.161181688}, diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 0916a07f4f..26e2bfde5c 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -231,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR1({input_values}); + Literal a_literal = LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); ReducePrecision(a, exponent_bits, mantissa_bits); @@ -255,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // Abs doesn't affect resolution. auto abs = Abs(a); @@ -284,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -310,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -334,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -359,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 57f7fed61f..83997cdac2 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -81,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase { }, 4); // clang-format on CHECK(ShapeUtil::Equal( - literal_3d_->shape(), + literal_3d_.shape(), ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3}))) - << literal_3d_->shape().ShortDebugString(); + << literal_3d_.shape().ShortDebugString(); } // Runs an R1 => R0 reduction test with the given number of elements. @@ -102,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase { input_data[i] *= -1; } } - std::unique_ptr input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (float item : input_data) { @@ -134,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase { Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + Literal input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); bool expected = and_reduce; for (bool item : input_data) { @@ -175,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(0, 1); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::array expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -209,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (int64 rowno = 0; rowno < rows; ++rowno) { @@ -237,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -295,12 +291,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillUnique(initial_value); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); // NativeT can be bool, and std::vector does not convert to // Span. @@ -352,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase { reference_reduction_function_for_uints, unsigned_int_identity); } - std::unique_ptr literal_2d_; - std::unique_ptr literal_3d_; + Literal literal_2d_; + Literal literal_3d_; uint32 seed_ = 0xdeadbeef; }; @@ -450,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -482,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -511,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - MakeFakeLiteral(input_shape)); + TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape)); - ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4)); + ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4)); } XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { @@ -531,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 major = 0; major < 2; ++major) { @@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { Array2D input(300, 250); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( @@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { Array2D input(150, 130); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; @@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::max()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0(&builder, 1, {}); } @@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::min()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1, 2}); @@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1, 2}); @@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {2}); @@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); + input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); Reduce(input_activations, ConstantR0(&builder, 0.0f), add, GetParam().reduce_dims); @@ -873,11 +864,10 @@ XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) { auto a = ConstantR0(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr b_literal = - LiteralUtil::CreateR1({1.0f, 4.0f}); + Literal b_literal = LiteralUtil::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b = Parameter(&builder, 0, b_literal.shape(), "b"); Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); @@ -904,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest { std::vector input_arr(num_elems, std::numeric_limits::lowest()); auto input_literal = LiteralUtil::CreateR1(input_arr); auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, - max_fn, {0}); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn, + {0}); ComputeAndCompareR0(&builder, initializer, {input_data.get()}); } @@ -952,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr input_literal = - LiteralUtil::CreateR1(operand); + Literal input_literal = LiteralUtil::CreateR1(operand); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr input_literal2 = LiteralUtil::CreateR0(init); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Literal input_literal2 = LiteralUtil::CreateR0(init); std::unique_ptr input_global_data2 = - client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + client_->TransferToServer(input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0( &builder, expected, {input_global_data.get(), input_global_data2.get()}, ErrorSpec(0.0001)); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index a1001296a1..d5de9650f1 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -73,7 +73,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface, absl::Span window_dimensions, absl::Span window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), + auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), @@ -107,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); + LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); ReduceWindow(input, init_value, CreateScalarAddComputation(FloatType(), &builder_), @@ -124,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(42.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(1.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(1.0), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0(43.0), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({100, 1}), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({100, 1}), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), + LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), {}, ErrorSpec(0.00001)); } @@ -161,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -176,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -190,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -207,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -229,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -252,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -277,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -294,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -313,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); Min(Add(lhs, rhs), - CreateConstantFromLiteral(*LiteralUtil::CreateR0(8.0f), b.get())); + CreateConstantFromLiteral(LiteralUtil::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); ReduceWindow( input, - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_), + CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -332,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected), {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -352,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -360,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - auto arg_literal = absl::make_unique(shape); - arg_literal->PopulateWithValue(1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + Literal arg_literal(shape); + arg_literal.PopulateWithValue(1.0f); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); @@ -371,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - auto expected = absl::make_unique(result_shape); - expected->PopulateWithValue(27.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + Literal expected(result_shape); + expected.PopulateWithValue(27.0f); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 9.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -413,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -435,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -457,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -478,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -504,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -521,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -540,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { @@ -556,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -614,11 +608,10 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); input.FillRandom(0.1f, 0.1f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(4); @@ -627,7 +620,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto reducer = param.reducer; if (use_bfloat16() && Product(param.window_bounds) > 128) { @@ -659,12 +652,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - std::unique_ptr expected_literal = - LiteralUtil::CreateFromArray(*expected); + Literal expected_literal = LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( - input_literal->shape().element_type(), - AsInt64Slice(expected_literal->shape().dimensions()), param.layout); - ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()}, + input_literal.shape().element_type(), + AsInt64Slice(expected_literal.shape().dimensions()), param.layout); + ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()}, DefaultErrorSpec(), &expected_shape_with_layout); } }; @@ -1008,12 +1000,11 @@ TEST_P(R3ReduceWindowTest, DoIt) { Array3D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2]); input.FillRandom(0.1f, 0.1f); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); auto reducer = param.reducer; if (use_bfloat16()) { - input_literal = LiteralUtil::ConvertF32ToBF16(*input_literal); + input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); if (Product(param.window_bounds) > 128) { // To avoid numerical issues, force the reducer to be kMax for large bf16 // windows. @@ -1021,9 +1012,9 @@ TEST_P(R3ReduceWindowTest, DoIt) { } } - XlaOp parameter = Parameter(&b, 0, input_literal->shape(), "input"); + XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) @@ -1035,7 +1026,7 @@ TEST_P(R3ReduceWindowTest, DoIt) { /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); - ComputeAndCompare(&b, {std::move(*input_literal)}, DefaultErrorSpec()); + ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P( @@ -1147,12 +1138,11 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(2); for (int i = 0; i < 2; ++i) { @@ -1162,7 +1152,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1178,7 +1168,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1352,11 +1342,11 @@ TEST_P(R1ReduceWindowTest, DoIt) { const float kInitValue = 0.0f; std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); - std::unique_ptr input_literal = + Literal input_literal = LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + auto input_arg = + CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(1); padding[0] = {param.pad_low[0], param.pad_high[0]}; @@ -1365,7 +1355,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1384,7 +1374,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1(*expected), {input_arg.get()}, DefaultErrorSpec()); } diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index d891451381..5cf87e565b 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect 4. - LiteralTestUtil::ExpectR0Equal(4, *literal); + LiteralTestUtil::ExpectR0Equal(4, literal); } XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { @@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*LiteralUtil::CreateR0(2)) + client_->TransferToServer(LiteralUtil::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*LiteralUtil::CreateR0(3)) + client_->TransferToServer(LiteralUtil::CreateR0(3)) .ConsumeValueOrDie(); - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{x_data.get(), y_data.get()}, @@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { .ConsumeValueOrDie(); // Expect 5. - LiteralTestUtil::ExpectR0Equal(5, *literal); + LiteralTestUtil::ExpectR0Equal(5, literal); } TEST_F(ReplayTest, MapPlusTwoOverR1) { @@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect result. - LiteralTestUtil::ExpectR1Equal({3, 4, 5}, *literal); + LiteralTestUtil::ExpectR1Equal({3, 4, 5}, literal); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 17d12715f6..dedc95b5ae 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(1.0f); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + Literal param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = LiteralUtil::CreateR1({-1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) { Array2D input_array(0, 3); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { Array2D input_array(3, 0); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); auto expected_literal = LiteralUtil::CreateFromArray(expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); @@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { {35, 36, 37}, {40, 41, 42}, {45, 46, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); @@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { {45, 16, 26}, {36, 46, 17}, {27, 37, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 1) = 7; auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 4}); auto expected_literal = LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { Reshape(parameter, dimensions, {}); auto expected_literal = LiteralUtil::CreateR0(83.0f); - ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&b, expected_literal, {input.get()}, zero_error_spec_); } } @@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {}, {}); EXPECT_THAT( @@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), @@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); @@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, {1, 0}); - std::unique_ptr actual = + Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); - std::unique_ptr expected = - LiteralUtil::CreateR2FromArray2D(expected_array); + Literal expected = LiteralUtil::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralUtil::ConvertF32ToBF16(*expected); + expected = LiteralUtil::ConvertF32ToBF16(expected); } - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); @@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {{204, 205, 206, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); @@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {{206, 7, 107, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -691,17 +690,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { Array4D input(2, 1, 1, 1); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -712,17 +709,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { Array4D input(2, 1, 4, 1); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -734,12 +729,11 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { Array4D input(5, 10, 2, 3); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); @@ -749,7 +743,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -761,12 +755,11 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { input_array.Each( [&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); @@ -775,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, {2, 3, 0, 1}); - std::unique_ptr output_literal = + Literal output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, &execution_options) @@ -784,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal); - EXPECT_EQ(expected->data(), output_literal->data()); + auto expected = LiteralUtil::ConvertF32ToBF16(input_literal); + EXPECT_EQ(expected.data(), output_literal.data()); } else { - EXPECT_EQ(input_literal->data(), output_literal->data()); + EXPECT_EQ(input_literal.data(), output_literal.data()); } } @@ -798,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); + ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()}); } XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { @@ -813,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaBuilder builder(TestName()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); @@ -830,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); + ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()}); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { @@ -841,24 +834,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { @@ -869,24 +861,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { @@ -897,24 +888,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { @@ -926,24 +916,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { @@ -954,24 +943,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({0, 1, 2, 3})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) - ->Relayout(input_literal->shape().layout()); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal) + .Relayout(input_literal.shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 74ded82ddf..4e55b0d7ac 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -83,25 +83,25 @@ TEST_P(FloatReverseTest, Reverses) { ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); auto r1_literal = LiteralUtil::CreateR1(input_vector); - auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); + auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto a = AddParam(*input_literal, &builder); + auto a = AddParam(input_literal, &builder); Rev(a, spec.reversal); - std::unique_ptr expected = input_literal->CloneToUnique(); + Literal expected = input_literal.Clone(); std::vector output_indices(spec.input_dims.size()); - expected->EachCell([&](absl::Span indices, float) { + expected.EachCell([&](absl::Span indices, float) { for (int64 i = 0; i < indices.size(); ++i) { output_indices[i] = indices[i]; } - float value = input_literal->Get(indices); + float value = input_literal.Get(indices); for (int64 dim : spec.reversal) { output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; } - expected->Set(output_indices, value); + expected.Set(output_indices, value); }); - ComputeAndCompareLiteral(&builder, *expected, {}); + ComputeAndCompareLiteral(&builder, expected, {}); } INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index e692b8c5d5..091a5d2cac 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -38,7 +38,7 @@ namespace { class RoundTripPackedLiteralTest : public ClientLibraryTestBase { protected: // Sends the literal to the server and retrieves it back. - std::unique_ptr RoundTripToServer(const Literal& original) { + Literal RoundTripToServer(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); return client_->Transfer(*data).ConsumeValueOrDie(); @@ -59,12 +59,12 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0, actual->Get({0})); - EXPECT_EQ(24.0, actual->Get({1})); + EXPECT_EQ(42.0, actual.Get({0})); + EXPECT_EQ(24.0, actual.Get({1})); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { @@ -87,18 +87,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({0, 1})); - EXPECT_EQ(64.0f, actual->Get({1, 0})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({0, 1})); + EXPECT_EQ(64.0f, actual.Get({1, 0})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { @@ -121,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({1, 0})); - EXPECT_EQ(64.0f, actual->Get({0, 1})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({1, 0})); + EXPECT_EQ(64.0f, actual.Get({0, 1})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index a8193c2eac..cd5a531603 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase { void RoundTripTest(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); + Literal result = client_->Transfer(*data).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralTestUtil::Equal(original, result)); } }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0(42)); + RoundTripTest(LiteralUtil::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0(42.0)); + RoundTripTest(LiteralUtil::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1({})); + RoundTripTest(LiteralUtil::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); + RoundTripTest(LiteralUtil::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest(LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4({{ + RoundTripTest(LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), - LiteralUtil::CreateR1({2, 3}).get()})); + RoundTripTest(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(1.0), LiteralUtil::CreateR1({2, 3})})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); + RoundTripTest(LiteralUtil::CreateR4FromArray4D(array4d)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 07460a7e01..1dd937a6d0 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -161,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr a_literal = LiteralUtil::CreateR0(value); + Literal a_literal = LiteralUtil::CreateR0(value); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, static_cast(value), {a_data.get()}); } @@ -225,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); - std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); - std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + Literal a_literal = LiteralUtil::CreateR0(2.1f); + Literal b_literal = LiteralUtil::CreateR0(5.5f); + Literal c_literal = LiteralUtil::CreateR0(0.5f); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); std::unique_ptr c_data = - client_->TransferToServer(*c_literal).ConsumeValueOrDie(); + client_->TransferToServer(c_literal).ConsumeValueOrDie(); - XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a"); - XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b"); - XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c"); + XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); + XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b"); + XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c"); Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, @@ -377,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(div_computation, @@ -388,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend / divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -419,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(rem_computation, @@ -430,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend % divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -441,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); Rem(x, ConstantR0(&builder, 80000)); - std::unique_ptr literal = LiteralUtil::CreateR0(87919); - TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); + Literal literal = LiteralUtil::CreateR0(87919); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 1858dcea61..d20dba028a 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -62,13 +62,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { @@ -92,13 +90,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { @@ -123,13 +120,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { @@ -154,13 +149,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { @@ -185,13 +178,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR2( + Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({2, 1}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { @@ -216,13 +208,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { @@ -247,13 +237,12 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { @@ -277,15 +266,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { @@ -309,15 +296,13 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { @@ -341,12 +326,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { @@ -370,13 +354,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - std::unique_ptr updates = - LiteralUtil::CreateR3({{{10}}, {{20}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ZeroDimBounds) { @@ -400,11 +382,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { @@ -429,12 +410,11 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr scatter_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20}, {30, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { @@ -458,13 +438,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { @@ -488,13 +468,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NegativeIndex) { @@ -518,13 +498,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OneScalarIndex) { @@ -548,12 +528,12 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR3({{{10, 20}, {30, 40}, {50, 60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ScalarUpdate) { @@ -577,10 +557,10 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); - std::unique_ptr updates = LiteralUtil::CreateR0(25); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR0(25); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, EmptyIndices) { @@ -604,10 +584,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR1({}); - std::unique_ptr updates = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3}); + Literal scatter_indices = LiteralUtil::CreateR1({}); + Literal updates = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } } // namespace diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c9a58aefb4..a40c2d7de6 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -176,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { XlaBuilder builder(TestName()); auto original = ConstantR4FromArray4D(&builder, values); Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), - &expected_literal->shape()); + ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001), + &expected_literal.shape()); } struct R1Spec { @@ -201,7 +201,7 @@ class SliceR1Test : public ClientLibraryTestBase, auto literal = LiteralUtil::CreateR1(input); XlaBuilder builder(TestName()); - auto original = Parameter(&builder, 0, literal->shape(), "p0"); + auto original = Parameter(&builder, 0, literal.shape(), "p0"); Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); @@ -213,7 +213,7 @@ class SliceR1Test : public ClientLibraryTestBase, } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -376,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) { input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = Parameter(&builder, 0, literal->shape(), "p0"); + auto a = Parameter(&builder, 0, literal.shape(), "p0"); Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR2(&builder, *expected, {arg.get()}); @@ -467,9 +467,9 @@ class SliceR4Test : public ClientLibraryTestBase, XlaBuilder builder(TestName()); auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); - auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); + auto parameter = Parameter(&builder, 0, literal.shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); } diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 3ae31191a0..5155f0c652 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, // array. This is uniqueness is best-effort only. Some types (half and bfloat16) // are not supported and uniqueness cannot be guaranteed if the number of // elements exceeds the number of different values supported by the type. -StatusOr> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { +StatusOr MakeFakeLiteralInternal(const Shape& shape, + std::minstd_rand0* engine, + bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; + std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( - std::unique_ptr element, + Literal element, MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } @@ -131,60 +132,52 @@ StatusOr> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = absl::make_unique(shape); + Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case S8: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U8: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S16: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U16: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S32: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U32: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S64: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U64: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case PRED: { std::uniform_int_distribution generator(0, 1); TF_CHECK_OK( - literal->Populate([&](absl::Span /*indices*/) { + literal.Populate([&](absl::Span /*indices*/) { return generator(*engine); })); break; @@ -236,8 +229,8 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomIndex(absl::Span index_space, - std::minstd_rand0* engine) { +Literal MakeRandomIndex(absl::Span index_space, + std::minstd_rand0* engine) { std::vector start_indices(index_space.size()); if (engine != nullptr) { for (int i = 0; i < index_space.size(); ++i) { @@ -293,7 +286,7 @@ std::vector FindConstrainedUses( // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). -StatusOr> CreateLiteralForConstrainedUses( +StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { std::vector index_space; @@ -358,9 +351,9 @@ StatusOr> CreateLiteralForConstrainedUses( } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: - return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()); case ConstantType::kOne: - return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. @@ -374,34 +367,33 @@ StatusOr> CreateLiteralForConstrainedUses( // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. -StatusOr> MakeConstrainedArgument( - const HloDataflowAnalysis& dataflow, const HloInstruction& param, - std::minstd_rand0* engine) { +StatusOr MakeConstrainedArgument(const HloDataflowAnalysis& dataflow, + const HloInstruction& param, + std::minstd_rand0* engine) { const auto constrained_uses = FindConstrainedUses(dataflow, param); return CreateLiteralForConstrainedUses(constrained_uses, param, engine); } } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random) { +StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random) { +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique() : nullptr; return MakeFakeArguments(module, engine.get()); } -StatusOr>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine) { +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::vector> arguments(params.size()); + std::vector arguments(params.size()); for (int i = 0; i < params.size(); ++i) { arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index a260271b1b..b3c8a73905 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -57,8 +57,8 @@ class PseudorandomGenerator { // Generates fake data in a literal of the given shape, or returns an error // status if the element type is currently unhandled for fake data // generation. See below for documentation of pseudo_random. -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random = true); +StatusOr MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. @@ -84,14 +84,14 @@ StatusOr> MakeFakeLiteral(const Shape& shape, // TODO(b/79942829): Make interesting argument generation fast enough that using // pseudo_random does not save any noticeable amount of time so that the // parameter can be removed. -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random = true); +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random = true); // Overload which accepts a random number generator. This enables generation of // different random values with sequential calls to MakeFakeArguments by reusing // the same generator. -StatusOr>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine); +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 322c8ef090..181e5cbe29 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -85,10 +85,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 3); - const Literal& index_arg = *args[0]; + const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); @@ -114,10 +114,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) })") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 5); - const Literal& index_arg = *args[0]; + const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); @@ -140,10 +140,10 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( } )") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); - const Literal& key_arg = *args[0]; + const Literal& key_arg = args[0]; tensorflow::gtl::FlatSet key_set; for (const float& value : key_arg.data()) { @@ -163,10 +163,10 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( } )") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); - const Literal& key_arg = *args[0]; + const Literal& key_arg = args[0]; tensorflow::gtl::FlatSet key_set; for (const int32& value : key_arg.data()) { diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index c7eb9e2dbe..b34fd0f2e8 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -34,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { @@ -50,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -193,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(true); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(42, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(42, result.Get({})); } { @@ -204,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(false); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(7, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(7, result.Get({})); } } diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 125513ddfd..d6641d257a 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase { }; XLA_TEST_F(TransferManagerTest, TransferR0U32) { - std::unique_ptr literal = LiteralUtil::CreateR0(42); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR0(42); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR0Equal(42, *result); + LiteralTestUtil::ExpectR0Equal(42, result); } XLA_TEST_F(TransferManagerTest, TransferR1F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, - *result); + result); } XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); - std::unique_ptr literal = LiteralUtil::CreateR1(test_vector); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1(test_vector); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR1Equal(test_vector, *result); + LiteralTestUtil::ExpectR1Equal(test_vector, result); } XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; - std::unique_ptr literal = LiteralUtil::CreateR1U8(test_string); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1U8(test_string); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_EQ(result->GetR1U8AsString(), test_string); + EXPECT_EQ(result.GetR1U8AsString(), test_string); } XLA_TEST_F(TransferManagerTest, TransferR2F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferR2F32AndChangeLayoutTransferringToDevice) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); const Shape ondevice_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( - LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); + LayoutUtil::Equal(result.shape().layout(), literal.shape().layout())); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple({}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTuple({}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { - std::unique_ptr literal = LiteralUtil::CreateR1( + Literal literal = LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( + Literal literal = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR1( - {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) - .get(), - LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}).get(), - LiteralUtil::CreateR0(complex64(0.3f, -0.4f)).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}), + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}), + LiteralUtil::CreateR0(complex64(0.3f, -0.4f))}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { @@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { // supported. auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result)); } XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { const int64 kIterationCount = 5000; - std::unique_ptr literal1 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - std::unique_ptr literal2 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(456.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-98.0f, 153.0f}).get()}); - - auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); - auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + Literal literal1 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + Literal literal2 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(456.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}), + LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-98.0f, 153.0f})}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1.shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2.shape()); auto stream1 = stream_; auto stream2 = stream_->GetOrCreateSubStream(); - std::unique_ptr result1, result2; + Literal result1, result2; // Round trip literals through device in multiple streams asynchronously. for (int i = 0; i < kIterationCount; ++i) { - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1, device_buffer1)); - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2, device_buffer2)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result1, + Literal this_result1, transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result2, + Literal this_result2, transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); result1 = std::move(this_result1); result2 = std::move(this_result2); } - EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2)); } class TransferDeviceToHostBenchmark : public TransferManagerTest { @@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); } tensorflow::testing::StopTiming(); @@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); } tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index f2b3b49015..619d2a388b 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests a tuple made of scalar constants. @@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar1).get(), - LiteralUtil::CreateR0(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar1), + LiteralUtil::CreateR0(constant_scalar2)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests the creation of tuple data. @@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) { ConstantR1(&builder, constant_vector), ConstantR2(&builder, constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of tuple data. @@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { Tuple(&builder, {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), - LiteralUtil::CreateR1({}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(7.0), LiteralUtil::CreateR1({})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of an empty tuple. @@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); auto expected = LiteralUtil::MakeTuple({}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. @@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ConstantR2(&builder, constant_matrix)}); Tuple(&builder, {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2(constant_matrix).get(), - LiteralUtil::CreateR1(constant_vector).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2(constant_matrix), + LiteralUtil::CreateR1(constant_vector)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { @@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(direction).get(), - LiteralUtil::CreateR0(!direction).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(direction), + LiteralUtil::CreateR0(!direction)}); - ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()}, error_spec_); } } @@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, TuplesInAMap) { @@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), - LiteralUtil::CreateR1(vec2).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec1), LiteralUtil::CreateR1(vec2)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { @@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { @@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); auto expected_s = LiteralUtil::CreateR0(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + LiteralUtil::MakeTuple({&expected_v1, &expected_s}); auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); - auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { @@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( - { - LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), - }) - .get(), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1.0, 2.0, 3.0}), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}), + }), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}), })) .ConsumeValueOrDie(); @@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) { std::unique_ptr arg0 = client_ - ->TransferToServer(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0({1, 2}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({{10, 20}, {30, 40}}) - .get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0({1, 2}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}), LiteralUtil::CreateR2( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, - {{10000, 20000}, {30000, 40000}}}) - .get()}) - .get()})) + {{10000, 20000}, {30000, 40000}}})})})) .ConsumeValueOrDie(); std::unique_ptr arg1 = client_ ->TransferToServer( - *LiteralUtil::CreateR1({{1, 2}, {1, -2}})) + LiteralUtil::CreateR1({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); auto sum = LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = absl::make_unique(sum->shape()); - ASSERT_TRUE(prod->Populate( - [&sum](absl::Span indexes) { - return sum->Get(indexes) * - (indexes[indexes.size() - 1] == 0 - ? complex64(1, 2) - : complex64(1, -2)); - }) + Literal prod(sum.shape()); + ASSERT_TRUE(prod.Populate([&sum](absl::Span indexes) { + return sum.Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) .ok()); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), - LiteralUtil::CreateR0({123, 456}).get()}); - ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices({prod, sum}), + LiteralUtil::CreateR0({123, 456})}); + ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { .ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); - auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), + result)); } // Disabled on interpreter due to lack of outfeed. @@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest, tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { TF_EXPECT_OK(Execute(std::move(module), - {param0.get(), param1.get(), param1.get(), - param0.get(), param4.get()}) + {¶m0, ¶m1, ¶m1, ¶m0, ¶m4}) .status()); })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); - auto literal = Literal::CreateFromShape(expected->shape()); + auto literal = Literal::CreateFromShape(expected.shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), *literal)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); + backend().default_stream_executor(), expected.shape(), literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 8f80a9f3e4..4fbd7f2fb1 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper() { {-inf(), 0}}); Abs(arg); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper() { {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); Sign(arg); - std::unique_ptr expected = LiteralUtil::CreateR1( + Literal expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper() { auto abs = Abs(arg); Sub(Mul(sign, ConvertElementType(abs, C64)), arg); - std::unique_ptr expected = - LiteralUtil::CreateR1({0, 0, 0, 0}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR1({0, 0, 0, 0}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { @@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { Add(sgnc, ConvertElementType( Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); - std::unique_ptr expected = - LiteralUtil::CreateR0({-2.6f, 0.8f}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR0({-2.6f, 0.8f}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, SignTestR1) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 1bdf1867b9..7abd8651d5 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // have all reached 2.0. auto expected_data = LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = LiteralUtil::MakeTuple({expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { @@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { auto expected_w1 = LiteralUtil::CreateR1({1.0f, 1.0f, 1.0f}); auto expected_w2 = LiteralUtil::CreateR1({2.0f, 2.0f, 2.0f}); auto expected_w3 = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple( + {&expected_counter, &expected_w2, &expected_w3, &expected_w1}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { @@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPredicateTupleResult) { @@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_predicate = LiteralUtil::CreateR0(true); - auto expected = LiteralUtil::MakeTuple( - {expected_counter.get(), expected_predicate.get()}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); + auto expected = + LiteralUtil::MakeTuple({&expected_counter, &expected_predicate}); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0)); } TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { @@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR0(7); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests two while nodes when the result type T is a Tuple and the second @@ -886,10 +883,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests a while node when the result type T is a vector of S32. @@ -977,11 +973,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto expected_element = LiteralUtil::CreateR1({1, 1}); auto expected = - LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({&expected_element, &expected_element}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1005,7 +1001,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareR1(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1031,7 +1027,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(42))); + client_->TransferToServer(LiteralUtil::CreateR0(42))); ComputeAndCompareR0(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1070,12 +1066,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(1))); + client_->TransferToServer(LiteralUtil::CreateR0(1))); auto add1 = LiteralUtil::CreateR0(15); auto add2 = LiteralUtil::CreateR0(16); - auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + auto expected = LiteralUtil::MakeTuple({&add1, &add2}); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1228,7 +1224,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { GetTupleElement(while_instruction, 3); TF_ASSERT_OK_AND_ASSIGN( - auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2( + auto param_value, client_->TransferToServer(LiteralUtil::CreateR2( {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2( @@ -1258,9 +1254,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { XlaBuilder builder(TestName()); While(condition, body, ConstantR0(&builder, 0)); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(false))); ComputeAndCompareR0(&builder, 2, {}); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 7fd42944de..db5a824de0 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -144,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 442e66321e..cdde88c135 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -39,8 +39,7 @@ limitations under the License. namespace xla { -StatusOr> TextLiteralReader::ReadPath( - absl::string_view path) { +StatusOr TextLiteralReader::ReadPath(absl::string_view path) { CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; @@ -57,7 +56,7 @@ StatusOr> TextLiteralReader::ReadPath( TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) : file_(file) {} -StatusOr> TextLiteralReader::ReadAllLines() { +StatusOr TextLiteralReader::ReadAllLines() { tensorflow::io::RandomAccessInputStream stream(file_.get()); tensorflow::io::BufferedInputStream buf(&stream, 65536); string shape_string; @@ -74,9 +73,9 @@ StatusOr> TextLiteralReader::ReadAllLines() { ShapeUtil::HumanString(shape)); } - auto result = absl::make_unique(shape); + Literal result(shape); const float fill = std::numeric_limits::quiet_NaN(); - result->PopulateWithValue(fill); + result.PopulateWithValue(fill); std::vector pieces; std::vector coordinates; std::vector coordinate_values; @@ -116,7 +115,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { "\"%s\"", shape.dimensions_size(), coordinate_values.size(), line); } - result->Set(coordinate_values, value); + result.Set(coordinate_values, value); } return std::move(result); } diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index b265640802..c40b43279f 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -41,7 +41,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr> ReadPath(absl::string_view path); + static StatusOr ReadPath(absl::string_view path); private: // Ownership of file is transferred. @@ -49,7 +49,7 @@ class TextLiteralReader { // Parses a shape string on the first line, followed by lines of values to the // end of the file. - StatusOr> ReadAllLines(); + StatusOr ReadAllLines(); // Owns the file being read std::unique_ptr file_; diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 92f9b4f9f0..1fab4e3a08 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) { tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents) .ok()); - std::unique_ptr literal = - TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); + Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); - EXPECT_EQ(42.5, literal->Get({0, 0, 0})); - EXPECT_EQ(43.5, literal->Get({0, 0, 1})); - EXPECT_EQ(44.5, literal->Get({0, 0, 2})); - EXPECT_EQ(45.5, literal->Get({0, 1, 0})); - EXPECT_EQ(46.5, literal->Get({0, 1, 1})); - EXPECT_EQ(47.5, literal->Get({0, 1, 2})); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape())); + EXPECT_EQ(42.5, literal.Get({0, 0, 0})); + EXPECT_EQ(43.5, literal.Get({0, 0, 1})); + EXPECT_EQ(44.5, literal.Get({0, 0, 2})); + EXPECT_EQ(45.5, literal.Get({0, 1, 0})); + EXPECT_EQ(46.5, literal.Get({0, 1, 1})); + EXPECT_EQ(47.5, literal.Get({0, 1, 2})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 4ea02faffc..5cbaf2fcc1 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) { }); string path = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); - ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path)); + ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path)); string contents; TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &contents)); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ba814af476..0c41f227b3 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -121,11 +121,10 @@ StatusOr ReplayComputation(const HloSnapshot& module, } } else { // use recorded data if available for (const auto& proto : module.arguments()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - Literal::CreateFromProto(proto)); + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto)); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer data, - client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); scoped_shaped_buffer_arguments.push_back(std::move(data)); } for (const auto& argument : scoped_shaped_buffer_arguments) { @@ -161,12 +160,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, // --generate_fake_infeed is passed and there exists an infeed operation in // the HloSnapshot. absl::optional pool; - std::unique_ptr data; + Literal data; if (provide_infeed) { data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); } auto transfer_infeed = [&data, client]() { - TF_CHECK_OK(client->TransferToInfeed(*data)); + TF_CHECK_OK(client->TransferToInfeed(data)); }; if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", @@ -214,9 +213,9 @@ StatusOr ReplayComputation(const HloSnapshot& module, << "s: " << module.hlo().hlo_module().name(); } - TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + TF_ASSIGN_OR_RETURN(Literal result_literal, client->ShapedBufferToLiteral(*result)); - return std::move(*result_literal); + return result_literal; } StatusOr ParseInputFile(const string& filename, @@ -305,11 +304,11 @@ int RealMain(absl::Span args, const Options& opts) { result.ToString().c_str()); auto& snapshot = snapshots[i]; if (snapshot.has_result()) { - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal->ToString().c_str()); + literal.ToString().c_str()); } } } diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 478c9663a7..54b06558ad 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -49,7 +49,7 @@ class XRTStateHelpers { // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an // OpKernel::Compute method. static Status MakeLiteral(const xla::LiteralProto& proto, - std::unique_ptr* literal) { + xla::Literal* literal) { TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto)); return Status::OK(); } @@ -173,7 +173,7 @@ class XRTAllocateOp : public OpKernel { errors::InvalidArgument( "Unable to parse allocation input to XLAAllocation")); - std::unique_ptr literal; + xla::Literal literal; OP_REQUIRES_OK( ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal)); @@ -189,7 +189,7 @@ class XRTAllocateOp : public OpKernel { XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( - *literal, device_ref.backend(), + literal, device_ref.backend(), device_ref.device_ordinal(), &allocation)); // Intern takes ownership of our reference to allocation. @@ -381,11 +381,11 @@ class XRTReadLiteralOp : public OpKernel { OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( ctx, allocation->device_ordinal(), &device_ref)); - std::unique_ptr literal; + xla::Literal literal; OP_REQUIRES_OK( ctx, allocation->ToLiteral(device_ref.backend(), device_ref.device_ordinal(), &literal)); - xla::LiteralProto literal_proto = literal->ToProto(); + xla::LiteralProto literal_proto = literal.ToProto(); Tensor output(DT_STRING, TensorShape({})); literal_proto.SerializeToString(&output.scalar()()); diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 5b8516bf1d..2952feb16a 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -52,44 +52,44 @@ string DeviceFromFlag() { xla::LiteralProto TwoElementTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); - return tuple->ToProto(); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); + return tuple.ToProto(); } xla::LiteralProto ScalarLiteral() { auto scalar = xla::LiteralUtil::CreateR0(12.0f); - return scalar->ToProto(); + return scalar.ToProto(); } xla::LiteralProto NestedTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); auto scalar = xla::LiteralUtil::CreateR0(12.0f); - auto nested = xla::LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); - return nested->ToProto(); + auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar}); + return nested.ToProto(); } xla::LiteralProto MakeTuple0() { auto scalar = xla::LiteralUtil::CreateR0(12.0f); auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); - auto nested0 = xla::LiteralUtil::MakeTuple({scalar.get(), tuple.get()}); - auto nested1 = xla::LiteralUtil::MakeTuple({scalar.get(), nested0.get()}); - return nested1->ToProto(); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); + auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple}); + auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0}); + return nested1.ToProto(); } -xla::LiteralProto FloatVector(gtl::ArraySlice v) { +xla::LiteralProto FloatVector(absl::Span v) { auto array = xla::LiteralUtil::CreateR1(v); - return array->ToProto(); + return array.ToProto(); } bool CompareLiteralProtos(const xla::LiteralProto& a, const xla::LiteralProto& b) { auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie(); auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); - bool equal = *l_a == *l_b; + bool equal = l_a == l_b; if (!equal) { LOG(INFO) << "LiteralProtos don't match " << a.DebugString() << " != " << b.DebugString(); @@ -100,7 +100,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a, bool CompareLiteralToLiteralProto(const xla::Literal& a, const xla::LiteralProto& b) { auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); - bool equal = a == *l_b; + bool equal = a == l_b; if (!equal) { LOG(INFO) << "Literal and LiteralProto don't match " << a.ToProto().DebugString() << " != " << b.DebugString(); @@ -211,7 +211,7 @@ TEST(RawApiTest, SubBuffer) { TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs)); auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie(); - auto base_elements = base_literal->DecomposeTuple(); + auto base_elements = base_literal.DecomposeTuple(); auto nested_0_elements = base_elements[0].Clone().DecomposeTuple(); xla::LiteralProto response_0; EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); @@ -343,7 +343,7 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response)); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } TEST(RawApiTest, CompileAndExecuteReturnTuple) { @@ -392,8 +392,8 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto sum = xla::LiteralUtil::CreateR1({9.0f, 7.0f}); - auto expected = xla::LiteralUtil::MakeTuple({sum.get()}); - EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response)); + auto expected = xla::LiteralUtil::MakeTuple({&sum}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } } // namespace diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 2c3b07da58..d05a1e7dcb 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -174,7 +174,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, - std::unique_ptr* literal) { + xla::Literal* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 42705688dd..73b5584e38 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -135,7 +135,7 @@ class XRTTupleAllocation : public ResourceBase { // Copies the allocation from device to host and returns it in literal. Status ToLiteral(xla::Backend* backend, int device_ordinal, - std::unique_ptr* literal); + xla::Literal* literal); // True if none of the buffers in the allocation are aliased by any other live // handle. -- cgit v1.2.3 From d274948444a1edc846d4b488f14ed029bfc569dd Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 10 Sep 2018 13:23:35 -0700 Subject: Add experimental grappler plugin to selection function implementation at run time. PiperOrigin-RevId: 212321238 --- tensorflow/core/framework/function_testlib.cc | 16 ++ tensorflow/core/framework/function_testlib.h | 3 + tensorflow/core/grappler/optimizers/BUILD | 65 ++++++++ .../experimental_implementation_selector.cc | 93 ++++++++++++ .../experimental_implementation_selector.h | 115 ++++++++++++++ .../experimental_implementation_selector_test.cc | 139 +++++++++++++++++ .../core/grappler/optimizers/function_api_info.cc | 167 +++++++++++++++++++++ .../core/grappler/optimizers/function_api_info.h | 80 ++++++++++ .../grappler/optimizers/function_api_info_test.cc | 160 ++++++++++++++++++++ 9 files changed, 838 insertions(+) create mode 100644 tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc create mode 100644 tensorflow/core/grappler/optimizers/experimental_implementation_selector.h create mode 100644 tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc create mode 100644 tensorflow/core/grappler/optimizers/function_api_info.cc create mode 100644 tensorflow/core/grappler/optimizers/function_api_info.h create mode 100644 tensorflow/core/grappler/optimizers/function_api_info_test.cc diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 46b169dddc..c5a4f661d2 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -110,6 +110,22 @@ FunctionDef XTimesTwo() { }); } +FunctionDef XAddX() { + return FDH::Define( + // Name + "XAddX", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}}, + }); +} + FunctionDef XTimesTwoInt32() { const Tensor kTwo = test::AsScalar(2); return FDH::Define( diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index 6d6476b936..ad61a76f16 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -63,6 +63,9 @@ GraphDef GDef(gtl::ArraySlice nodes, // x:T -> x * 2. FunctionDef XTimesTwo(); +// x:T -> x + x. +FunctionDef XAddX(); + // x:T -> x * 2, where x is int32. FunctionDef XTimesTwoInt32(); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index a24004dc16..f094c151e6 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -846,3 +846,68 @@ tf_cc_test( "//third_party/eigen3", ], ) + +cc_library( + name = "function_api_info", + srcs = ["function_api_info.cc"], + hdrs = ["function_api_info.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "function_api_info_test", + size = "small", + srcs = ["function_api_info_test.cc"], + deps = [ + ":function_api_info", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "experimental_implementation_selector", + srcs = ["experimental_implementation_selector.cc"], + hdrs = ["experimental_implementation_selector.h"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + ":function_api_info", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + ], +) + +tf_cc_test( + name = "experimental_implementation_selector_test", + size = "small", + srcs = ["experimental_implementation_selector_test.cc"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + ":experimental_implementation_selector", + ":function_api_info", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc new file mode 100644 index 0000000000..eeea269fb0 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" + +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/function_api_info.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +REGISTER_GRAPH_OPTIMIZER(ExperimentalImplementationSelector); + +Status ExperimentalImplementationSelector::LoadFunctions( + const GraphDef& graph) { + lib_info_.reset(new FunctionLibraryApiInfo); + TF_RETURN_IF_ERROR(lib_info_->Init(graph.library())); + return Status::OK(); +} + +Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall( + NodeDef* node_def) const { + const FunctionApiInfo* info = lib_info_->GetApiInfo(node_def->op()); + if (info == nullptr) { + // A regular op, or a function which has no interface. + return Status::OK(); + } + + string task, device; + if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) { + return errors::Internal("Could not split device name:", node_def->device()); + } + VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device() + << " = (" << task << ", " << device << ")"; + DeviceNameUtils::ParsedName parsed_name; + DeviceNameUtils::ParseLocalName(device, &parsed_name); + + string best_function_name; + lib_info_->GetBestImplementation(node_def->op(), parsed_name.type, + &best_function_name); + if (node_def->op() != best_function_name) { + // The current implementation is not the best, swap the op to the best one. + // There will be duplicates in the graph and they will be pruned by other + // grappler plugin since no other node is using their output as inputs. + // TODO(scottzhu): Update the tf.eager.defun to register functions without + // having to call them with input data. That will reduce the graph size and + // save the work for prune them. + node_def->set_op(best_function_name); + } + return Status::OK(); +} + +Status ExperimentalImplementationSelector::SelectImplementation( + GraphDef* graph) const { + for (int k = 0; k < graph->node_size(); ++k) + TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k))); + + return Status::OK(); +} + +Status ExperimentalImplementationSelector::Optimize(Cluster* cluster, + const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph)); + return SelectImplementation(optimized_graph); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h new file mode 100644 index 0000000000..82f7473a14 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ + +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/function_api_info.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +// -- EXPERIMENTAL -- +// This transformation replaces function calls by the appropriate function +// definition based on properties of the runtime system. For instance, +// we may choose one implementation over another if we have a GPU with +// enough memory available. +// +// It is a way for the programmer to specify alternative implementations +// of the same functionality in the graph, and let TensorFlow pick the +// most appropriate one at runtime. +// +// For instance, the python code might specify: +// @Defun(tf.float32, +// experimental_api_implements='plus_one', +// experimental_api_preferred_device='GPU') +// def plus_one_gpu(x): return x + 1.0 +// +// @Defun(tf.float32, +// experimental_api_implements='plus_one') +// def plus_one_reference_implementation(x): return x + 1.0 +// input = tf.constant(2.0, dtype=tf.float32) +// +// z = plus_one_reference_implementation(input) +// z = plus_one_gpu(input) +// print(sess.run(z)) +// +// At runtime, we will trim either `plus_one_gpu` or +// `plus_one_reference_implementation` based on the availability of the GPU. +// +// Available annotations: +// - experimental_api_implements(string): all functions mapping to the same +// string can be interchanged. For now, all functions must have the same +// signature and overloads are not allowed. Defuns within defuns are +// allowed. +// - experimental_api_preferred_device(string): sets which device is preferred. +class ExperimentalImplementationSelector : public CustomGraphOptimizer { + public: + ExperimentalImplementationSelector() = default; + ~ExperimentalImplementationSelector() override = default; + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + string name() const override { + return "experimental_implementation_selector"; + } + + // This call is not thread-safe. + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + // Does not take any feedback. + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + Status LoadFunctions(const GraphDef& graph); + Status MaybeOptimizeFunctionCall(NodeDef* node_def) const; + + // Finds all call sites for functions, then replace with the appropriate + // implementation. + // There are two ways of calling functions: + // 1. By specifying an op name as a function name, and + // 2. Via the functional interface, where the function name appears as an + // Attr. + // + // There may be multiple call sites for a given function. The function body + // may call into another function, so a function might have to be duplicated. + // For simplicity, we do not change function bodies. Also, we do not change + // gradients. + Status SelectImplementation(GraphDef* graph) const; + + std::unique_ptr lib_info_; + + TF_DISALLOW_COPY_AND_ASSIGN(ExperimentalImplementationSelector); +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc new file mode 100644 index 0000000000..2368e577c2 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" + +#include +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char CpuDevice[] = "/device:CPU:0"; +constexpr char GpuDevice[] = "/device:GPU:0"; + +class ExperimentalImplementationSelectorTest : public GrapplerTest {}; + +TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + std::unique_ptr optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull( + "ExperimentalImplementationSelector"); + ASSERT_NE(nullptr, optimizer); + TF_ASSERT_OK(optimizer->Init()); + + GraphDef output; + const Status status = optimizer->Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // This is a trivial graph so there is nothing to update. + EXPECT_EQ(item.graph.node_size(), output.node_size()); +} + +TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) { + using test::function::NDef; + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["experimental_api_implements"].set_s("times_two"); + (*func_attr)["experimental_api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XAddX(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["experimental_api_implements"].set_s("times_two"); + (*func2_attr)["experimental_api_preferred_device"].set_s("GPU"); + + ExperimentalImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice), + NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(output.node_size(), 5); + for (const NodeDef& node : output.node()) { + if (node.name() == "y1") { + // Make sure the implementation has been swapped to use the GPU version. + EXPECT_EQ("XAddX", node.op()); + } else if (node.name() == "y2") { + // Make sure the implementation is not changed. + EXPECT_EQ("XTimesTwo", node.op()); + } + } +} + +TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) { + using test::function::NDef; + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["experimental_api_implements"].set_s("random_boost"); + (*func_attr)["experimental_api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XTimesFour(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["experimental_api_implements"].set_s("random_boost"); + (*func2_attr)["experimental_api_preferred_device"].set_s("GPU"); + + ExperimentalImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice), + NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + const Tensor input = test::AsScalar(1.0f); + item.fetch = {"z"}; + item.feed.emplace_back("x", input); + + const auto four_times_boosted_tensor = EvaluateFetchNodes(item); + test::ExpectTensorEqual(four_times_boosted_tensor[0], + test::AsScalar(4.0f)); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + GrapplerItem optimized(item, std::move(output)); + const auto twice_boosted_tensor = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual(twice_boosted_tensor[0], + test::AsScalar(2.0f)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc new file mode 100644 index 0000000000..798e0f6fd5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_api_info.h" + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +FunctionApiInfo::FunctionApiInfo() {} +FunctionApiInfo::~FunctionApiInfo() {} + +Status FunctionApiInfo::Init(const FunctionDef& function_def) { + for (const auto& attr : function_def.attr()) { + if (attr.first == "experimental_api_preferred_device") { + preferred_device_ = attr.second.s(); + } + if (attr.first == "experimental_api_implements") { + interface_name_ = attr.second.s(); + } + } + if (interface_name_.empty() && !preferred_device_.empty()) { + return errors::InvalidArgument( + "Function '", function_def.signature().name(), + "' has a preferred device, but does not implement an interface"); + } + return Status::OK(); +} + +const string& FunctionApiInfo::preferred_device() const { + return preferred_device_; +} + +const string& FunctionApiInfo::interface_name() const { + return interface_name_; +} + +FunctionLibraryApiInfo::FunctionLibraryApiInfo() {} +FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {} + +namespace { +bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2) { + if (f1.ret().size() != f2.ret().size()) return false; + const auto& sig1 = f1.signature(); + const auto& sig2 = f2.signature(); + // Functions have positional semantics, so we don't check for names. + if (sig1.input_arg_size() != sig2.input_arg_size()) return false; + for (int k = 0; k < sig1.input_arg_size(); ++k) { + const OpDef::ArgDef& arg1 = sig1.input_arg(k); + const OpDef::ArgDef& arg2 = sig2.input_arg(k); + if (arg1.type() != arg2.type()) return false; + if (arg1.type_attr() != arg2.type_attr()) return false; + if (arg1.number_attr() != arg2.number_attr()) return false; + if (arg1.type_list_attr() != arg2.type_list_attr()) return false; + if (arg1.is_ref() != arg2.is_ref()) return false; + } + return true; +} + +Status ValidateSignature(const string& interface_name, + const std::vector& equiv_funcs) { + if (equiv_funcs.size() < 2) return Status::OK(); + for (size_t k = 1; k < equiv_funcs.size(); ++k) { + if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k])) + return errors::InvalidArgument( + "Functions '", equiv_funcs[0]->signature().name(), "' and '", + equiv_funcs[k]->signature().name(), "' both implement '", + interface_name, "' but their signatures do not match."); + } + return Status::OK(); +} + +Status ValidateSignatures( + const std::unordered_map>& + intf_to_func) { + for (const auto& item : intf_to_func) + TF_RETURN_IF_ERROR(ValidateSignature(item.first, item.second)); + return Status::OK(); +} +} // namespace + +Status FunctionLibraryApiInfo::Init( + const FunctionDefLibrary& function_library) { + std::unordered_map> intf_to_func; + for (const auto& function : function_library.function()) { + std::unique_ptr func_info(new FunctionApiInfo); + TF_RETURN_IF_ERROR(func_info->Init(function)); + // Ignore the function if it does not implement any interface. + if (func_info->interface_name().empty()) continue; + + const string& function_name = function.signature().name(); + const string& interface_name = func_info->interface_name(); + func_to_intf_[function_name] = interface_name; + intf_to_funcs_[interface_name].emplace_back(function_name); + intf_to_func[interface_name].emplace_back(&function); + func_info_[function_name] = std::move(func_info); + } + TF_RETURN_IF_ERROR(ValidateSignatures(intf_to_func)); + return Status::OK(); +} + +void FunctionLibraryApiInfo::GetEquivalentImplementations( + const string& function_name, std::vector* other_names) const { + const auto intf_it = func_to_intf_.find(function_name); + // The function does not implement any interface. + if (intf_it == func_to_intf_.end()) return; + CHECK(!intf_it->second.empty()) << "Function " << function_name + << "should at least implement 1 interface."; + const auto it = intf_to_funcs_.find(intf_it->second); + CHECK(it != intf_to_funcs_.end()) + << "Function " << function_name << " maps to " << intf_it->second + << " but no reverse mapping was found"; + CHECK_GE(it->second.size(), 1) << "Class " << it->first << " is empty"; + other_names->reserve(it->second.size() - 1); + for (const auto& other_name : it->second) { + if (other_name == function_name) continue; + other_names->emplace_back(other_name); + } +} + +void FunctionLibraryApiInfo::GetBestImplementation( + const string& function_name, const string& device, + string* best_func_name) const { + CHECK(best_func_name != nullptr); + const auto func_it = func_to_intf_.find(function_name); + if (func_it == func_to_intf_.end()) return; + + const auto it = intf_to_funcs_.find(func_it->second); + // No function found for the given interface. + if (it == intf_to_funcs_.end()) return; + for (const auto& func_name : it->second) { + const auto func_api_info = func_info_.find(func_name)->second.get(); + if (func_api_info->preferred_device() == device) { + best_func_name->assign(func_name); + return; + } + } + // Didn't find a function with the match device name, choose the first one + // among all the available functions. + best_func_name->assign(it->second.front()); +} + +const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo( + const string& function_name) const { + const auto it = func_info_.find(function_name); + if (it == func_info_.end()) return nullptr; + return it->second.get(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h new file mode 100644 index 0000000000..412687c58c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info.h @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +class FunctionApiInfo { + public: + FunctionApiInfo(); + virtual ~FunctionApiInfo(); + + Status Init(const FunctionDef& function_def); + + const string& interface_name() const; + const string& preferred_device() const; + + private: + string interface_name_; + string preferred_device_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo); +}; + +// A collection of information for function and the interface it implements. +// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple +// functions could implement the same interface with different behavior based on +// different hardware condition and limits, +// eg F1 = math_ops.add(math_ops.add(x, x), y), or +// F2 = math_ops.add(math_ops.matmul(x, 2), y). +class FunctionLibraryApiInfo { + public: + FunctionLibraryApiInfo(); + virtual ~FunctionLibraryApiInfo(); + // Populate the internal field for the functions within the function_library. + Status Init(const FunctionDefLibrary& function_library); + + void GetEquivalentImplementations(const string& function_name, + std::vector* other_names) const; + + void GetBestImplementation(const string& function_name, const string& device, + string* best_func_name) const; + + const FunctionApiInfo* GetApiInfo(const string& function_name) const; + + private: + // Map between function name to function details. + std::unordered_map> func_info_; + // Map between function name to interface name. + std::unordered_map func_to_intf_; + // Map between interface name to function names. + std::unordered_map> intf_to_funcs_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc new file mode 100644 index 0000000000..582890d3e3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_api_info.h" + +#include +#include +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { +void SetArg(const string& name, const string& type_name, + OpDef::ArgDef* arg_def) { + arg_def->set_name(name); + arg_def->set_type_attr(type_name); +} + +typedef std::pair ArgSpec; // name, type. + +void SetArgs(const std::vector& args_spec, OpDef* sig) { + for (const auto& arg_spec : args_spec) + SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg()); + SetArg("output", "float32", sig->add_output_arg()); +} + +void PopulateFunction(const string& name, const string& api_interface_name, + const string& preferred_device, + const std::vector& input_args, + FunctionDef* func_def) { + OpDef* sig = func_def->mutable_signature(); + sig->set_name(name); + + SetArgs(input_args, sig); + + if (!api_interface_name.empty() || !preferred_device.empty()) { + auto* func_attr = func_def->mutable_attr(); + if (!api_interface_name.empty()) + (*func_attr)["experimental_api_implements"].set_s(api_interface_name); + if (!preferred_device.empty()) + (*func_attr)["experimental_api_preferred_device"].set_s(preferred_device); + } +} + +void PopulateSampleLibrary(const bool mismatch_args, + FunctionDefLibrary* func_lib) { + const std::vector func_args{{"in1", "float32"}, {"in2", "int32"}}; + const std::vector func_wrong_args{{"in1", "int32"}, + {"in2", "int32"}}; + PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args, + func_lib->add_function()); + PopulateFunction("DoStuffGpu", "DoStuff", "GPU", + mismatch_args ? func_wrong_args : func_args, + func_lib->add_function()); + PopulateFunction("DoThings", "DoThings", "", func_args, + func_lib->add_function()); + PopulateFunction("OneOff", "", "", func_args, func_lib->add_function()); + PopulateFunction("AnotherOneOff", "", "", func_args, + func_lib->add_function()); +} + +bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name, + const std::vector& expected_other) { + std::vector other_impl; + lib_api_info.GetEquivalentImplementations(func_name, &other_impl); + const std::unordered_set actual(other_impl.begin(), other_impl.end()); + const std::unordered_set expected(expected_other.begin(), + expected_other.end()); + return actual == expected; +} + +bool CheckGetBestImpl(const FunctionLibraryApiInfo& lib_api_info, + const string& function_name, const string& device, + const string& expected_function_name) { + string best_function_name; + lib_api_info.GetBestImplementation(function_name, device, + &best_function_name); + + return best_function_name == expected_function_name; +} + +string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name) { + auto* info = lib_api_info.GetApiInfo(func_name); + CHECK_NOTNULL(info); + return info->interface_name(); +} + +string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name) { + auto* info = lib_api_info.GetApiInfo(func_name); + CHECK_NOTNULL(info); + return info->preferred_device(); +} + +TEST(FunctionApiInfoTest, ParseTags) { + FunctionDefLibrary func_lib; + PopulateSampleLibrary(/* mismatch_args */ false, &func_lib); + FunctionLibraryApiInfo lib_api_info; + TF_ASSERT_OK(lib_api_info.Init(func_lib)); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {})); + + EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu")); + EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu")); + EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings")); + + EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu")); + EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu")); + EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings")); + + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffCpu", "CPU", "DoStuffCpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffCpu", "GPU", "DoStuffGpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "CPU", "DoStuffCpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "GPU", "DoStuffGpu")); + + EXPECT_TRUE(CheckGetBestImpl(lib_api_info, "DoThings", "GPU", "DoThings")); + // TPU impl is not available, choose the first one available which is the CPU. + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "TPU", "DoStuffCpu")); +} + +TEST(FunctionApiInfoTest, MismatchedArguments) { + FunctionDefLibrary func_lib; + PopulateSampleLibrary(/* mismatch_args */ true, &func_lib); + FunctionLibraryApiInfo lib_api_info; + const Status ret = lib_api_info.Init(func_lib); + EXPECT_FALSE(ret.ok()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow -- cgit v1.2.3 From bb9c72ae54f3a4a16b851a811a20f93740f5f1d3 Mon Sep 17 00:00:00 2001 From: Shashi Shekhar Date: Mon, 10 Sep 2018 14:01:46 -0700 Subject: Update accuracy numbers without blacklist. PiperOrigin-RevId: 212328308 --- tensorflow/contrib/lite/g3doc/models.md | 91 ++++++++++++++++----------------- 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index 88f6cda420..a4267eee4c 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -7,65 +7,64 @@ Model Name | Paper_Model_Files^ --------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------: DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms -NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 74.2% | 91.7% | 261 ms | 389 ms -NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.8% | 96.2% | 6697 ms | 7940 ms -ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms -ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms -Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 78.2% | 94.0% | 1433 ms | 1522 ms -Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.4% | 95.2% | 2986 ms | 3139 ms -Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.8% | 94.1% | 2731 ms | 2926 ms -Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.6% | 66.6% | 6.2 ms | 13.0 ms -Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.7% | 70.6% | 8.6 ms | 19.5 ms -Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.5% | 72.4% | 12.1 ms | 27.8 ms -Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 50.0% | 74.4% | 16.2 ms | 37.3 ms -Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.5% | 79.5% | 18.1 ms | 29.9 ms -Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.3% | 82.1% | 26.8 ms | 45.9 ms -Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 62.0% | 83.7% | 35.6 ms | 65.3 ms -Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.5% | 85.0% | 47.6 ms | 164.2 ms -Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.3% | 84.1% | 34.6 ms | 48.7 ms -Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.5% | 86.1% | 51.3 ms | 75.2 ms -Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.4% | 87.4% | 71.7 ms | 107.0 ms -Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.6% | 88.3% | 95.7 ms | 143.4 ms -Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.5% | 85.9% | 57.4 ms | 76.8 ms -Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.3% | 87.8% | 86.0 ms | 117.7 ms -Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.2% | 89.3% | 118.6 ms | 167.3 ms -Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.3% | 90.1% | 160.1 ms | 224.3 ms -Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.9% | 90.1% | 117 ms | +NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 73.9% | 91.5% | 261 ms | 389 ms +NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.6% | 96.1% | 6697 ms | 7940 ms +ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 76.8% | 93.6% | 1880 ms | 1970 ms +Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 77.9% | 93.8% | 1433 ms | 1522 ms +Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.1% | 95.1% | 2986 ms | 3139 ms +Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.5% | 94.0% | 2731 ms | 2926 ms +Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.4% | 66.2% | 6.2 ms | 13.0 ms +Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.4% | 70.2% | 8.6 ms | 19.5 ms +Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.1% | 72.0% | 12.1 ms | 27.8 ms +Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.7% | 74.1% | 16.2 ms | 37.3 ms +Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.2% | 79.3% | 18.1 ms | 29.9 ms +Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.0% | 81.8% | 26.8 ms | 45.9 ms +Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.5% | 35.6 ms | 65.3 ms +Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.2% | 84.9% | 47.6 ms | 164.2 ms +Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.0% | 83.8% | 34.6 ms | 48.7 ms +Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.2% | 85.9% | 51.3 ms | 75.2 ms +Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.1% | 87.2% | 71.7 ms | 107.0 ms +Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.3% | 88.1% | 95.7 ms | 143.4 ms +Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.7% | 57.4 ms | 76.8 ms +Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms +Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 69.9% | 89.1% | 118.6 ms | 167.3 ms +Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.0% | 89.9% | 160.1 ms | 224.3 ms +Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.8% | 90.6% | 117 ms | ^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph. ^^ The performance numbers are generated in the benchmark on Pixel-2 using single thread large core. -^^ Accuracy numbers were computed using the [TFLite accuracy tool](../tools/accuracy/ilsvrc) -after excluding blacklisted images. +^^ Accuracy numbers were computed using the +[TFLite accuracy tool](../tools/accuracy/ilsvrc) . ## Image classification (Quantized Models) Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance --------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: -Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.8% | 64.8% | 3.7 ms -Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.0% | 68.4% | 5.5 ms -Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms -Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.5% | 73.1% | 10.4 ms -Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 55.2% | 78.4% | 8.8 ms -Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.5% | 80.7% | 13.0 ms -Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.2% | 82.3% | 18.3 ms -Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.5% | 83.5% | 24.7 ms -Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 56.2% | 79.4% | 16.2 ms -Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.7% | 83.9% | 24.3 ms -Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.4% | 86.4% | 33.8 ms -Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.2% | 87.0% | 45.4 ms -Mobilenet_V1_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.6% | 84.3% | 24.9 ms -Mobilenet_V1_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.9% | 37.4 ms -Mobilenet_V1_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.4% | 88.3% | 51.9 ms -Mobilenet_V1_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.2% | 89.1% | 70.2 ms -Mobilenet_v2_1.0_224_quant | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz) | 3.4 Mb | 71.1% | 90.1% | 80.3 ms -Inception_v3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.6% | 637 ms +Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms +Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 42.8% | 68.1% | 5.5 ms +Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.7% | 70.8% | 7.9 ms +Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 72.8% | 10.4 ms +Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.1% | 8.8 ms +Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.2% | 80.5% | 13.0 ms +Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 59.9% | 82.1% | 18.3 ms +Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.2% | 83.2% | 24.7 ms +Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.9% | 79.1% | 16.2 ms +Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.4% | 83.7% | 24.3 ms +Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.2% | 33.8 ms +Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.9% | 86.9% | 45.4 ms +Mobilenet_V1_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.3% | 84.1% | 24.9 ms +Mobilenet_V1_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 66.9% | 86.7% | 37.4 ms +Mobilenet_V1_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.1% | 88.1% | 51.9 ms +Mobilenet_V1_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.0% | 89.0% | 70.2 ms +Mobilenet_v2_1.0_224_quant | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz) | 3.4 Mb | 70.8% | 89.9% | 80.3 ms +Inception_v3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.7% | 637 ms ## Other models -Lite FlatBuffer ----------------------- | :----------------: Smart Reply 1.0 -Android | +Model | TF Lite FlatBuffer +----------------------- | :----------------: [reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), [tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip) -- cgit v1.2.3 From 1de8e4400b286e359e4369d41038eca8e18ad261 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 14:25:58 -0700 Subject: Adding forgotten exports. PiperOrigin-RevId: 212333784 --- tensorflow/python/keras/utils/data_utils.py | 1 + tensorflow/python/keras/utils/layer_utils.py | 1 + .../tensorflow.keras.utils.-ordered-enqueuer.pbtxt | 26 ++++++++++++++++++++++ .../api/golden/v1/tensorflow.keras.utils.pbtxt | 8 +++++++ .../tensorflow.keras.utils.-ordered-enqueuer.pbtxt | 26 ++++++++++++++++++++++ .../api/golden/v2/tensorflow.keras.utils.pbtxt | 8 +++++++ 6 files changed, 70 insertions(+) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index c1ee34ae46..d93a7b6afc 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -494,6 +494,7 @@ class SequenceEnqueuer(object): raise NotImplementedError +@tf_export('keras.utils.OrderedEnqueuer') class OrderedEnqueuer(SequenceEnqueuer): """Builds a Enqueuer from a Sequence. diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index 1f28c59ea4..158a9a5e76 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -26,6 +26,7 @@ from tensorflow.python.keras.utils.conv_utils import convert_kernel from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.utils.get_source_inputs') def get_source_inputs(tensor, layer=None, node_index=None): """Returns the list of input tensors necessary to compute `tensor`. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt new file mode 100644 index 0000000000..e7e7d2839b --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt @@ -0,0 +1,26 @@ +path: "tensorflow.keras.utils.OrderedEnqueuer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], " + } + member_method { + name: "get" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_running" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], " + } + member_method { + name: "stop" + argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt index 4d7a1519ce..81b91d2780 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "HDF5Matrix" mtype: "" } + member { + name: "OrderedEnqueuer" + mtype: "" + } member { name: "Progbar" mtype: "" @@ -44,6 +48,10 @@ tf_module { name: "get_file" argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], " } + member_method { + name: "get_source_inputs" + argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } member_method { name: "multi_gpu_model" argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt new file mode 100644 index 0000000000..e7e7d2839b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt @@ -0,0 +1,26 @@ +path: "tensorflow.keras.utils.OrderedEnqueuer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], " + } + member_method { + name: "get" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_running" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], " + } + member_method { + name: "stop" + argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt index 4d7a1519ce..81b91d2780 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "HDF5Matrix" mtype: "" } + member { + name: "OrderedEnqueuer" + mtype: "" + } member { name: "Progbar" mtype: "" @@ -44,6 +48,10 @@ tf_module { name: "get_file" argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], " } + member_method { + name: "get_source_inputs" + argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } member_method { name: "multi_gpu_model" argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], " -- cgit v1.2.3 From 4cbe494e87437213a7cb464ec23c12cb5788eb66 Mon Sep 17 00:00:00 2001 From: Sung Jin Hwang Date: Mon, 10 Sep 2018 14:33:53 -0700 Subject: Register gradient for EnsureShape op. Currently this op cannot be used within backprop path because it lacks gradient registry. PiperOrigin-RevId: 212335632 --- tensorflow/python/kernel_tests/check_ops_test.py | 13 +++++++++++++ tensorflow/python/ops/check_ops.py | 6 ++++++ 2 files changed, 19 insertions(+) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 680d0c97cc..27a674e223 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -819,6 +820,18 @@ class EnsureShapeTest(test.TestCase): with self.test_session() as sess: sess.run(derived, feed_dict={placeholder: feed_val}) + def testGradient(self): + placeholder = array_ops.placeholder(dtypes.float32) + derived = check_ops.ensure_shape(placeholder, (None, None)) + gradient = gradients.gradients(derived, placeholder) + + feed_val = [[4.0], [-1.0]] + with self.test_session() as sess: + gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val}) + + expected = [[1.0], [1.0]] + self.assertAllEqual(gradient_values, expected) + class EnsureShapeBenchmark(test.Benchmark): diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 6528062f3c..c3cf6e61f2 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1292,3 +1292,9 @@ def ensure_shape(x, shape, name=None): shape = tensor_shape.TensorShape(shape) return array_ops.ensure_shape(x, shape, name=name) + + +@ops.RegisterGradient('EnsureShape') +def _ensure_shape_grad(op, grad): + del op # Unused. + return grad -- cgit v1.2.3 From 55ad6406b8e0e1f50d27f619aa150cc2f827311a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 14:36:05 -0700 Subject: Move from deprecated self.test_session() to self.cached_session(). self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336206 --- .../layers/python/layers/embedding_ops_test.py | 54 ++-- .../contrib/layers/python/layers/encoders_test.py | 20 +- .../python/layers/feature_column_ops_test.py | 206 +++++++------- .../layers/python/layers/feature_column_test.py | 26 +- .../contrib/layers/python/layers/layers_test.py | 316 ++++++++++----------- .../layers/python/layers/normalization_test.py | 8 +- .../layers/python/layers/optimizers_test.py | 14 +- .../layers/python/layers/regularizers_test.py | 14 +- .../layers/python/layers/rev_block_lib_test.py | 10 +- .../contrib/layers/python/layers/summaries_test.py | 12 +- .../contrib/layers/python/layers/utils_test.py | 24 +- 11 files changed, 352 insertions(+), 352 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index 7ede193029..124515e5a6 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -109,7 +109,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): return sparse_ids, sparse_weights def test_safe_embedding_lookup_sparse_return_zero_vector(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, sparse_weights = self._ids_and_weights_2d() @@ -122,7 +122,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): 3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4]) def test_safe_embedding_lookup_sparse_return_special_vector(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, sparse_weights = self._ids_and_weights_2d() @@ -136,7 +136,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_weights[0][2], embedding_weights[0][3]]) def test_safe_embedding_lookup_sparse_no_weights(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, _ = self._ids_and_weights_2d() @@ -150,7 +150,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_weights[0][0] + embedding_weights[0][1]) / 2.0]) def test_safe_embedding_lookup_sparse_partitioned(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights(num_shards=3) sparse_ids, _ = self._ids_and_weights_2d() @@ -164,7 +164,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): (embedding_weights[0] + embedding_weights[1]) / 2.0]) def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights(num_shards=3) sparse_ids, sparse_weights = self._ids_and_weights_2d() @@ -179,7 +179,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_weights, sparse_ids, sparse_weights) def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, sparse_weights = self._ids_and_weights_3d() @@ -192,7 +192,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): ], [embedding_weights[0][2], [0] * 4, [0] * 4]]) def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, sparse_weights = self._ids_and_weights_3d() @@ -208,7 +208,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): ]]) def test_safe_embedding_lookup_sparse_3d_no_weights(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, _ = self._ids_and_weights_3d() @@ -224,7 +224,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): ]]) def test_safe_embedding_lookup_sparse_3d_partitioned(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights(num_shards=3) sparse_ids, _ = self._ids_and_weights_3d() @@ -241,7 +241,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights( self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights(num_shards=3) sparse_ids, sparse_weights = self._ids_and_weights_3d() @@ -276,7 +276,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): return embedding_weights def test_scattered_embedding_consistency(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() values = constant_op.constant(["foo", "foo"]) @@ -288,7 +288,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1]) def test_scattered_embedding_multiple_partition(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights(num_shards=7) values = constant_op.constant([4, 4, 5]) @@ -304,7 +304,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertGreater(embedding_diff, 0) def test_scattered_embedding_coverage(self): - with self.test_session(): + with self.cached_session(): size = 8 embedding_weights = self._random_weights(size=size, num_shards=3) values = constant_op.constant(["foo"]) @@ -316,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertEqual(len(np.unique(embedding_lookup_result[0])), size) def test_scattered_embedding_multi_dimension(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) @@ -329,7 +329,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1][2]) def test_scattered_embedding_lookup_sparse(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights(num_shards=3) sparse_tensor = sparse_tensor_lib.SparseTensor( values=["foo", "bar", "foo", "bar"], @@ -358,7 +358,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embeds = np.random.randn(n_embed, d_embed) idx = np.random.randint(0, n_embed, idx_shape) - with self.test_session(): + with self.cached_session(): embedded_np = embeds[idx] embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() @@ -370,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): idx = np.random.randint(0, 5, 10) idx2d = np.random.randint(0, 5, (10, 2)) - with self.test_session(): + with self.cached_session(): embedded_np = embeds[idx] embedded_np2d = embeds[idx2d] embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() @@ -408,7 +408,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): return embedding_weights def test_hashed_embedding_consistency(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() values = constant_op.constant(["foo", "foo"]) # The first three sampled_candidates are equal, so the first three @@ -429,7 +429,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1][3]) def test_hashed_embedding_multi_dimension(self): - with self.test_session(): + with self.cached_session(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) @@ -467,7 +467,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): def test_output_shape(self): """Verifies the shape of the output tensor.""" - with self.test_session(): + with self.cached_session(): sp_values = sparse_tensor_lib.SparseTensor( values=["a", "a", "b", "c", "d", "e", "f"], indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]], @@ -481,7 +481,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): def test_output_values(self): """Verifies the values in a trivial case.""" - with self.test_session(): + with self.cached_session(): sp_values = sparse_tensor_lib.SparseTensor( values=["a"], indices=[[1, 0]], dense_shape=[3, 1]) params = constant_op.constant([.1, .2, .3]) @@ -495,7 +495,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): def test_output_values_with_sampled_candidates(self): """Verifies the values for given sampled_candidates.""" - with self.test_session(): + with self.cached_session(): sp_values = sparse_tensor_lib.SparseTensor( values=["a", "a", "b", "c", "d", "e", "f"], indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]], @@ -520,7 +520,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): def test_output_values_with_sign_hash(self): """Verifies the values in a trivial case with hash_signs=True.""" - with self.test_session(): + with self.cached_session(): sp_values = sparse_tensor_lib.SparseTensor( values=["a"], indices=[[1, 0]], dense_shape=[3, 1]) params = constant_op.constant([.1, .1, .1]) @@ -537,7 +537,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): def test_distributive_property(self): """Verifies the distributive property of matrix multiplication.""" - with self.test_session(): + with self.cached_session(): params = constant_op.constant([.1, .2, .3]) sp_values_a = sparse_tensor_lib.SparseTensor( values=["a"], indices=[[0, 0]], dense_shape=[3, 1]) @@ -710,7 +710,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], [True, False]): - with self.test_session(): + with self.cached_session(): p, params, feed_dict = _EmbeddingParams( num_shards, vocab_size, shape=param_shape, dtype=dtype) embedding_sum = \ @@ -749,7 +749,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): for num_shards, combiner, dtype, ignore_weights in itertools.product( [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64], [True, False]): - with self.test_session(): + with self.cached_session(): x, params, _ = _EmbeddingParams( num_shards, vocab_size, shape=param_shape, dtype=dtype) @@ -767,7 +767,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3) def testIncompatibleShapes(self): - with self.test_session(): + with self.cached_session(): x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32) sp_ids = sparse_tensor_lib.SparseTensor( constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64), diff --git a/tensorflow/contrib/layers/python/layers/encoders_test.py b/tensorflow/contrib/layers/python/layers/encoders_test.py index e8528e9890..1a2aa710d5 100644 --- a/tensorflow/contrib/layers/python/layers/encoders_test.py +++ b/tensorflow/contrib/layers/python/layers/encoders_test.py @@ -34,14 +34,14 @@ def _get_const_var(name, shape, value): class EncodersTest(test.TestCase): def testBowEncoderSparse(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[0, 1], [2, 3]] enc = encoders.bow_encoder(docs, 4, 3) sess.run(variables.global_variables_initializer()) self.assertAllEqual([2, 3], enc.eval().shape) def testBowEncoderSparseTensor(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[0, 1], [2, 3]] sparse_docs = sparse_ops.dense_to_sparse_tensor(docs) enc = encoders.bow_encoder(sparse_docs, 4, 3) @@ -49,28 +49,28 @@ class EncodersTest(test.TestCase): self.assertAllEqual([2, 3], enc.eval().shape) def testBowEncoderSparseEmptyRow(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[0, 1], [2, 3], [0, 0]] enc = encoders.bow_encoder(docs, 4, 5) sess.run(variables.global_variables_initializer()) self.assertAllEqual([3, 5], enc.eval().shape) def testBowEncoderDense(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[0, 1], [2, 3], [0, 0], [0, 0]] enc = encoders.bow_encoder(docs, 4, 3, sparse_lookup=False) sess.run(variables.global_variables_initializer()) self.assertAllEqual([4, 3], enc.eval().shape) def testBowEncoderSparseTensorDenseLookup(self): - with self.test_session(): + with self.cached_session(): docs = [[0, 1]] sparse_docs = sparse_ops.dense_to_sparse_tensor(docs) with self.assertRaises(TypeError): encoders.bow_encoder(sparse_docs, 4, 3, sparse_lookup=False) def testBowEncodersSharingEmbeddings(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[0, 1], [2, 3]] enc_1 = encoders.bow_encoder(docs, 4, 3, scope='test') enc_2 = encoders.bow_encoder(docs, 4, 3, scope='test', reuse=True) @@ -79,7 +79,7 @@ class EncodersTest(test.TestCase): self.assertAllEqual(avg_1, avg_2) def testBowEncodersSharingEmbeddingsInheritedScopes(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[0, 1], [2, 3]] with variable_scope.variable_scope('test'): enc_1 = encoders.bow_encoder(docs, 4, 3) @@ -90,7 +90,7 @@ class EncodersTest(test.TestCase): self.assertAllEqual(avg_1, avg_2) def testBowEncodersSharingEmbeddingsSharedScope(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[0, 1], [2, 3]] enc_1 = encoders.bow_encoder(docs, 4, 3, scope='bow') variable_scope.get_variable_scope().reuse_variables() @@ -100,7 +100,7 @@ class EncodersTest(test.TestCase): self.assertAllEqual(avg_1, avg_2) def testBowEncoderReuseEmbeddingsVariable(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[1, 1], [2, 3]] with variable_scope.variable_scope('test'): v = _get_const_var('embeddings', (4, 3), @@ -111,7 +111,7 @@ class EncodersTest(test.TestCase): self.assertAllClose([[3., 4., 5.], [7.5, 8.5, 9.5]], enc.eval()) def testEmbedSequence(self): - with self.test_session() as sess: + with self.cached_session() as sess: docs = [[1, 1], [2, 3]] with variable_scope.variable_scope('test'): v = _get_const_var('embeddings', (4, 3), diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index e6bbd86ab7..6fb4b9ff35 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -49,7 +49,7 @@ class TransformerTest(test.TestCase): real_valued = feature_column.real_valued_column("price") features = {"price": constant_op.constant([[20.], [110], [-3]])} output = feature_column_ops._Transformer(features).transform(real_valued) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(output.eval(), [[20.], [110], [-3]]) def testSparseRealValuedColumnIdentityTransformation(self): @@ -60,7 +60,7 @@ class TransformerTest(test.TestCase): features = {"rating": rating_tensor} output = feature_column_ops._Transformer(features).transform( sparse_real_valued) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(output.values.eval(), rating_tensor.values.eval()) self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval()) self.assertAllEqual(output.dense_shape.eval(), @@ -80,7 +80,7 @@ class TransformerTest(test.TestCase): [sparse_real_valued]) self.assertTrue(sparse_real_valued in output_dict) output = output_dict[sparse_real_valued] - with self.test_session(): + with self.cached_session(): self.assertArrayNear(output.values.eval(), [4.0, 25.0], 1e-5) self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval()) self.assertAllEqual(output.dense_shape.eval(), @@ -97,7 +97,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[bucket]) self.assertEqual(len(output), 1) self.assertIn(bucket, output) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(output[bucket].eval(), [[2], [3], [0]]) def testBucketizedColumnWithMultiDimensions(self): @@ -109,7 +109,7 @@ class TransformerTest(test.TestCase): "price": constant_op.constant([[20., 110], [110., 20], [-3, -3]]) } output = feature_column_ops._Transformer(features).transform(bucket) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(output.eval(), [[2, 3], [3, 2], [0, 0]]) def testCachedTransformation(self): @@ -118,7 +118,7 @@ class TransformerTest(test.TestCase): # buckets 2, 3, 0 features = {"price": constant_op.constant([[20.], [110], [-3]])} transformer = feature_column_ops._Transformer(features) - with self.test_session() as sess: + with self.cached_session() as sess: transformer.transform(bucket) num_of_ops = len(sess.graph.get_operations()) # Verify that the second call to transform the same feature @@ -138,7 +138,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[hashed_sparse]) self.assertEqual(len(output), 1) self.assertIn(hashed_sparse, output) - with self.test_session(): + with self.cached_session(): self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64) self.assertTrue( all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval())) @@ -161,7 +161,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[hashed_sparse]) self.assertEqual(len(output), 1) self.assertIn(hashed_sparse, output) - with self.test_session(): + with self.cached_session(): self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64) self.assertTrue( all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval())) @@ -177,7 +177,7 @@ class TransformerTest(test.TestCase): features = {"wire": wire_tensor} output = feature_column_ops._Transformer(features).transform(hashed_sparse) - with self.test_session(): + with self.cached_session(): # While the input is a dense Tensor, the output should be a SparseTensor. self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertEqual(output.values.dtype, dtypes.int64) @@ -203,7 +203,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 2) self.assertIn(hashed_sparse, output) self.assertIn(wire_embedding, output) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(output[wire_embedding].indices.eval(), wire_tensor.indices.eval()) self.assertAllEqual(output[wire_embedding].dense_shape.eval(), [2, 2]) @@ -223,7 +223,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[keys_sparse]) self.assertEqual(len(output), 1) self.assertIn(keys_sparse, output) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0]) @@ -241,7 +241,7 @@ class TransformerTest(test.TestCase): features = {"wire": wire_tensor} output = feature_column_ops._Transformer(features).transform(keys_sparse) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() # While the input is a dense Tensor, the output should be a SparseTensor. self.assertIsInstance(output, sparse_tensor.SparseTensor) @@ -264,7 +264,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[hashed_sparse]) self.assertEqual(len(output), 1) self.assertIn(hashed_sparse, output) - with self.test_session(): + with self.cached_session(): self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int32) self.assertTrue( all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval())) @@ -282,7 +282,7 @@ class TransformerTest(test.TestCase): wire_tensor = constant_op.constant([[100, 0], [1, 25]]) features = {"wire": wire_tensor} output = feature_column_ops._Transformer(features).transform(hashed_sparse) - with self.test_session(): + with self.cached_session(): # While the input is a dense Tensor, the output should be a SparseTensor. self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertEqual(output.values.dtype, dtypes.int32) @@ -310,7 +310,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(weighted_ids, output) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(), ids_tensor.dense_shape.eval()) @@ -340,7 +340,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[vocab_sparse]) self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) @@ -362,7 +362,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[vocab_sparse]) self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) @@ -386,7 +386,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[vocab_sparse]) self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) @@ -408,7 +408,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[vocab_sparse]) self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) @@ -440,7 +440,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[country_language]) self.assertEqual(len(output), 1) self.assertIn(country_language, output) - with self.test_session(): + with self.cached_session(): self.assertEqual(output[country_language].values.dtype, dtypes.int64) self.assertTrue( all(x < 15 and x >= 0 for x in output[country_language].values.eval( @@ -467,7 +467,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[country_price]) self.assertEqual(len(output), 1) self.assertIn(country_price, output) - with self.test_session(): + with self.cached_session(): self.assertEqual(output[country_price].values.dtype, dtypes.int64) self.assertTrue( all(x < 15 and x >= 0 for x in output[country_price].values.eval())) @@ -498,7 +498,7 @@ class TransformerTest(test.TestCase): weights = column_to_variable[country_price][0] grad = array_ops.squeeze( gradients_impl.gradients(output, weights)[0].values) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertEqual(len(grad.eval()), 6) @@ -537,7 +537,7 @@ class TransformerTest(test.TestCase): features=features, feature_columns=[wire_country_price]) self.assertEqual(len(output), 1) self.assertIn(wire_country_price, output) - with self.test_session(): + with self.cached_session(): self.assertEqual(output[wire_country_price].values.dtype, dtypes.int64) self.assertTrue( all(x < 15 and x >= 0 for x in output[wire_country_price].values.eval( @@ -600,7 +600,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): columns = [one_hot_column, embedding_column, real_valued_column] output = feature_column_ops.input_from_feature_columns(features, columns) output_core = fc_core.input_layer(features, columns) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10]) @@ -626,7 +626,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): cols_to_outs = {} feature_column_ops.input_from_feature_columns( features, columns, cols_to_outs=cols_to_outs) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() for column in columns: @@ -637,7 +637,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): features = {"price": constant_op.constant([[20.], [110], [-3]])} output = feature_column_ops.input_from_feature_columns(features, [real_valued]) - with self.test_session(): + with self.cached_session(): self.assertAllClose(output.eval(), features["price"].eval()) # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllClose(output.eval(), @@ -650,7 +650,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): } output = feature_column_ops.input_from_feature_columns(features, [real_valued]) - with self.test_session(): + with self.cached_session(): self.assertAllClose(output.eval(), features["price"].eval()) # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllClose(output.eval(), @@ -662,7 +662,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): rating = np.array([[0., 1., 2., -1.], [3., 4., 5., 6.]]) features = {"rating": constant_op.constant(rating)} - with self.test_session() as sess: + with self.cached_session() as sess: output = sess.run(feature_column_ops.input_from_feature_columns( features, [var_len_real_valued])) self.assertAllClose(rating, output) @@ -673,7 +673,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): rating = np.array([[0, 1, 2, -1], [3, 4, 5, 6]]) features = {"rating": constant_op.constant(rating, dtype=dtypes.int64)} - with self.test_session() as sess: + with self.cached_session() as sess: output = sess.run(feature_column_ops.input_from_feature_columns( features, [var_len_real_valued])) self.assertAllClose(rating.astype(np.float32), output) @@ -684,7 +684,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): features = {"price": constant_op.constant([[20.], [110], [-3]])} output = feature_column_ops.input_from_feature_columns(features, [real_valued]) - with self.test_session(): + with self.cached_session(): self.assertAllClose(output.eval(), features["price"].eval() - 2) # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllClose(output.eval(), @@ -698,7 +698,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): } output = feature_column_ops.input_from_feature_columns(features, [real_valued]) - with self.test_session(): + with self.cached_session(): self.assertAllClose(output.eval(), features["price"].eval() - 2) # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllClose(output.eval(), @@ -713,7 +713,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): features = {"price": constant_op.constant([[20.], [110], [-3]])} output = feature_column_ops.input_from_feature_columns(features, [bucket]) expected = [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]] - with self.test_session(): + with self.cached_session(): self.assertAllClose(output.eval(), expected) self.assertAllClose(output.eval(), fc_core.input_layer(features, [bucket]).eval()) @@ -729,7 +729,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): output = feature_column_ops.input_from_feature_columns(features, [bucket]) expected = [[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1, 0], [1, 0, 0, 0, 1, 0, 0, 0]] - with self.test_session(): + with self.cached_session(): self.assertAllClose(output.eval(), expected) self.assertAllClose(output.eval(), fc_core.input_layer(features, [bucket]).eval()) @@ -752,7 +752,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): output = feature_column_ops.input_from_feature_columns(features, [one_hot_column]) output_core = fc_core.input_layer(features, [one_hot_column]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]], @@ -773,7 +773,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_sparse]) output_core = fc_core.input_layer(features, [one_hot_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]], @@ -794,7 +794,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_sparse]) output_core = fc_core.input_layer(features, [one_hot_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], @@ -816,7 +816,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): output = feature_column_ops.input_from_feature_columns(features, [one_hot_sparse]) output_core = fc_core.input_layer(features, [one_hot_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], output.eval()) @@ -834,7 +834,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): output = feature_column_ops.input_from_feature_columns(features, [one_hot_sparse]) output_core = fc_core.input_layer(features, [one_hot_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual([3, 10], output.eval().shape) @@ -852,7 +852,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): output = feature_column_ops.input_from_feature_columns(features, [embeded_sparse]) output_core = fc_core.input_layer(features, [embeded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(output.eval().shape, [4, 10]) # Verify cross compatibility: Core builder output should equal to contrib. @@ -878,7 +878,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): features, [embedded_sparse], weight_collections=["my_collection_core"]) weights_core = ops.get_collection("my_collection_core") grad_core = gradients_impl.gradients(output_core, weights_core) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() gradient_values = [] gradient_values_core = [] @@ -907,7 +907,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) output_core = fc_core.input_layer(features, [embeded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() output_eval = output.eval() self.assertAllEqual(output_eval.shape, [2, 10]) @@ -935,7 +935,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): # Makes sure that trying to use different initializers with the same # embedding column explicitly fails. - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( ValueError, "Duplicate feature column key found for column: wire_embedding"): @@ -961,7 +961,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) output_core = fc_core.input_layer(features, [embeded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) @@ -986,7 +986,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): embeded_sparse = feature_column.embedding_column(weighted_ids, 10) output = feature_column_ops.input_from_feature_columns(features, [embeded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) @@ -1005,7 +1005,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): embeded_sparse = feature_column.embedding_column(crossed, 10) output = feature_column_ops.input_from_feature_columns(features, [embeded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) @@ -1016,7 +1016,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): indices=[[0, 0], [1, 0], [1, 1]], dense_shape=[2, 2]) features = {"wire": wire_tensor} - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( ValueError, "Error creating input layer for column: wire"): variables_lib.global_variables_initializer().run() @@ -1035,7 +1035,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): indices=[[0, 0], [1, 0], [1, 1]], dense_shape=[2, 2]) features = {"ids": ids_tensor, "weights": weights_tensor} - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( ValueError, "Error creating input layer for column: ids_weighted_by_weights"): @@ -1053,7 +1053,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): indices=[[0, 0], [1, 0], [1, 1]], dense_shape=[2, 2]) features = {"aaa": wire_tensor, "bbb": wire_tensor} - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( ValueError, "Error creating input layer for column: aaa_X_bbb"): variables_lib.global_variables_initializer().run() @@ -1080,7 +1080,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): hashed_sparse, 10, initializer=init_ops.constant_initializer(133.7)) output = feature_column_ops.input_from_feature_columns( features, [real_valued, bucket, embeded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() # size of output = 3 (real_valued) + 2 * 4 (bucket) + 10 (embedding) = 21 self.assertAllEqual(output.eval().shape, [3, 21]) @@ -1099,7 +1099,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): initializer=init_ops.ones_initializer()) output = feature_column_ops.input_from_feature_columns(features, [embeded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() # score: (number of values) self.assertAllEqual(output.eval(), [[1.], [2.], [0.]]) @@ -1119,7 +1119,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): max_norm=0.5) output = feature_column_ops.input_from_feature_columns(features, [embedded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() # score: (number of values * 0.5) self.assertAllClose(output.eval(), [[0.5], [1.], [0.]]) @@ -1144,7 +1144,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): initializer=init_ops.ones_initializer()) output = feature_column_ops.input_from_feature_columns(features, [embeded_sparse]) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() # score: (sum of weights) @@ -1236,7 +1236,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): # There should be one trainable variables for sparse_2 self.assertEqual(1, len(variables_lib.trainable_variables())) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() output_1_eval = output_1.eval() output_2_eval = output_2.eval() @@ -1295,7 +1295,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): model_input_tensor = feature_column_ops.sequence_input_from_feature_columns( columns_to_tensors, [measurement_column]) - with self.test_session() as sess: + with self.cached_session() as sess: model_inputs = sess.run(model_input_tensor) self.assertAllClose(measurement_input, model_inputs) @@ -1305,7 +1305,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): rating = np.array([[0., 1., 2., -1.], [3., 4., 5., 6.]]) features = {"rating": constant_op.constant(rating)} - with self.test_session() as sess: + with self.cached_session() as sess: output = sess.run( feature_column_ops.sequence_input_from_feature_columns( features, [var_len_real_valued])) @@ -1329,7 +1329,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): expected_shape = [batch_size, sequence_length, np.prod(dimensions)] reshaped_measurements = np.reshape(measurement_input, expected_shape) - with self.test_session() as sess: + with self.cached_session() as sess: model_inputs = sess.run(model_input_tensor) self.assertAllClose(reshaped_measurements, model_inputs) @@ -1350,7 +1350,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): model_input_tensor = feature_column_ops.sequence_input_from_feature_columns( columns_to_tensors, [measurement_column]) - with self.test_session() as sess: + with self.cached_session() as sess: model_inputs = sess.run(model_input_tensor) self.assertAllClose(normalizer(measurement_input), model_inputs) @@ -1373,7 +1373,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): expected_shape = [batch_size, sequence_length, np.prod(dimensions)] reshaped_measurements = np.reshape(measurement_input, expected_shape) - with self.test_session() as sess: + with self.cached_session() as sess: model_inputs = sess.run(model_input_tensor) self.assertAllClose(normalizer(reshaped_measurements), model_inputs) @@ -1395,7 +1395,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): model_input_tensor = feature_column_ops.sequence_input_from_feature_columns( columns_to_tensors, [one_hot_column]) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) @@ -1429,7 +1429,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): model_input_tensor = feature_column_ops.sequence_input_from_feature_columns( columns_to_tensors, [one_hot_column]) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) @@ -1459,7 +1459,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): model_input_tensor = feature_column_ops.sequence_input_from_feature_columns( columns_to_tensors, [embedded_column]) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) @@ -1488,7 +1488,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): model_input_tensor = feature_column_ops.sequence_input_from_feature_columns( columns_to_tensors, [embedded_column]) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) @@ -1518,7 +1518,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): embedding_weights = ops.get_collection("my_collection") gradient_tensor = gradients_impl.gradients(model_input_tensor, embedding_weights) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() model_input, gradients = sess.run([model_input_tensor, gradient_tensor]) @@ -1585,7 +1585,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): columns_to_tensors, model_input_columns) self.assertEqual(dtypes.float32, model_input_tensor.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) @@ -1622,7 +1622,7 @@ class WeightedSumTest(test.TestCase): logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( features, [hashed_sparse], num_outputs=5) logits_core = fc_core.linear_model(features, [hashed_sparse], units=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) # Verify cross compatibility: Core builder output should equal to contrib. @@ -1640,7 +1640,7 @@ class WeightedSumTest(test.TestCase): logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( features, [hashed_sparse], num_outputs=5) logits_core = fc_core.linear_model(features, [hashed_sparse], units=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) # Verify cross compatibility: Core builder output should equal to contrib. @@ -1654,7 +1654,7 @@ class WeightedSumTest(test.TestCase): logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( features, [hashed_sparse], num_outputs=5) logits_core = fc_core.linear_model(features, [hashed_sparse], units=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) # Verify cross compatibility: Core builder output should equal to contrib. @@ -1676,7 +1676,7 @@ class WeightedSumTest(test.TestCase): logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( features, [weighted_ids], num_outputs=5) logits_core = fc_core.linear_model(features, [weighted_ids], units=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) @@ -1695,7 +1695,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_ids], num_outputs=5) logits_core = fc_core.linear_model(features, [weighted_ids], units=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) @@ -1716,7 +1716,7 @@ class WeightedSumTest(test.TestCase): logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( features, [crossed], num_outputs=5) logits_core = fc_core.linear_model(features, [crossed], units=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) # Verify cross compatibility: Core builder output should equal to contrib. @@ -1730,7 +1730,7 @@ class WeightedSumTest(test.TestCase): dense_shape=[2, 2]) features = {"wire": wire_tensor} embeded_sparse = feature_column.embedding_column(hashed_sparse, 10) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( ValueError, "Error creating weighted sum for column: wire_embedding"): variables_lib.global_variables_initializer().run() @@ -1756,7 +1756,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) logits_core = fc_core.linear_model(features, [movies]) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.initialize_all_variables().run() lookup_ops.tables_initializer().run() @@ -1776,7 +1776,7 @@ class WeightedSumTest(test.TestCase): } logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( features, [real_valued], num_outputs=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(logits.eval().shape, [3, 5]) @@ -1789,7 +1789,7 @@ class WeightedSumTest(test.TestCase): } logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( features, [bucket], num_outputs=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(logits.eval().shape, [3, 5]) @@ -1814,7 +1814,7 @@ class WeightedSumTest(test.TestCase): features, [real_valued, bucket, hashed_sparse, crossed], num_outputs=5) output_core = fc_core.linear_model( features, [real_valued, bucket, hashed_sparse, crossed], units=5) - with self.test_session(): + with self.cached_session(): variables_lib.global_variables_initializer().run() self.assertAllEqual(output.eval().shape, [3, 5]) # Verify cross compatibility: Core builder output should equal to contrib. @@ -1837,7 +1837,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, bias = ( feature_column_ops.weighted_sum_from_feature_columns( features, [age, language], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -1877,7 +1877,7 @@ class WeightedSumTest(test.TestCase): features, [country, language], num_outputs=1)) # Assert that only a single weight is created. self.assertEqual(len(variables), 1) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -1941,7 +1941,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, bias = ( feature_column_ops.weighted_sum_from_feature_columns( features, [weighted_language], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -1969,7 +1969,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, bias = ( feature_column_ops.weighted_sum_from_feature_columns( features, [language], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -1992,7 +1992,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [movies], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2026,7 +2026,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [country_language], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2050,7 +2050,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [language_language], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2083,7 +2083,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [country_language], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2124,7 +2124,7 @@ class WeightedSumTest(test.TestCase): features, [country, language, country_language], num_outputs=1, scope=scope)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2161,7 +2161,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [country, age, incomes], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2197,7 +2197,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [country, age, height, incomes], num_outputs=5)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2228,7 +2228,7 @@ class WeightedSumTest(test.TestCase): feature_column_ops.weighted_sum_from_feature_columns( features, [bucket], num_outputs=1)) output_core = fc_core.linear_model(features, [bucket]) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() # Cross compatibility: Core builder output should equal to contrib. @@ -2259,7 +2259,7 @@ class WeightedSumTest(test.TestCase): feature_column_ops.weighted_sum_from_feature_columns( features, [bucket, country], num_outputs=1)) output_core = fc_core.linear_model(features, [bucket, country]) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() # Cross compatibility: Core builder output should equal to contrib. @@ -2290,7 +2290,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [bucket, country], num_outputs=5)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2326,7 +2326,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [country_price], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2365,7 +2365,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [country_language_price], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2389,7 +2389,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [product], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] @@ -2404,7 +2404,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [product], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] @@ -2419,7 +2419,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [product], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] @@ -2440,7 +2440,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [product], num_outputs=1)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] @@ -2452,7 +2452,7 @@ class WeightedSumTest(test.TestCase): features = {"age": constant_op.constant([[10.], [20.], [30.], [40.]])} output, _, bias = feature_column_ops.weighted_sum_from_feature_columns( features, [feature_column.real_valued_column("age")], num_outputs=3) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() sess.run(bias.assign([0.1, 0.2, 0.3])) @@ -2466,7 +2466,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [column], num_outputs=3)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] @@ -2490,7 +2490,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [column], num_outputs=3)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] @@ -2516,7 +2516,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [column], num_outputs=3)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2556,7 +2556,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [column], num_outputs=3)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2585,7 +2585,7 @@ class WeightedSumTest(test.TestCase): output, column_to_variable, _ = ( feature_column_ops.weighted_sum_from_feature_columns( features, [column], num_outputs=3)) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() @@ -2651,7 +2651,7 @@ class ParseExampleTest(test.TestCase): feature_columns=[bucket, wire_cast]) self.assertIn(bucket, output) self.assertIn(wire_cast, output) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]]) self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]]) @@ -2713,7 +2713,7 @@ class ParseExampleTest(test.TestCase): self.assertIn("measurements", seq) self.assertIsInstance(seq["measurements"], ops.Tensor) - with self.test_session() as sess: + with self.cached_session() as sess: location_val, wire_cast_val, measurement_val = sess.run( [ctx["location"], seq["wire_cast"], seq["measurements"]]) diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index eaaf9f8d5f..d90d6ecf7f 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -201,7 +201,7 @@ class FeatureColumnTest(test.TestCase): b2 = feature_column_ops.input_from_feature_columns({ b[1]: input_tensor_c2 }, [b[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) b1_value = b1.eval() b2_value = b2.eval() @@ -230,7 +230,7 @@ class FeatureColumnTest(test.TestCase): e1 = feature_column_ops.input_from_feature_columns({ e[0]: input_tensor_c1 }, [e[0]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) d1_value = d1.eval() e1_value = e1.eval() @@ -340,7 +340,7 @@ class FeatureColumnTest(test.TestCase): with variable_scope.variable_scope("output_rank_{}".format(output_rank)): one_hot_output = one_hot._to_dnn_input_layer( id_tensor, output_rank=output_rank) - with self.test_session() as sess: + with self.cached_session() as sess: one_hot_value = sess.run(one_hot_output) expected_shape = (id_tensor_shape[:output_rank - 1] + [vocab_size]) self.assertEquals(expected_shape, list(one_hot_value.shape)) @@ -376,7 +376,7 @@ class FeatureColumnTest(test.TestCase): one_hot_output_shape = one_hot_output.get_shape().as_list() expected_shape = id_tensor_shape[:-1] + [vocab_size] self.assertEquals(expected_shape, one_hot_output_shape) - with self.test_session() as sess: + with self.cached_session() as sess: one_hot_value = sess.run(one_hot_output) self.assertEquals(expected_shape, list(one_hot_value.shape)) @@ -399,7 +399,7 @@ class FeatureColumnTest(test.TestCase): expected = np.array([[0., 1., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 1.], [1., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]) - with self.test_session() as sess: + with self.cached_session() as sess: one_hot_value = sess.run(one_hot_output) self.assertTrue(np.array_equal(one_hot_value, expected)) @@ -440,7 +440,7 @@ class FeatureColumnTest(test.TestCase): } one_hot_tensor = feature_column_ops.input_from_feature_columns( features, [one_hot]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) self.assertAllEqual([[2., 6., 0.]], one_hot_tensor.eval()) @@ -451,7 +451,7 @@ class FeatureColumnTest(test.TestCase): features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])} one_hot_tensor = feature_column_ops.input_from_feature_columns( features, [one_hot]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) self.assertAllEqual([[1., 1., 0.]], one_hot_tensor.eval()) @@ -603,7 +603,7 @@ class FeatureColumnTest(test.TestCase): real_valued_output = real_valued_column._to_dnn_input_layer( constant_op.constant(real_valued_input, dtype=dtypes.float32), output_rank=output_rank) - with self.test_session() as sess: + with self.cached_session() as sess: real_valued_eval = sess.run(real_valued_output) expected_shape = ( input_shape[:output_rank - 1] + @@ -797,7 +797,7 @@ class FeatureColumnTest(test.TestCase): sparse_column.insert_transformed_feature(features) sparse_output = features[sparse_column] expected_shape = [batch_size, 1] - with self.test_session() as sess: + with self.cached_session() as sess: sparse_result = sess.run(sparse_output) self.assertEquals(expected_shape, list(sparse_result.dense_shape)) @@ -1110,7 +1110,7 @@ class FeatureColumnTest(test.TestCase): ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix) checkpoint_path = os.path.join(ckpt_dir, "model.ckpt") - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) saved_embedding = embeddings.eval() save.save(sess, checkpoint_path) @@ -1131,7 +1131,7 @@ class FeatureColumnTest(test.TestCase): embedding_col_initialized: input_tensor }, [embedding_col_initialized]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loaded_embedding = pretrained_embeddings.eval() @@ -1176,7 +1176,7 @@ class FeatureColumnTest(test.TestCase): ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix) checkpoint_path = os.path.join(ckpt_dir, "model.ckpt") - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(assign_op) saved_col_weights = col_weights[crossed_col][0].eval() @@ -1201,7 +1201,7 @@ class FeatureColumnTest(test.TestCase): }, [crossed_col_initialized], 1)) col_weights_from_ckpt = col_weights[crossed_col_initialized][0] - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loaded_col_weights = col_weights_from_ckpt.eval() diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 52c9c4f3be..85af9de4e4 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -281,7 +281,7 @@ class BiasAddTest(test.TestCase): def testCreate(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.bias_add(images) self.assertEqual(output.op.name, 'BiasAdd/BiasAdd') @@ -289,7 +289,7 @@ class BiasAddTest(test.TestCase): def testCreateWithActivation(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = _layers.bias_add(images, activation_fn=nn_ops.relu) self.assertEqual(output.op.name, 'BiasAdd/Relu') @@ -298,7 +298,7 @@ class BiasAddTest(test.TestCase): def testCreateDimensions(self): dims = (2, 3, 4) shape = [5, 2, 3, 4] - with self.test_session(): + with self.cached_session(): for d in dims: input_shape = shape[:d] inputs = random_ops.random_uniform(input_shape, seed=1) @@ -311,7 +311,7 @@ class BiasAddTest(test.TestCase): class ConvolutionTest(test.TestCase): def testInvalidShape(self): - with self.test_session(): + with self.cached_session(): images_2d = random_ops.random_uniform((5, 7, 9, 3), seed=1) with self.assertRaisesRegexp( ValueError, 'Convolution expects input with rank 5, got 4'): @@ -323,14 +323,14 @@ class ConvolutionTest(test.TestCase): def testInvalidDataFormat(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) with self.assertRaisesRegexp(ValueError, 'data_format'): layers_lib.convolution2d(images, 32, 3, data_format='CHWN') def testCreateConv(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32) output = layers_lib.convolution2d(images, 32, [3, 3]) self.assertEqual(output.op.name, 'Conv/Relu') @@ -342,7 +342,7 @@ class ConvolutionTest(test.TestCase): def testCreateConvNCHW(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, 4, height, width)).astype(np.float32) output = layers_lib.convolution2d(images, 32, [3, 3], data_format='NCHW') self.assertEqual(output.op.name, 'Conv/Relu') @@ -354,7 +354,7 @@ class ConvolutionTest(test.TestCase): def testCreateSquareConv(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.convolution2d(images, 32, 3) self.assertEqual(output.op.name, 'Conv/Relu') @@ -362,7 +362,7 @@ class ConvolutionTest(test.TestCase): def testCreateConvWithTensorShape(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.convolution2d(images, 32, images.get_shape()[1:3]) self.assertEqual(output.op.name, 'Conv/Relu') @@ -370,7 +370,7 @@ class ConvolutionTest(test.TestCase): def testCreateFullyConv(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 32), seed=1) output = layers_lib.convolution2d( images, 64, images.get_shape()[1:3], padding='VALID') @@ -381,7 +381,7 @@ class ConvolutionTest(test.TestCase): def testFullyConvWithCustomGetter(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): called = [0] def custom_getter(getter, *args, **kwargs): @@ -395,7 +395,7 @@ class ConvolutionTest(test.TestCase): def testCreateVerticalConv(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 4), seed=1) output = layers_lib.convolution2d(images, 32, [3, 1]) self.assertEqual(output.op.name, 'Conv/Relu') @@ -407,7 +407,7 @@ class ConvolutionTest(test.TestCase): def testCreateHorizontalConv(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 4), seed=1) output = layers_lib.convolution2d(images, 32, [1, 3]) self.assertEqual(output.op.name, 'Conv/Relu') @@ -417,7 +417,7 @@ class ConvolutionTest(test.TestCase): def testCreateConvWithStride(self): height, width = 6, 8 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.convolution2d(images, 32, [3, 3], stride=2) self.assertEqual(output.op.name, 'Conv/Relu') @@ -427,7 +427,7 @@ class ConvolutionTest(test.TestCase): def testCreateConvCreatesWeightsAndBiasesVars(self): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 3), seed=1) - with self.test_session(): + with self.cached_session(): self.assertFalse(variables.get_variables('conv1/weights')) self.assertFalse(variables.get_variables('conv1/biases')) layers_lib.convolution2d(images, 32, [3, 3], scope='conv1') @@ -436,7 +436,7 @@ class ConvolutionTest(test.TestCase): def testCreateConvWithScope(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.convolution2d(images, 32, [3, 3], scope='conv1') self.assertEqual(output.op.name, 'conv1/Relu') @@ -453,14 +453,14 @@ class ConvolutionTest(test.TestCase): def testCreateConvWithoutActivation(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.convolution2d(images, 32, [3, 3], activation_fn=None) self.assertEqual(output.op.name, 'Conv/BiasAdd') def testCreateConvValid(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.convolution2d(images, 32, [3, 3], padding='VALID') self.assertListEqual(output.get_shape().as_list(), [5, 5, 7, 32]) @@ -468,7 +468,7 @@ class ConvolutionTest(test.TestCase): def testCreateConvWithWD(self): height, width = 7, 9 weight_decay = 0.01 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform((5, height, width, 3), seed=1) regularizer = regularizers.l2_regularizer(weight_decay) layers_lib.convolution2d( @@ -481,7 +481,7 @@ class ConvolutionTest(test.TestCase): def testCreateConvNoRegularizers(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) layers_lib.convolution2d(images, 32, [3, 3]) self.assertEqual( @@ -489,7 +489,7 @@ class ConvolutionTest(test.TestCase): def testReuseVars(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) layers_lib.convolution2d(images, 32, [3, 3], scope='conv1') self.assertEqual(len(variables.get_variables()), 2) @@ -498,7 +498,7 @@ class ConvolutionTest(test.TestCase): def testNonReuseVars(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) layers_lib.convolution2d(images, 32, [3, 3]) self.assertEqual(len(variables.get_variables()), 2) @@ -507,7 +507,7 @@ class ConvolutionTest(test.TestCase): def testReuseConvWithWD(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) weight_decay = regularizers.l2_regularizer(0.01) with arg_scope( @@ -523,7 +523,7 @@ class ConvolutionTest(test.TestCase): def testConvWithBatchNorm(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 32), seed=1) with arg_scope( [layers_lib.convolution2d], @@ -539,7 +539,7 @@ class ConvolutionTest(test.TestCase): def testReuseConvWithBatchNorm(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 32), seed=1) with arg_scope( [layers_lib.convolution2d], @@ -557,7 +557,7 @@ class ConvolutionTest(test.TestCase): def testCreateConvCreatesWeightsAndBiasesVarsWithRateTwo(self): height, width = 7, 9 images = random_ops.random_uniform((5, height, width, 3), seed=1) - with self.test_session(): + with self.cached_session(): self.assertFalse(variables.get_variables('conv1/weights')) self.assertFalse(variables.get_variables('conv1/biases')) layers_lib.convolution2d(images, 32, [3, 3], rate=2, scope='conv1') @@ -573,7 +573,7 @@ class ConvolutionTest(test.TestCase): output = layers_lib.convolution2d( images, num_filters, [3, 3], rate=2, padding='SAME') self.assertListEqual(list(output.get_shape().as_list()), expected_size) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -587,7 +587,7 @@ class ConvolutionTest(test.TestCase): output = layers_lib.convolution2d( images, num_filters, [3, 3], rate=2, padding='VALID') self.assertListEqual(list(output.get_shape().as_list()), expected_size) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -601,7 +601,7 @@ class ConvolutionTest(test.TestCase): output = layers_lib.convolution2d( images, num_filters, [3, 3], rate=[2, 3], padding='VALID') self.assertListEqual(list(output.get_shape().as_list()), expected_size) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEquals(output.op.name, 'Conv/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -612,7 +612,7 @@ class ConvolutionTest(test.TestCase): expected_size = [None, None, None, num_filters] expected_size_dynamic = [5, 7, 9, num_filters] - with self.test_session(): + with self.cached_session(): images = array_ops.placeholder(np.float32, [None, None, None, input_size[3]]) output = layers_lib.convolution2d( @@ -651,7 +651,7 @@ class ConvolutionTest(test.TestCase): expected_size = [None, None, None, num_filters] expected_size_dynamic = [5, 5, 7, num_filters] - with self.test_session(): + with self.cached_session(): images = array_ops.placeholder(np.float32, [None, None, None, input_size[3]]) output = layers_lib.convolution2d( @@ -670,7 +670,7 @@ class ConvolutionTest(test.TestCase): images = random_ops.random_uniform(input_size, seed=1) output = layers_lib.convolution2d( images, num_filters, [3, 3], rate=2, padding='VALID', scope='conv7') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'conv7/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -688,7 +688,7 @@ class ConvolutionTest(test.TestCase): padding='VALID', activation_fn=None, scope='conv7') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'conv7/BiasAdd') self.assertListEqual(list(output.eval().shape), expected_size) @@ -712,7 +712,7 @@ class Convolution2dTransposeTests(test.TestCase): def testInvalidDataFormat(self): height, width = 7, 9 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) with self.assertRaisesRegexp( ValueError, 'data_format has to be either NCHW or NHWC.'): @@ -915,7 +915,7 @@ class Convolution2dTransposeTests(test.TestCase): images, num_filters, [3, 3], stride=1, padding='SAME') self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertListEqual(list(output.eval().shape), expected_size) @@ -929,7 +929,7 @@ class Convolution2dTransposeTests(test.TestCase): images, num_filters, [3, 3], stride=1, padding='VALID') self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertListEqual(list(output.eval().shape), expected_size) @@ -944,7 +944,7 @@ class Convolution2dTransposeTests(test.TestCase): self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertListEqual(list(output.get_shape().as_list()), expected_size) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertListEqual(list(output.eval().shape), expected_size) @@ -958,7 +958,7 @@ class Convolution2dTransposeTests(test.TestCase): images, num_filters, [2, 2], stride=[2, 2], padding='SAME') self.assertListEqual(list(output.get_shape().as_list()), expected_size) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -971,7 +971,7 @@ class Convolution2dTransposeTests(test.TestCase): images = random_ops.random_uniform(input_size, seed=1) output = layers_lib.conv2d_transpose( images, num_filters, [2, 2], stride=[2, 2], padding='VALID') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -984,7 +984,7 @@ class Convolution2dTransposeTests(test.TestCase): images = random_ops.random_uniform(input_size, seed=1) output = layers_lib.conv2d_transpose( images, num_filters, [2, 2], stride=[2, 2], padding='SAME') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -997,7 +997,7 @@ class Convolution2dTransposeTests(test.TestCase): images = random_ops.random_uniform(input_size, seed=1) output = layers_lib.conv2d_transpose( images, num_filters, [2, 2], stride=[2, 2], padding='VALID') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -1010,7 +1010,7 @@ class Convolution2dTransposeTests(test.TestCase): images = random_ops.random_uniform(input_size, seed=1) output = layers_lib.conv2d_transpose( images, num_filters, [2, 4], stride=[2, 1], padding='VALID') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -1023,7 +1023,7 @@ class Convolution2dTransposeTests(test.TestCase): images = random_ops.random_uniform(input_size, seed=1) output = layers_lib.conv2d_transpose( images, num_filters, [2, 4], stride=[2, 4], padding='VALID') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -1036,7 +1036,7 @@ class Convolution2dTransposeTests(test.TestCase): images = random_ops.random_uniform(input_size, seed=1) output = layers_lib.conv2d_transpose( images, num_filters, [2, 4], stride=[2, 5], padding='VALID') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') self.assertListEqual(list(output.eval().shape), expected_size) @@ -1083,7 +1083,7 @@ class Convolution2dTransposeTests(test.TestCase): images, num_filters, [3, 3], stride=[2, 2], padding='VALID') self.assertListEqual(output.get_shape().as_list(), expected_size) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertEqual(output.op.name, 'Conv2d_transpose/Relu') eval_output = output.eval({images: np.zeros(input_size, np.float32)}) @@ -1095,7 +1095,7 @@ class Convolution2dTransposeTests(test.TestCase): expected_size = [None, None, None, num_filters] expected_size_dynamic = [5, 18, 22, num_filters] - with self.test_session(): + with self.cached_session(): images = array_ops.placeholder(np.float32, [None, None, None, input_size[3]]) output = layers_lib.conv2d_transpose( @@ -1116,7 +1116,7 @@ class Convolution2dTransposeTests(test.TestCase): images, num_filters, [3, 3], stride=2, padding='VALID', scope='conv7') self.assertEqual(output.op.name, 'conv7/Relu') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertListEqual(list(output.eval().shape), expected_size) @@ -1135,7 +1135,7 @@ class Convolution2dTransposeTests(test.TestCase): scope='conv7') self.assertEqual(output.op.name, 'conv7/BiasAdd') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertListEqual(list(output.eval().shape), expected_size) @@ -1146,7 +1146,7 @@ class Convolution2dTransposeTests(test.TestCase): stride = 2 padding = 'VALID' - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform(input_size, seed=1) output_deconv = layers_lib.conv2d_transpose( images, @@ -1184,7 +1184,7 @@ class ConvolutionInPlaneTest(test.TestCase): activation_fn=None) init_op = variables_lib.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) result = sess.run(horz_gradients) expected = np.zeros((1, 10, 9, 1)) @@ -1201,7 +1201,7 @@ class ConvolutionInPlaneTest(test.TestCase): activation_fn=None) init_op = variables_lib.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) result = sess.run( horz_gradients, feed_dict={ @@ -1225,7 +1225,7 @@ class ConvolutionInPlaneTest(test.TestCase): activation_fn=None) init_op = variables_lib.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) result = sess.run(horz_gradients) @@ -1245,7 +1245,7 @@ class ConvolutionInPlaneTest(test.TestCase): activation_fn=None) init_op = variables_lib.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) result = sess.run(horz_gradients) @@ -1267,7 +1267,7 @@ class ConvolutionInPlaneTest(test.TestCase): activation_fn=None) init_op = variables_lib.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) result = sess.run(horz_gradients) @@ -1283,7 +1283,7 @@ class ConvolutionInPlaneTest(test.TestCase): activation_fn=None) init_op = variables_lib.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) result = sess.run(vert_gradients) expected = np.zeros((1, 9, 10, 1)) @@ -1306,7 +1306,7 @@ class ConvolutionInPlaneTest(test.TestCase): activation_fn=None) init_op = variables_lib.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) result = sess.run(vert_gradients) @@ -1314,7 +1314,7 @@ class ConvolutionInPlaneTest(test.TestCase): def testConv1dShape(self): width = 7 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, width, 3), seed=1) output = layers_lib.convolution1d(images, 32, 3) self.assertEqual(output.op.name, 'Conv/Relu') @@ -1322,7 +1322,7 @@ class ConvolutionInPlaneTest(test.TestCase): def testConvInferSpatialDims(self): depth, height, width = 7, 9, 11 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, width, 4)).astype(np.float32) output = layers_lib.convolution(images, 32, [3]) self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) @@ -1344,7 +1344,7 @@ class DenseToSparseTest(test.TestCase): sparse = _layers.dense_to_sparse(tensor) dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape, sparse.values) - with self.test_session() as sess: + with self.cached_session() as sess: constant = sess.run(dense) self.assertAllEqual(expected_constant, constant) @@ -1353,7 +1353,7 @@ class DropoutTest(test.TestCase): def testCreateDropout(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') @@ -1362,7 +1362,7 @@ class DropoutTest(test.TestCase): def testCreateDropoutWithConstantTrue(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): is_training = constant_op.constant(True) images = random_ops.random_uniform((5, height, width, 3), seed=1) output = _layers.dropout(images, is_training=is_training) @@ -1370,7 +1370,7 @@ class DropoutTest(test.TestCase): def testCreateDropoutWithConstantFalse(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): is_training = constant_op.constant(False) images = random_ops.random_uniform((5, height, width, 3), seed=1) output = _layers.dropout(images, is_training=is_training) @@ -1378,7 +1378,7 @@ class DropoutTest(test.TestCase): def testCreateDropoutWithPlaceholder(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): is_training = array_ops.placeholder(dtype=dtypes.bool, shape=[]) images = random_ops.random_uniform((5, height, width, 3), seed=1) output = _layers.dropout(images, is_training=is_training) @@ -1387,7 +1387,7 @@ class DropoutTest(test.TestCase): def testCollectOutputs(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = _layers.dropout(images, outputs_collections='outputs') c_output = ops.get_collection('outputs')[0] @@ -1396,7 +1396,7 @@ class DropoutTest(test.TestCase): def testDropout(self): height, width = 10, 10 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) @@ -1409,7 +1409,7 @@ class DropoutTest(test.TestCase): def testDropoutSeed(self): """Test that providing the same seed produces the same result.""" height, width = 10, 10 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output1 = _layers.dropout(images, seed=1) @@ -1418,7 +1418,7 @@ class DropoutTest(test.TestCase): def testCreateDropoutNoTraining(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) @@ -1431,7 +1431,7 @@ class DropoutTest(test.TestCase): def testCreateFCFollowByDropout(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output = _layers.fully_connected(images, 50) @@ -1445,7 +1445,7 @@ class DropoutTest(test.TestCase): def testCreateFCWithDropout(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output = _layers.fully_connected( @@ -1475,7 +1475,7 @@ class FlattenTest(test.TestCase): def testCollectOutputs(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.flatten(images, outputs_collections='outputs') c_output = ops.get_collection('outputs')[0] @@ -1484,7 +1484,7 @@ class FlattenTest(test.TestCase): def testFlatten4D(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output = _layers.flatten(images) @@ -1494,7 +1494,7 @@ class FlattenTest(test.TestCase): def testFlatten3D(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height, width), seed=1, name='images') output = _layers.flatten(images) @@ -1504,7 +1504,7 @@ class FlattenTest(test.TestCase): def testFlattenBatchSize(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') inputs = array_ops.placeholder(dtypes.int32, (None, height, width, 3)) @@ -1516,7 +1516,7 @@ class FlattenTest(test.TestCase): def testUnknownDims(self): height = width = depth = 3 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, depth), seed=1, name='images') inputs = array_ops.placeholder(dtypes.int32, (None, None, None, None)) @@ -1551,7 +1551,7 @@ class PartialFlattenTest(test.TestCase): flattened_t = _layers._inner_flatten(inputs, new_rank) static_shape = flattened_t.get_shape().as_list() self.assertEqual(static_shape, expected_new_shape) - with self.test_session() as sess: + with self.cached_session() as sess: flattened = sess.run(flattened_t) np.testing.assert_array_equal(expected_flattened, flattened) @@ -1571,7 +1571,7 @@ class PartialFlattenTest(test.TestCase): flattened_t = _layers._inner_flatten(inputs_t, new_rank) - with self.test_session() as sess: + with self.cached_session() as sess: flattened = sess.run(flattened_t) np.testing.assert_array_equal(expected_indices, flattened.indices) @@ -1641,7 +1641,7 @@ class FCTest(test.TestCase): def testCreateFCWithScope(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): inputs = random_ops.random_uniform((5, height * width * 3), seed=1) output = _layers.fully_connected(inputs, 32, scope='fc1') self.assertEqual(output.op.name, 'fc1/Relu') @@ -1659,7 +1659,7 @@ class FCTest(test.TestCase): def testCreateFcCreatesWeightsAndBiasesVars(self): height, width = 3, 3 inputs = random_ops.random_uniform((5, height * width * 3), seed=1) - with self.test_session(): + with self.cached_session(): self.assertFalse(variables.get_variables('fc1/weights')) self.assertFalse(variables.get_variables('fc1/biases')) _layers.fully_connected(inputs, 32, scope='fc1') @@ -1669,7 +1669,7 @@ class FCTest(test.TestCase): def testReuseVars(self): height, width = 3, 3 inputs = random_ops.random_uniform((5, height * width * 3), seed=1) - with self.test_session(): + with self.cached_session(): _layers.fully_connected(inputs, 32, scope='fc1') self.assertEqual(len(variables.get_variables('fc1')), 2) _layers.fully_connected(inputs, 32, scope='fc1', reuse=True) @@ -1678,7 +1678,7 @@ class FCTest(test.TestCase): def testNonReuseVars(self): height, width = 3, 3 inputs = random_ops.random_uniform((5, height * width * 3), seed=1) - with self.test_session(): + with self.cached_session(): _layers.fully_connected(inputs, 32) self.assertEqual(len(variables.get_variables('fully_connected')), 2) _layers.fully_connected(inputs, 32) @@ -1713,14 +1713,14 @@ class FCTest(test.TestCase): def testCreateFCWithoutActivation(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): inputs = random_ops.random_uniform((5, height * width * 3), seed=1) output = _layers.fully_connected(inputs, 32, activation_fn=None) self.assertEqual(output.op.name, 'fully_connected/BiasAdd') def testCreateFCWithWD(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: inputs = random_ops.random_uniform((5, height * width * 3), seed=1) weight_decay = regularizers.l2_regularizer(0.01) _layers.fully_connected(inputs, 32, weights_regularizer=weight_decay) @@ -1732,7 +1732,7 @@ class FCTest(test.TestCase): def testCreateFCWithBD(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: inputs = random_ops.random_uniform((5, height * width * 3), seed=1) bias_decay = regularizers.l2_regularizer(0.01) _layers.fully_connected(inputs, 32, biases_regularizer=bias_decay) @@ -1744,7 +1744,7 @@ class FCTest(test.TestCase): def testCreateNoRegularizers(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): inputs = random_ops.random_uniform((5, height * width * 3), seed=1) _layers.fully_connected(inputs, 32) self.assertEqual( @@ -1752,7 +1752,7 @@ class FCTest(test.TestCase): def testReuseFCWithWD(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): inputs = random_ops.random_uniform((5, height * width * 3), seed=1) weight_decay = regularizers.l2_regularizer(0.01) _layers.fully_connected( @@ -1768,7 +1768,7 @@ class FCTest(test.TestCase): def testFCWithBatchNorm(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height * width * 3), seed=1) with arg_scope( [_layers.fully_connected], @@ -1786,7 +1786,7 @@ class FCTest(test.TestCase): def testReuseFCWithBatchNorm(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height * width * 3), seed=1) with arg_scope( [_layers.fully_connected], @@ -1844,7 +1844,7 @@ class BatchNormTest(test.TestCase): if dtype is None: dtype = dtypes.float32 height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)).astype( dtype.as_numpy_dtype) output = _layers.batch_norm(images, fused=fused) @@ -1866,7 +1866,7 @@ class BatchNormTest(test.TestCase): def _testCreateOpBetaRegularizer(self, fused=True): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): reg = lambda x: 0.1 * math_ops.reduce_sum(x) images = np.random.uniform(size=(5, height, width, 3)).astype('f') _layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused) @@ -1883,7 +1883,7 @@ class BatchNormTest(test.TestCase): def _testCreateOpGammaRegularizer(self, fused=True): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): reg = lambda x: 0.1 * math_ops.reduce_sum(x) images = np.random.uniform(size=(5, height, width, 3)).astype('f') _layers.batch_norm( @@ -1901,7 +1901,7 @@ class BatchNormTest(test.TestCase): def testCreateVariables(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) _layers.batch_norm(images, scale=True) beta = variables.get_variables_by_name('beta')[0] @@ -1915,7 +1915,7 @@ class BatchNormTest(test.TestCase): def testMovingAverageVariables(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) _layers.batch_norm(images, scale=True) self.assertEqual(len(variables.get_model_variables()), 4) @@ -1926,7 +1926,7 @@ class BatchNormTest(test.TestCase): def testMovingAverageVariablesZeroDebias(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) _layers.batch_norm( images, scale=True, zero_debias_moving_mean=True, fused=False) @@ -1943,7 +1943,7 @@ class BatchNormTest(test.TestCase): def testUpdatesCollection(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) _layers.batch_norm(images, updates_collections='my_update_ops') update_layers = ops.get_collection('my_update_ops') @@ -1971,7 +1971,7 @@ class BatchNormTest(test.TestCase): def testReuseVariables(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) _layers.batch_norm(images, scale=True, scope='bn') _layers.batch_norm(images, scale=True, scope='bn', reuse=True) @@ -1986,7 +1986,7 @@ class BatchNormTest(test.TestCase): def testReuseUpdateOps(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) with arg_scope([_layers.batch_norm], updates_collections='update_ops'): _layers.batch_norm(images, scope='bn') @@ -1996,7 +1996,7 @@ class BatchNormTest(test.TestCase): def testCreateMovingVars(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) _ = _layers.batch_norm(images) moving_mean = variables.get_variables('BatchNorm/moving_mean') @@ -2029,7 +2029,7 @@ class BatchNormTest(test.TestCase): moving_variance = variables.get_variables_by_name('moving_variance')[0] biased = variables.get_variables_by_name('biased')[0] local_step = variables.get_variables_by_name('local_step')[0] - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) self.assertAllClose(local_step.eval(), 0) self.assertAllClose(moving_mean.eval(), [0] * channels) @@ -2213,7 +2213,7 @@ class BatchNormTest(test.TestCase): def _testEvalMovingVars(self, zero_debias_moving_mean=False): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: image_shape = (10, height, width, 3) image_values = np.random.rand(*image_shape) expected_mean = np.mean(image_values, axis=(0, 1, 2)) @@ -2264,7 +2264,7 @@ class BatchNormTest(test.TestCase): height, width = 3, 3 batch_size = 10 channels = 3 - with self.test_session() as sess: + with self.cached_session() as sess: image_shape = (batch_size, height, width, channels) image_values = np.random.rand(*image_shape) expected_mean = np.mean(image_values, axis=(0, 1, 2)) @@ -2435,7 +2435,7 @@ class BatchNormTest(test.TestCase): def testNoUpdatesWhenIsTrainingFalse(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: image_shape = (10, height, width, 3) image_values = np.random.rand(*image_shape) images = constant_op.constant( @@ -2460,7 +2460,7 @@ class BatchNormTest(test.TestCase): def testNoneUpdatesCollectionNoTraining(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: image_shape = (10, height, width, 3) image_values = np.random.rand(*image_shape) images = constant_op.constant( @@ -2647,7 +2647,7 @@ class BatchNormTest(test.TestCase): def testCustomInitializer(self): height, width = 3, 3 channels = 3 - with self.test_session() as sess: + with self.cached_session() as sess: images = (np.ones((5, height, width, channels)) * 9.0).astype('f') beta = init_ops.constant_initializer( (np.ones(channels) * 5.0).astype('f')) @@ -2728,7 +2728,7 @@ class BatchNormTest(test.TestCase): def testBatchNormBeta(self): # Test case for 11673 - with self.test_session() as sess: + with self.cached_session() as sess: a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10)) _layers.batch_norm( a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True) @@ -2739,7 +2739,7 @@ class BatchNormTest(test.TestCase): def testVariablesAreFloat32(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height, width, 3), seed=1, dtype=dtypes.float16) _layers.batch_norm(images, scale=True) @@ -2824,7 +2824,7 @@ class LayerNormTest(test.TestCase): def testCreateOp(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.layer_norm(images) self.assertTrue(output.op.name.startswith('LayerNorm/batchnorm')) @@ -2832,7 +2832,7 @@ class LayerNormTest(test.TestCase): def testCreateVariables(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) _layers.layer_norm(images) beta = variables.get_variables_by_name('beta')[0] @@ -2842,7 +2842,7 @@ class LayerNormTest(test.TestCase): def testReuseVariables(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) _layers.layer_norm(images, scope='ln') _layers.layer_norm(images, scope='ln', reuse=True) @@ -2853,7 +2853,7 @@ class LayerNormTest(test.TestCase): def testReuseVars(self): height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: image_shape = (10, height, width, 3) image_values = np.random.rand(*image_shape) images = constant_op.constant( @@ -2940,7 +2940,7 @@ class GDNTest(test.TestCase): def _runGDN(self, x, shape, inverse, data_format): inputs = array_ops.placeholder(dtypes.float32, shape) outputs = _layers.gdn(inputs, inverse=inverse, data_format=data_format) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() y, = sess.run([outputs], {inputs: x}) return y @@ -3152,14 +3152,14 @@ class MaxPool3DTest(test.TestCase): class OneHotEncodingTest(test.TestCase): def testOneHotEncodingCreate(self): - with self.test_session(): + with self.cached_session(): labels = np.array([0, 1, 2]) output = _layers.one_hot_encoding(labels, num_classes=3) self.assertEqual(output.op.name, 'OneHotEncoding/one_hot') self.assertListEqual(output.get_shape().as_list(), [3, 3]) def testCollectOutputs(self): - with self.test_session(): + with self.cached_session(): labels = constant_op.constant([0, 1, 2]) output = _layers.one_hot_encoding( labels, num_classes=3, outputs_collections='outputs') @@ -3168,14 +3168,14 @@ class OneHotEncodingTest(test.TestCase): self.assertEqual(c_output, output) def testOneHotEncoding(self): - with self.test_session(): + with self.cached_session(): labels = constant_op.constant([0, 1, 2]) one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) output = _layers.one_hot_encoding(labels, num_classes=3) self.assertAllClose(output.eval(), one_hot_labels.eval()) def testOneHotEncodingInt32(self): - with self.test_session(): + with self.cached_session(): labels = constant_op.constant([0, 1, 2], dtype=dtypes.int32) one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) output = _layers.one_hot_encoding(labels, num_classes=3) @@ -3186,7 +3186,7 @@ class RepeatTests(test.TestCase): def testRepeat(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32) output = _layers.repeat(images, 3, layers_lib.conv2d, 32, [3, 3]) self.assertEqual(output.op.name, 'Repeat/convolution2d_3/Relu') @@ -3194,7 +3194,7 @@ class RepeatTests(test.TestCase): def testRepeatWithScope(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output = _layers.repeat( @@ -3207,7 +3207,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateConvInt32(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height, width, 3), seed=1, dtype=dtypes.int32, maxval=12345) with self.assertRaisesRegexp(TypeError, 'non-floating point type'): @@ -3215,7 +3215,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateConvFloat32(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height, width, 3), seed=1, dtype=dtypes.float32) output = layers_lib.separable_conv2d(images, 32, [3, 3], 2) @@ -3224,7 +3224,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateDepthwiseConv(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.separable_conv2d(images, None, [3, 3], 2) self.assertEqual(output.op.name, 'SeparableConv2d/Relu') @@ -3233,7 +3233,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateConvCreatesWeightsAndBiasesVars(self): height, width = 3, 3 images = random_ops.random_uniform((5, height, width, 3), seed=1) - with self.test_session(): + with self.cached_session(): self.assertFalse(variables.get_variables('conv1/depthwise_weights')) self.assertFalse(variables.get_variables('conv1/pointwise_weights')) self.assertFalse(variables.get_variables('conv1/biases')) @@ -3245,7 +3245,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateAtrousConvCreatesWeightsAndBiasesVars(self): height, width = 3, 3 images = random_ops.random_uniform((5, height, width, 3), seed=1) - with self.test_session(): + with self.cached_session(): self.assertFalse(variables.get_variables('conv1/depthwise_weights')) self.assertFalse(variables.get_variables('conv1/pointwise_weights')) self.assertFalse(variables.get_variables('conv1/biases')) @@ -3257,7 +3257,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateDepthwiseConvCreatesWeightsAndBiasesVars(self): height, width = 3, 3 images = random_ops.random_uniform((5, height, width, 3), seed=1) - with self.test_session(): + with self.cached_session(): self.assertFalse(variables.get_variables('conv1/depthwise_weights')) self.assertFalse(variables.get_variables('conv1/pointwise_weights')) self.assertFalse(variables.get_variables('conv1/biases')) @@ -3268,14 +3268,14 @@ class SeparableConv2dTest(test.TestCase): def testCreateConvWithScope(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.separable_conv2d(images, 32, [3, 3], 6, scope='conv1') self.assertEqual(output.op.name, 'conv1/Relu') def testCreateConvWithoutActivation(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.separable_conv2d( images, 32, [3, 3], 8, activation_fn=None) @@ -3283,7 +3283,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateConvValid(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.separable_conv2d( images, 32, [3, 3], 2, padding='VALID') @@ -3291,7 +3291,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateAtrousConvValid(self): height, width = 5, 5 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.separable_conv2d( images, 32, [3, 3], 2, padding='VALID', rate=2) @@ -3299,7 +3299,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateDepthwiseConvValid(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.separable_conv2d( images, None, [3, 3], 2, padding='VALID') @@ -3307,7 +3307,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateAtrousDepthwiseConvValid(self): height, width = 5, 5 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) output = layers_lib.separable_conv2d( images, None, [3, 3], 2, padding='VALID', rate=2) @@ -3316,7 +3316,7 @@ class SeparableConv2dTest(test.TestCase): def testCreateConvWithWeightDecay(self): random_seed.set_random_seed(0) height, width = 3, 3 - with self.test_session() as sess: + with self.cached_session() as sess: images = random_ops.random_uniform((5, height, width, 3), seed=1) regularizer = regularizers.l2_regularizer(0.01) layers_lib.separable_conv2d( @@ -3360,7 +3360,7 @@ class SeparableConv2dTest(test.TestCase): def testReuseConvWithWeightDecay(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) regularizer = regularizers.l2_regularizer(0.01) layers_lib.separable_conv2d( @@ -3419,7 +3419,7 @@ class SeparableConv2dTest(test.TestCase): normalizer_params={}, scope='conv1') init_op = variables_lib.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: images = np.random.rand(5, height, width, 3) sess.run(init_op) sess.run(net, feed_dict={images_placeholder: images}) @@ -3440,7 +3440,7 @@ class SeparableConv2dTest(test.TestCase): def testSepConvNCHW(self): for num_filters, correct_output_filters in zip((None, 5), (6, 5)): - with self.test_session(): + with self.cached_session(): batch, height, width = 4, 10, 12 kernel_dim, stride = 3, 2 images = random_ops.random_uniform((batch, 3, height, width), seed=1) @@ -3462,7 +3462,7 @@ class ScaleGradientTests(test.TestCase): """Simple tests of the scale_gradient function.""" def testBasic(self): - with self.test_session(): + with self.cached_session(): x = np.array([42], np.float32) gradient_scale = np.array([2], np.float32) @@ -3513,7 +3513,7 @@ class SoftmaxTests(test.TestCase): exp_prediction = np.array([[self.low, self.high], [0.5, 0.5], [self.high, self.low]]) - with self.test_session() as sess: + with self.cached_session() as sess: prediction = sess.run(prediction) self.assertAllClose(exp_prediction, prediction) @@ -3529,7 +3529,7 @@ class SoftmaxTests(test.TestCase): exp_prediction[1, 1, 1] = self.low prediction = _layers.softmax(logits) - with self.test_session() as sess: + with self.cached_session() as sess: prediction = sess.run(prediction) self.assertAllClose(exp_prediction, prediction) @@ -3547,7 +3547,7 @@ class SoftmaxTests(test.TestCase): exp_prediction[1, 1, 1] = self.low prediction = _layers.softmax(logit_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: prediction = sess.run(prediction, feed_dict=feed_dict) self.assertAllClose(exp_prediction, prediction) @@ -3575,7 +3575,7 @@ class SpatialSoftmaxTests(test.TestCase): features = array_ops.placeholder(dtypes.float32, shape=batch_shape) np_features = np.zeros(batch_shape, dtype=np.float32) spatial_softmax = _layers.spatial_softmax(features) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) @@ -3586,7 +3586,7 @@ class SpatialSoftmaxTests(test.TestCase): features = array_ops.placeholder(dtypes.float32, shape=batch_shape) np_features = np.zeros(batch_shape, dtype=np.float32) spatial_softmax = _layers.spatial_softmax(features, data_format='NCHW') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) @@ -3613,7 +3613,7 @@ class SpatialSoftmaxTests(test.TestCase): nchannels) # Make sure expected location keypoints matches actual location keypoints. - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) @@ -3637,7 +3637,7 @@ class SpatialSoftmaxTests(test.TestCase): nchannels) # Make sure expected location keypoints matches actual location keypoints. - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) @@ -3669,7 +3669,7 @@ class SpatialSoftmaxTests(test.TestCase): batch_size, nchannels) # Make sure expected location keypoints matches actual location keypoints. - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features1} tf_keypoints1 = sess.run(spatial_softmax, feed_dict) @@ -3696,7 +3696,7 @@ class SpatialSoftmaxTests(test.TestCase): nchannels) # Make sure expected location keypoints matches actual location keypoints. - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) @@ -3719,7 +3719,7 @@ class SpatialSoftmaxTests(test.TestCase): nchannels) # Make sure expected location keypoints matches actual location keypoints. - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) @@ -3731,7 +3731,7 @@ class SpatialSoftmaxTests(test.TestCase): spatial_softmax = _layers.spatial_softmax(features) net = _layers.fully_connected(spatial_softmax, 10) np_features = np.zeros(batch_shape, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} sess.run(net, feed_dict) @@ -3741,7 +3741,7 @@ class StackTests(test.TestCase): def testStackFullyConnected(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = np.random.uniform(size=(5, height * width * 3)) output = _layers.stack(images, _layers.fully_connected, [10, 20, 30]) self.assertEqual(output.op.name, 'Stack/fully_connected_3/Relu') @@ -3749,7 +3749,7 @@ class StackTests(test.TestCase): def testStackFullyConnectedFailOnReuse(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope('test', reuse=True): images = np.random.uniform(size=(5, height * width * 3)) with self.assertRaises(ValueError): @@ -3757,7 +3757,7 @@ class StackTests(test.TestCase): def testStackRelu(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height * width * 3), seed=1, name='images') output = _layers.stack(images, layers_lib.relu, [10, 20, 30]) @@ -3766,7 +3766,7 @@ class StackTests(test.TestCase): def testStackElu(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height * width * 3), seed=1, name='images') output = _layers.stack(images, layers_lib.elu, [10, 20, 30]) @@ -3775,7 +3775,7 @@ class StackTests(test.TestCase): def testStackConvolution2d(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output = _layers.stack( @@ -3788,7 +3788,7 @@ class StackTests(test.TestCase): def testStackWithScope(self): height, width = 3, 3 - with self.test_session(): + with self.cached_session(): images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output = _layers.stack( @@ -3817,7 +3817,7 @@ class UnitNormTests(test.TestCase): del shape[dim] expected = np.ones(shape) - with self.test_session(): + with self.cached_session(): actual = norms.eval() self.assertAllClose(expected, actual, 1e-4, 1e-4) @@ -3849,7 +3849,7 @@ class UnitNormTests(test.TestCase): norms = math_ops.sqrt( math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) - with self.test_session(): + with self.cached_session(): actual = norms.eval({image: placeholder_value}) self.assertAllClose(expected, actual, 1e-4, 1e-4) @@ -3875,7 +3875,7 @@ class PoincareNormalizeTest(test.TestCase): x_np = np.random.random_sample(x_shape).astype(np.float32) for dim in range(len(x_shape)): y_np = self._PoincareNormalize(x_np, dim, epsilon) - with self.test_session(): + with self.cached_session(): x_tf = constant_op.constant(x_np, name='x') y_tf = _layers.poincare_normalize(x_tf, dim, epsilon) y_tf_eval = y_tf.eval() @@ -3893,7 +3893,7 @@ class PoincareNormalizeTest(test.TestCase): x_np = np.random.random_sample(x_shape).astype(np.float32) dim = [1, 2] y_np = self._PoincareNormalize(x_np, dim, epsilon) - with self.test_session(): + with self.cached_session(): x_tf = constant_op.constant(x_np, name='x') y_tf = _layers.poincare_normalize(x_tf, dim, epsilon) y_tf_eval = y_tf.eval() @@ -3908,7 +3908,7 @@ class PoincareNormalizeTest(test.TestCase): np.random.seed(1) x_np = np.random.random_sample(x_shape).astype(np.float64) for dim in range(len(x_shape)): - with self.test_session(): + with self.cached_session(): x_tf = constant_op.constant(x_np, name='x') y_tf = _layers.poincare_normalize(x_tf, dim) err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, @@ -4117,7 +4117,7 @@ class LegacyFullyConnectedTest(test.TestCase): # Empty x is common if someone masks their input with tf.boolean_mask in # order to drop missing entries, and in a particular batch all entries are # missing. - with self.test_session(): + with self.cached_session(): x = np.array([]).reshape(0, 3) self.assertEqual(0, array_ops.size(x).eval()) y = _layers.legacy_fully_connected(x, 2, activation_fn=nn_ops.softmax) @@ -4131,7 +4131,7 @@ class LegacyFullyConnectedTest(test.TestCase): y = _layers.legacy_fully_connected(x, 1) # in the output we still only know the 2nd and 3rd dimensions statically. self.assertEqual(y.get_shape().as_list(), [None, 4, 1]) - with self.test_session() as sess: + with self.cached_session() as sess: variables_lib.global_variables_initializer().run() # we can feed in input with first dimension 2 shape_value = sess.run( @@ -4162,7 +4162,7 @@ class LegacyFullyConnectedTest(test.TestCase): self._unknown_dim_invalid_input(last_dim=None) def test_1d_invalid_input(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, 'rank of x must be at least 2 not: 1'): x = constant_op.constant([[]], shape=[0]) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index 55272e5fd1..c8d3c91b10 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -106,7 +106,7 @@ class InstanceNormTest(test.TestCase): images = random_ops.random_uniform(image_shape, seed=1) output_train = normalization.instance_norm(images, scope='IN') output_eval = normalization.instance_norm(images, scope='IN', reuse=True) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) # output_train and output_eval should be the same. train_np, eval_np = sess.run([output_train, output_eval]) @@ -130,7 +130,7 @@ class InstanceNormTest(test.TestCase): inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu output_op = normalization.instance_norm( inputs, center=False, scale=False, data_format=data_format) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) outputs = sess.run(output_op) # Make sure that there are no NaNs @@ -287,7 +287,7 @@ class GroupNormTest(test.TestCase): output_train = normalization.group_norm(images, groups=2, scope='IN') output_eval = normalization.group_norm(images, groups=2, scope='IN', reuse=True) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) # output_train and output_eval should be the same. train_np, eval_np = sess.run([output_train, output_eval]) @@ -349,7 +349,7 @@ class GroupNormTest(test.TestCase): channels_axis=channels_axis, reduction_axes=reduction_axes, mean_close_to_zero=mean_close_to_zero) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) outputs = sess.run(output_op) # Make sure that there are no NaNs diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 0f037e24ad..29dede2a49 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -165,7 +165,7 @@ class OptimizersTest(test.TestCase): def testGradientNoise(self): random_seed.set_random_seed(42) - with self.test_session() as session: + with self.cached_session() as session: x, var, loss, global_step = _setup_model() train = optimizers_lib.optimize_loss( loss, @@ -182,7 +182,7 @@ class OptimizersTest(test.TestCase): def testGradientNoiseWithClipping(self): random_seed.set_random_seed(42) - with self.test_session() as session: + with self.cached_session() as session: x, var, loss, global_step = _setup_model() train = optimizers_lib.optimize_loss( loss, @@ -198,7 +198,7 @@ class OptimizersTest(test.TestCase): self.assertEqual(global_step_value, 1) def testGradientClip(self): - with self.test_session() as session: + with self.cached_session() as session: x, var, loss, global_step = _setup_model() train = optimizers_lib.optimize_loss( loss, @@ -213,7 +213,7 @@ class OptimizersTest(test.TestCase): self.assertEqual(global_step_value, 1) def testAdaptiveGradientClip(self): - with self.test_session() as session: + with self.cached_session() as session: x, var, loss, global_step = _setup_model() clip_gradients = optimizers_lib.adaptive_clipping_fn() train = optimizers_lib.optimize_loss( @@ -234,7 +234,7 @@ class OptimizersTest(test.TestCase): self.assertEqual(2, var_count) def testGradientMultiply(self): - with self.test_session() as session: + with self.cached_session() as session: x, var, loss, global_step = _setup_model() train = optimizers_lib.optimize_loss( loss, @@ -433,7 +433,7 @@ class OptimizersTest(test.TestCase): class AdaptiveClipping(test.TestCase): def testAverages(self): - with self.test_session() as session: + with self.cached_session() as session: scale = 2. grad = array_ops.ones([3, 4]) * scale log_norm = np.log(np.sqrt(scale**2 * grad.get_shape().num_elements())) @@ -463,7 +463,7 @@ class AdaptiveClipping(test.TestCase): self.assertAlmostEqual(float(sq_mean), log_norm**2, places=4) def testClip(self): - with self.test_session() as session: + with self.cached_session() as session: spike = 1000. multiplier = array_ops.placeholder(dtypes.float32, [], "multiplier") step = array_ops.placeholder(dtypes.int32, [], "step") diff --git a/tensorflow/contrib/layers/python/layers/regularizers_test.py b/tensorflow/contrib/layers/python/layers/regularizers_test.py index 07191eeda7..51faba30c7 100644 --- a/tensorflow/contrib/layers/python/layers/regularizers_test.py +++ b/tensorflow/contrib/layers/python/layers/regularizers_test.py @@ -71,7 +71,7 @@ class RegularizerTest(test.TestCase): with self.assertRaises(ValueError): regularizers.l1_l2_regularizer(0.5, 0) - with self.test_session(): + with self.cached_session(): shape = [5, 5, 5] num_elem = 5 * 5 * 5 tensor = constant_op.constant(1.0, shape=shape) @@ -84,7 +84,7 @@ class RegularizerTest(test.TestCase): num_elem = 5 * 5 * 5 tensor = constant_op.constant(1.0, shape=shape) loss = regularizers.l1_l2_regularizer(0.0, 1.0)(tensor) - with self.test_session(): + with self.cached_session(): self.assertEquals(loss.op.name, 'l1_l2_regularizer') self.assertAlmostEqual(loss.eval(), num_elem / 2, 5) @@ -93,7 +93,7 @@ class RegularizerTest(test.TestCase): num_elem = 5 * 5 * 5 tensor = constant_op.constant(1.0, shape=shape) loss = regularizers.l1_l2_regularizer(1.0, 0.0)(tensor) - with self.test_session(): + with self.cached_session(): self.assertEquals(loss.op.name, 'l1_l2_regularizer') self.assertAlmostEqual(loss.eval(), num_elem, 5) @@ -104,7 +104,7 @@ class RegularizerTest(test.TestCase): self.assertEquals(loss, None) def testL1L2RegularizerWithScope(self): - with self.test_session(): + with self.cached_session(): shape = [5, 5, 5] num_elem = 5 * 5 * 5 tensor = constant_op.constant(1.0, shape=shape) @@ -142,7 +142,7 @@ class RegularizerTest(test.TestCase): array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]] tensor_weights_list = [constant_op.constant(x) for x in array_weights_list] expected = sum([2 * x for l in array_weights_list for x in l]) - with self.test_session(): + with self.cached_session(): result = regularizers.apply_regularization(dummy_regularizer, tensor_weights_list) self.assertAllClose(expected, result.eval()) @@ -151,7 +151,7 @@ class RegularizerTest(test.TestCase): regularizer = regularizers.l2_regularizer(0.0) array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]] tensor_weights_list = [constant_op.constant(x) for x in array_weights_list] - with self.test_session(): + with self.cached_session(): result = regularizers.apply_regularization(regularizer, tensor_weights_list) self.assertAllClose(0.0, result.eval()) @@ -161,7 +161,7 @@ class RegularizerTest(test.TestCase): tensor_weights_list = [ constant_op.constant(x) for x in [[1.5], [2, 3, 4.2], [10, 42, 666.6]] ] - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): regularizers.apply_regularization(non_scalar_regularizer, tensor_weights_list) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index c34b5a8017..2c7463acc0 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -58,7 +58,7 @@ class RevBlockTest(test.TestCase): y1, y2 = block.forward(x1, x2) x1_inv, x2_inv = block.backward(y1, y2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv]) @@ -81,7 +81,7 @@ class RevBlockTest(test.TestCase): x1, x2 = block.backward(y1, y2) y1_inv, y2_inv = block.forward(x1, x2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) @@ -151,7 +151,7 @@ class RevBlockTest(test.TestCase): grads_rev = gradients_impl.gradients(loss_rev, wrt) grads = gradients_impl.gradients(loss, wrt) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) @@ -286,7 +286,7 @@ class RecomputeTest(test.TestCase): for out, scope_vars in outputs_and_vars: all_grads.append(gradients_impl.gradients(out, scope_vars)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) outputs = list(zip(*outputs_and_vars))[0] outs, all_grads_val = sess.run([outputs, all_grads]) @@ -389,7 +389,7 @@ class RecomputeTest(test.TestCase): layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list))) grads = gradients_impl.gradients(layer_list[-1], layer_list[0]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(grads) def testErrorOnClosedOverTensor(self): diff --git a/tensorflow/contrib/layers/python/layers/summaries_test.py b/tensorflow/contrib/layers/python/layers/summaries_test.py index a1ef06feec..2ec2af9d44 100644 --- a/tensorflow/contrib/layers/python/layers/summaries_test.py +++ b/tensorflow/contrib/layers/python/layers/summaries_test.py @@ -29,19 +29,19 @@ from tensorflow.python.platform import test class SummariesTest(test.TestCase): def test_summarize_scalar_tensor(self): - with self.test_session(): + with self.cached_session(): scalar_var = variables.Variable(1) summary_op = summaries_lib.summarize_tensor(scalar_var) self.assertEquals(summary_op.op.type, 'ScalarSummary') def test_summarize_multidim_tensor(self): - with self.test_session(): + with self.cached_session(): tensor_var = variables.Variable([1, 2, 3]) summary_op = summaries_lib.summarize_tensor(tensor_var) self.assertEquals(summary_op.op.type, 'HistogramSummary') def test_summarize_activation(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable(1) op = array_ops.identity(var, name='SummaryTest') summary_op = summaries_lib.summarize_activation(op) @@ -52,7 +52,7 @@ class SummariesTest(test.TestCase): self.assertIn(u'SummaryTest/activation', names) def test_summarize_activation_relu(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable(1) op = nn_ops.relu(var, name='SummaryTest') summary_op = summaries_lib.summarize_activation(op) @@ -64,7 +64,7 @@ class SummariesTest(test.TestCase): self.assertIn(u'SummaryTest/activation', names) def test_summarize_activation_relu6(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable(1) op = nn_ops.relu6(var, name='SummaryTest') summary_op = summaries_lib.summarize_activation(op) @@ -77,7 +77,7 @@ class SummariesTest(test.TestCase): self.assertIn(u'SummaryTest/activation', names) def test_summarize_collection_regex(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable(1) array_ops.identity(var, name='Test1') ops.add_to_collection('foo', array_ops.identity(var, name='Test2')) diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py index a9bd89532a..34f63f5d86 100644 --- a/tensorflow/contrib/layers/python/layers/utils_test.py +++ b/tensorflow/contrib/layers/python/layers/utils_test.py @@ -42,7 +42,7 @@ class ConstantValueTest(test.TestCase): c = constant_op.constant(v) value = utils.constant_value(c) self.assertEqual(value, v) - with self.test_session(): + with self.cached_session(): self.assertEqual(c.eval(), v) def test_variable(self): @@ -60,7 +60,7 @@ class ConstantValueTest(test.TestCase): x = array_ops.identity(p) value = utils.constant_value(p) self.assertEqual(value, None) - with self.test_session(): + with self.cached_session(): self.assertEqual(x.eval(feed_dict={p: v}), v) @@ -80,7 +80,7 @@ class StaticCondTest(test.TestCase): expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.static_cond(v, fn1, fn2) - with self.test_session(): + with self.cached_session(): self.assertEqual(o.eval(), expected(v)) def test_variable(self): @@ -89,7 +89,7 @@ class StaticCondTest(test.TestCase): expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.static_cond(v, fn1, fn2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(o.eval(), expected(v)) @@ -99,7 +99,7 @@ class StaticCondTest(test.TestCase): expected = lambda v: -1 if v else -2 for v in [True, False, 1, 0]: o = utils.static_cond(v, fn1, fn2) - with self.test_session(): + with self.cached_session(): self.assertEqual(o.eval(), expected(v)) @@ -119,7 +119,7 @@ class SmartCondStaticTest(test.TestCase): expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) - with self.test_session(): + with self.cached_session(): self.assertEqual(o.eval(), expected(v)) def test_variable(self): @@ -128,7 +128,7 @@ class SmartCondStaticTest(test.TestCase): expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(o.eval(), expected(v)) @@ -138,7 +138,7 @@ class SmartCondStaticTest(test.TestCase): expected = lambda v: -1 if v else -2 for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) - with self.test_session(): + with self.cached_session(): self.assertEqual(o.eval(), expected(v)) @@ -151,7 +151,7 @@ class SmartCondDynamicTest(test.TestCase): p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) - with self.test_session(): + with self.cached_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) def test_constant(self): @@ -161,7 +161,7 @@ class SmartCondDynamicTest(test.TestCase): p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) - with self.test_session(): + with self.cached_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) def test_variable(self): @@ -171,7 +171,7 @@ class SmartCondDynamicTest(test.TestCase): p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) @@ -182,7 +182,7 @@ class SmartCondDynamicTest(test.TestCase): p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) - with self.test_session(): + with self.cached_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) -- cgit v1.2.3 From 132babebf5b1026cb33cad7c4eb7e03810c2acdf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 14:36:15 -0700 Subject: Move from deprecated self.test_session() to self.cached_session(). self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336258 --- .../data/kernel_tests/batch_dataset_op_test.py | 22 +++++----- .../data/kernel_tests/cache_dataset_op_test.py | 14 +++---- .../kernel_tests/concatenate_dataset_op_test.py | 4 +- .../kernel_tests/dataset_constructor_op_test.py | 16 ++++---- .../kernel_tests/dataset_from_generator_op_test.py | 28 ++++++------- .../python/data/kernel_tests/dataset_ops_test.py | 2 +- .../data/kernel_tests/filter_dataset_op_test.py | 14 +++---- .../data/kernel_tests/flat_map_dataset_op_test.py | 8 ++-- .../kernel_tests/list_files_dataset_op_test.py | 18 ++++----- .../data/kernel_tests/map_dataset_op_test.py | 47 +++++++++++----------- .../python/data/kernel_tests/optional_ops_test.py | 2 +- .../data/kernel_tests/prefetch_dataset_op_test.py | 4 +- .../data/kernel_tests/range_dataset_op_test.py | 16 ++++---- .../data/kernel_tests/reader_dataset_ops_test.py | 26 ++++++------ .../data/kernel_tests/sequence_dataset_op_test.py | 10 ++--- .../data/kernel_tests/shard_dataset_op_test.py | 14 +++---- .../data/kernel_tests/shuffle_dataset_op_test.py | 12 +++--- .../data/kernel_tests/zip_dataset_op_test.py | 4 +- 18 files changed, 131 insertions(+), 130 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index 89de55dd4f..c48708a2b9 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -82,7 +82,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[dim0] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -111,7 +111,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -131,7 +131,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -158,7 +158,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -188,7 +188,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) actual = sess.run(get_next) expected = sparse_tensor.SparseTensorValue( @@ -214,7 +214,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -262,7 +262,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -307,7 +307,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=4, padded_shapes=[5]).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.DataLossError): sess.run(get_next) @@ -318,7 +318,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=4, padded_shapes=[-1]).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: result = sess.run(get_next) self.assertAllEqual([[], [], [], []], result) with self.assertRaises(errors.OutOfRangeError): @@ -342,7 +342,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) sess.run( @@ -381,7 +381,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None]))) padded_dataset = dataset.padded_batch( 2, padded_shapes=([None], [None]), padding_values=('', 0)) - with self.test_session() as sess: + with self.cached_session() as sess: next_element = padded_dataset.make_one_shot_iterator().get_next() sess.run(next_element) diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index 4f7fd3566e..d5f5b2fe05 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -68,7 +68,7 @@ class FileCacheDatasetTest(test.TestCase): get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # First run without caching to collect the "ground truth". sess.run(init_fifo_op) elements = [] @@ -132,7 +132,7 @@ class FileCacheDatasetTest(test.TestCase): get_next1 = iterator1.get_next() get_next2 = iterator2.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) sess.run(get_next1) # this should succeed @@ -162,7 +162,7 @@ class FileCacheDatasetTest(test.TestCase): get_next1 = iterator1.get_next() get_next2 = iterator2.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) elements = [] @@ -217,7 +217,7 @@ class MemoryCacheDatasetTest(test.TestCase): uncached_iterator = uncached_dataset.make_initializable_iterator() uncached_next = uncached_iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(repeat_count.initializer) sess.run(cached_iterator.initializer) @@ -261,7 +261,7 @@ class MemoryCacheDatasetTest(test.TestCase): get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize with an empty upstream and a missing cache file (should # throw errors.OutOfRangeError immediately). sess.run(init_cache_op, feed_dict={count_placeholder: 0}) @@ -278,7 +278,7 @@ class MemoryCacheDatasetTest(test.TestCase): i1 = d1.make_initializable_iterator() i2 = d2.make_initializable_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(i1.initializer) self.assertEqual(1, sess.run(i1.get_next())) @@ -304,7 +304,7 @@ class MemoryCacheDatasetTest(test.TestCase): expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] - with self.test_session() as sess: + with self.cached_session() as sess: for i, expected in enumerate(expected_values): self.assertEqual(expected, sess.run(n), "Unexpected value at index %s" % i) diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py index 159218c99b..5dfb84f28e 100644 --- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py @@ -49,7 +49,7 @@ class ConcatenateDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(9): result = sess.run(get_next) @@ -83,7 +83,7 @@ class ConcatenateDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(9): result = sess.run(get_next) diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py index ea5b41e5d8..e43564a2eb 100644 --- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py @@ -50,7 +50,7 @@ class DatasetConstructorTest(test.TestCase): self.assertEqual([c.shape for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) results = sess.run(get_next) for component, result_component in zip(components, results): @@ -84,7 +84,7 @@ class DatasetConstructorTest(test.TestCase): [tensor_shape.TensorShape(c.dense_shape) for c in components], [shape for shape in iterator.output_shapes]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) results = sess.run(get_next) for component, result_component in zip(components, results): @@ -115,7 +115,7 @@ class DatasetConstructorTest(test.TestCase): if sparse_tensor.is_sparse(c) else c.shape for c in components ], [shape for shape in iterator.output_shapes]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) results = sess.run(get_next) for component, result_component in zip(components, results): @@ -142,7 +142,7 @@ class DatasetConstructorTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(4): results = sess.run(get_next) @@ -172,7 +172,7 @@ class DatasetConstructorTest(test.TestCase): [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components], [shape for shape in iterator.output_shapes]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) expected = [ (sparse_tensor.SparseTensorValue( @@ -232,7 +232,7 @@ class DatasetConstructorTest(test.TestCase): if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components ], [shape for shape in iterator.output_shapes]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) expected = [ (sparse_tensor.SparseTensorValue( @@ -283,7 +283,7 @@ class DatasetConstructorTest(test.TestCase): self.assertEqual((), iterator.output_shapes["foo"]) self.assertEqual((1,), iterator.output_shapes["bar"]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(3): results = sess.run(get_next) @@ -300,7 +300,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] # Test with sparse tensor in the appropriate order. diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py index fb55ae1400..cd0c1ddf1e 100644 --- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py @@ -44,7 +44,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): # Run twice to test reinitialization. sess.run(init_op) for _ in range(num_repeats): @@ -61,7 +61,7 @@ class DatasetConstructorTest(test.TestCase): .make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(num_repeats): for elem in elem_sequence: self.assertAllEqual(elem, sess.run(get_next)) @@ -131,7 +131,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(num_inner_repeats * num_outer_repeats): for elem in input_list: @@ -190,7 +190,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for elem in [0, 1]: for _ in range(num_parallel_iterators): @@ -213,7 +213,7 @@ class DatasetConstructorTest(test.TestCase): self.assertEqual(dtype, get_next.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for expected in [[1], [2], [3]]: next_val = sess.run(get_next) @@ -234,7 +234,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for expected in [b"foo", b"bar", b"baz"]: next_val = sess.run(get_next) @@ -255,7 +255,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) self.assertAllEqual([4, 5, 6], sess.run(get_next)) @@ -278,7 +278,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) self.assertAllEqual([4, 5, 6], sess.run(get_next)) @@ -302,7 +302,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertEqual((1, 2), sess.run(get_next)) self.assertEqual((3, 4), sess.run(get_next)) @@ -327,7 +327,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual(1, sess.run(get_next)) self.assertAllEqual([2, 3], sess.run(get_next)) @@ -347,7 +347,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual(0, sess.run(get_next)) self.assertAllEqual(1, sess.run(get_next)) @@ -405,7 +405,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4] for x in expected: @@ -434,7 +434,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) expected = [(0, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), @@ -468,7 +468,7 @@ class DatasetConstructorTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual(37, sess.run(get_next)) self.assertAllEqual(37, sess.run(get_next)) diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py index 2c4c11e132..239aa85175 100644 --- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py @@ -27,7 +27,7 @@ class DatasetOpsTest(test.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) - with self.test_session() as sess: + with self.cached_session() as sess: graph = graph_pb2.GraphDef().FromString( sess.run(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index 4f2216f0a3..19944d389f 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -59,7 +59,7 @@ class FilterDatasetTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Test that we can dynamically feed a different modulus value for each # iterator. def do_test(count_val, modulus_val): @@ -84,7 +84,7 @@ class FilterDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(0, sess.run(get_next)) self.assertEqual(1, sess.run(get_next)) self.assertEqual(3, sess.run(get_next)) @@ -98,7 +98,7 @@ class FilterDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): if (i ** 2) % 2 == 0: @@ -123,7 +123,7 @@ class FilterDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual(input_data[0], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): @@ -151,7 +151,7 @@ class FilterDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(5): actual = sess.run(get_next) @@ -169,7 +169,7 @@ class FilterDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): self.assertEqual((i, True), sess.run(get_next)) @@ -181,7 +181,7 @@ class FilterDatasetTest(test.TestCase): lambda x: math_ops.equal(x % 2, 0)) iterators = [dataset.make_one_shot_iterator() for _ in range(10)] next_elements = [iterator.get_next() for iterator in iterators] - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual([0 for _ in range(10)], sess.run(next_elements)) diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py index 350234a839..1123cbff62 100644 --- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py @@ -43,7 +43,7 @@ class FlatMapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in repeats: for _ in range(i): @@ -62,7 +62,7 @@ class FlatMapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for row in repeats: for i in row: @@ -113,7 +113,7 @@ class FlatMapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): for _ in range(i ** 2): @@ -137,7 +137,7 @@ class FlatMapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): for j in range(2): diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py index 579096f880..c4b338a58f 100644 --- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py @@ -44,7 +44,7 @@ class ListFilesDatasetOpTest(test.TestCase): def testEmptyDirectory(self): dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_one_shot_iterator() next_element = itr.get_next() with self.assertRaises(errors.OutOfRangeError): @@ -55,7 +55,7 @@ class ListFilesDatasetOpTest(test.TestCase): self._touchTempFiles(filenames) dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_one_shot_iterator() next_element = itr.get_next() @@ -75,7 +75,7 @@ class ListFilesDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.list_files( path.join(self.tmp_dir, '*'), shuffle=False) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_one_shot_iterator() next_element = itr.get_next() @@ -91,7 +91,7 @@ class ListFilesDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.list_files( path.join(self.tmp_dir, '*'), shuffle=True, seed=37) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_initializable_iterator() next_element = itr.get_next() @@ -121,7 +121,7 @@ class ListFilesDatasetOpTest(test.TestCase): filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) dataset = dataset_ops.Dataset.list_files(filename_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_initializable_iterator() with self.assertRaisesRegexp( errors.InvalidArgumentError, 'No files matched pattern: '): @@ -136,7 +136,7 @@ class ListFilesDatasetOpTest(test.TestCase): filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) dataset = dataset_ops.Dataset.list_files(filename_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_initializable_iterator() next_element = itr.get_next() sess.run( @@ -162,7 +162,7 @@ class ListFilesDatasetOpTest(test.TestCase): filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) dataset = dataset_ops.Dataset.list_files(filename_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_initializable_iterator() next_element = itr.get_next() sess.run( @@ -187,7 +187,7 @@ class ListFilesDatasetOpTest(test.TestCase): filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) dataset = dataset_ops.Dataset.list_files(filename_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_initializable_iterator() next_element = itr.get_next() sess.run( @@ -221,7 +221,7 @@ class ListFilesDatasetOpTest(test.TestCase): # more meaningful. dataset = dataset_ops.Dataset.list_files( path.join(self.tmp_dir, '*'), shuffle=False).repeat(2) - with self.test_session() as sess: + with self.cached_session() as sess: itr = dataset.make_one_shot_iterator() next_element = itr.get_next() diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index fde785be6e..7685d8dbdc 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Test single-threaded access to the iterator. sess.run(init_op, feed_dict={count: 14}) for _ in range(14): @@ -138,7 +138,8 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: + def do_test(num_parallel_calls_val, output_buffer_size_val): # Test single-threaded access to the iterator. sess.run(init_op, feed_dict={ @@ -203,7 +204,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(3): sess.run(get_next) @@ -218,7 +219,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(3): sess.run(get_next) @@ -233,7 +234,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(3): sess.run(get_next) @@ -254,7 +255,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(3): sess.run(get_next) @@ -285,7 +286,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(table.init) sess.run(init_op) sess.run(get_next) @@ -303,7 +304,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(enqueue_op) sess.run(close_op) sess.run(init_op) @@ -328,7 +329,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(enqueue_op) sess.run(close_op) sess.run(init_op) @@ -347,7 +348,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(counter_var.initializer) sess.run(init_op) for i in range(10): @@ -367,7 +368,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.NotFoundError): sess.run(get_next) @@ -379,7 +380,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) random_values = [] with self.assertRaises(errors.OutOfRangeError): @@ -404,7 +405,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) @@ -436,7 +437,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next() # make sure both datasets contain the same data - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(count): tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple]) self.assertEqual(tuple_, namedtuple_) @@ -454,7 +455,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual(row ** 2, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): @@ -485,7 +486,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Simple test that prefetch yields the expected values in the # expected order. for buffer_size in [1, 10, 100, 1000]: @@ -523,7 +524,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): self.assertEqual((i, 37.0), sess.run(get_next)) @@ -544,7 +545,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): self.assertEqual((i, 37.0), sess.run(get_next)) @@ -570,7 +571,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): actual = sess.run(get_next) @@ -597,7 +598,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): actual = sess.run(get_next) @@ -621,7 +622,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(100): self.assertEqual(i, sess.run(get_next)) @@ -635,7 +636,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): self.assertEqual((i, b"hello", 10), sess.run(get_next)) @@ -702,7 +703,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): dataset = dataset.map(broken_function) iterator = dataset.make_initializable_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"): sess.run(iterator.initializer) diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py index a32527af8d..c344513e71 100644 --- a/tensorflow/python/data/kernel_tests/optional_ops_test.py +++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py @@ -158,7 +158,7 @@ class OptionalTest(test.TestCase): self.assertEqual(ds.output_classes, next_elem.output_classes) elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() - with self.test_session() as sess: + with self.cached_session() as sess: # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. with self.assertRaises(errors.FailedPreconditionError): diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py index 63a0830272..cc97bac609 100644 --- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py @@ -36,7 +36,7 @@ class PrefetchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={buffer_size_t: buffer_size}) for m in range(10): self.assertEqual(m, sess.run(get_next)) @@ -51,7 +51,7 @@ class PrefetchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={buffer_size_t: buffer_size}) diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py index ad87f31b01..51e90785e7 100644 --- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py @@ -49,7 +49,7 @@ class RangeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={stop: 5}) for i in range(5): self.assertEqual(i, sess.run(get_next)) @@ -64,7 +64,7 @@ class RangeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={start: 2, stop: 5}) for i in range(2, 5): self.assertEqual(i, sess.run(get_next)) @@ -80,7 +80,7 @@ class RangeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2}) for i in range(2, 10, 2): self.assertEqual(i, sess.run(get_next)) @@ -95,7 +95,7 @@ class RangeDatasetTest(test.TestCase): step).make_initializable_iterator() init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0}) @@ -108,7 +108,7 @@ class RangeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1}) # This for loop is a no-op but will ensure that the implementation is # consistent with range if it ever changes. @@ -125,7 +125,7 @@ class RangeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={start: 10, stop: 2}) # This for loop is a no-op but will ensure that the implementation is # consistent with range if it ever changes. @@ -143,7 +143,7 @@ class RangeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2}) # This for loop is a no-op but will ensure that the implementation is # consistent with range if it ever changes. @@ -161,7 +161,7 @@ class RangeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1}) for i in range(10, 2, -1): self.assertEqual(i, sess.run(get_next)) diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py index 431362aa9a..aa3636364d 100644 --- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py @@ -100,7 +100,7 @@ class TextLineDatasetTest(test.TestCase): init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from file 0. sess.run( init_op, feed_dict={filenames: [test_filenames[0]], @@ -163,7 +163,7 @@ class TextLineDatasetTest(test.TestCase): repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10) iterator = repeat_dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(2): for i in range(5): self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next())) @@ -240,7 +240,7 @@ class FixedLengthRecordReaderTest(test.TestCase): init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from file 0. sess.run( init_op, feed_dict={filenames: [test_filenames[0]], @@ -302,7 +302,7 @@ class FixedLengthRecordReaderTest(test.TestCase): buffer_size=10) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertEqual(self._record(j, i), sess.run(iterator.get_next())) @@ -319,7 +319,7 @@ class FixedLengthRecordReaderTest(test.TestCase): buffer_size=10) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input " @@ -661,7 +661,7 @@ class TFRecordDatasetTest(test.TestCase): return filenames def testReadOneEpoch(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from file 0. sess.run( self.init_op, @@ -698,7 +698,7 @@ class TFRecordDatasetTest(test.TestCase): sess.run(self.get_next) def testReadTenEpochs(self): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={self.filenames: self.test_filenames, @@ -711,7 +711,7 @@ class TFRecordDatasetTest(test.TestCase): sess.run(self.get_next) def testReadTenEpochsOfBatches(self): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_batch_op, feed_dict={ @@ -738,7 +738,7 @@ class TFRecordDatasetTest(test.TestCase): f.write(cdata) zlib_files.append(zfn) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={self.filenames: zlib_files, @@ -758,7 +758,7 @@ class TFRecordDatasetTest(test.TestCase): gzf.write(f.read()) gzip_files.append(gzfn) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={self.filenames: gzip_files, @@ -774,7 +774,7 @@ class TFRecordDatasetTest(test.TestCase): d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) iterator = d.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertAllEqual(self._record(j, i), sess.run(next_element)) @@ -786,7 +786,7 @@ class TFRecordDatasetTest(test.TestCase): d = readers.TFRecordDataset(files) iterator = d.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(self._num_files): for i in range(self._num_records): self.assertAllEqual(self._record(j, i), sess.run(next_element)) @@ -801,7 +801,7 @@ class TFRecordDatasetTest(test.TestCase): next_element = iterator.get_next() expected = [] actual = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(10): for j in range(self._num_files): for i in range(self._num_records): diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py index 1d27b036eb..37e2333560 100644 --- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py @@ -44,7 +44,7 @@ class SequenceDatasetTest(test.TestCase): self.assertEqual([c.shape for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Test a finite repetition. sess.run(init_op, feed_dict={count_placeholder: 3}) for _ in range(3): @@ -90,7 +90,7 @@ class SequenceDatasetTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Take fewer than input size sess.run(init_op, feed_dict={count_placeholder: 4}) for i in range(4): @@ -136,7 +136,7 @@ class SequenceDatasetTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Skip fewer than input size, we should skip # the first 4 elements and then read the rest. sess.run(init_op, feed_dict={count_placeholder: 4}) @@ -183,7 +183,7 @@ class SequenceDatasetTest(test.TestCase): self.assertEqual([c.shape for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14}) for _ in range(7 * 14): results = sess.run(get_next) @@ -199,7 +199,7 @@ class SequenceDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py index cefe872d0f..137f6341ce 100644 --- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py @@ -28,7 +28,7 @@ class ShardDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.range(10).shard(5, 2) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(2, sess.run(iterator.get_next())) self.assertEqual(7, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -40,7 +40,7 @@ class ShardDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual((2, 8), sess.run(iterator.get_next())) self.assertEqual((7, 3), sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -50,7 +50,7 @@ class ShardDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.range(10).shard(5, 0) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(0, sess.run(iterator.get_next())) self.assertEqual(5, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -76,14 +76,14 @@ class ShardDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.range(1).shard(5, 2) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) def testLargerWorkerPool(self): dataset = dataset_ops.Dataset.range(10).shard(7, 5) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(5, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) @@ -91,7 +91,7 @@ class ShardDatasetOpTest(test.TestCase): def testIndexEqualsNumShards(self): dataset = dataset_ops.Dataset.range(10).shard(5, 4) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(4, sess.run(iterator.get_next())) self.assertEqual(9, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): @@ -100,7 +100,7 @@ class ShardDatasetOpTest(test.TestCase): def testIndexEqualsNumShards2(self): dataset = dataset_ops.Dataset.range(10).shard(4, 3) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(iterator.get_next())) self.assertEqual(7, sess.run(iterator.get_next())) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py index 5fcc48831f..f294840706 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py @@ -60,7 +60,7 @@ class ShuffleDatasetTest(test.TestCase): get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # First run without shuffling to collect the "ground truth". sess.run(init_fifo_op) unshuffled_elements = [] @@ -140,7 +140,7 @@ class ShuffleDatasetTest(test.TestCase): get_next = iterator.get_next() elems = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(10): elems.append(sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): @@ -152,7 +152,7 @@ class ShuffleDatasetTest(test.TestCase): .make_initializable_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer, feed_dict={seed_placeholder: 0}) for elem in elems: self.assertEqual(elem, sess.run(get_next)) @@ -166,7 +166,7 @@ class ShuffleDatasetTest(test.TestCase): get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: counts = collections.defaultdict(lambda: 0) for _ in range(10): for _ in range(5): @@ -183,7 +183,7 @@ class ShuffleDatasetTest(test.TestCase): .make_one_shot_iterator()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: initial_permutation = sess.run(next_element) self.assertAllEqual(initial_permutation, sess.run(next_element)) self.assertAllEqual(initial_permutation, sess.run(next_element)) @@ -198,7 +198,7 @@ class ShuffleDatasetTest(test.TestCase): .make_one_shot_iterator()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: initial_permutation = list(sess.run(next_element)) for _ in range(2): next_permutation = list(sess.run(next_element)) diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py index 55933118b9..3106effbd3 100644 --- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py @@ -45,7 +45,7 @@ class ZipDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: equal_length_components = [ np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(np.array([[12], [13], [14], [15]]), 22), @@ -93,7 +93,7 @@ class ZipDatasetTest(test.TestCase): self.assertEqual([22], get_next[1][0].shape) self.assertEqual([], get_next[1][1].shape) - with self.test_session() as sess: + with self.cached_session() as sess: equal_length_components = [ np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(np.array([[12], [13], [14], [15]]), 22), -- cgit v1.2.3 From 890e16594a005fe703a5556530b0dc3e6527fa47 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 14:36:26 -0700 Subject: Move from deprecated self.test_session() to self.cached_session(). self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336321 --- .../batch/categorical_split_handler_test.py | 10 +- .../learner/batch/ordinal_split_handler_test.py | 32 +- .../cloud/python/ops/bigquery_reader_ops_test.py | 4 +- .../cloud/python/ops/gcs_config_ops_test.py | 4 +- .../contrib/crf/python/kernel_tests/crf_test.py | 20 +- .../contrib/integrate/python/ops/odes_test.py | 34 +- .../contrib/layers/python/ops/sparse_ops_test.py | 46 +-- .../python/kernel_tests/decode_libsvm_op_test.py | 4 +- tensorflow/contrib/lite/python/convert_test.py | 12 +- tensorflow/contrib/lookup/lookup_ops_test.py | 206 +++++----- .../contrib/losses/python/losses/loss_ops_test.py | 214 +++++----- .../metrics/python/ops/metric_ops_large_test.py | 2 +- .../contrib/metrics/python/ops/metric_ops_test.py | 456 ++++++++++----------- .../model_pruning/python/layers/rnn_cells_test.py | 4 +- .../kernel_tests/hyperplane_lsh_probes_test.py | 2 +- .../kernel_tests/periodic_resample_op_test.py | 14 +- .../python/kernel_tests/recurrent_test.py | 4 +- .../python/saved_model/keras_saved_model_test.py | 10 +- tensorflow/python/client/session_test.py | 4 +- .../ops/parallel_for/control_flow_ops_test.py | 2 +- .../python/ops/parallel_for/gradients_test.py | 6 +- tensorflow/python/util/nest_test.py | 2 +- tensorflow/python/util/tf_should_use_test.py | 5 +- .../compatibility/testdata/test_file_v0_11.py | 16 +- .../compatibility/testdata/test_file_v1_10.py | 2 +- 25 files changed, 558 insertions(+), 557 deletions(-) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index d9f03c3840..94ea7bc2eb 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -47,7 +47,7 @@ def get_empty_tensors(gradient_shape, hessian_shape): class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 0 | 1,2 | @@ -281,7 +281,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): gains[0], 0.00001) def testGenerateFeatureSplitCandidatesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 0 | 1,2 | @@ -404,7 +404,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testGenerateFeatureSplitCandidatesMulticlass(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( [[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2]) @@ -482,7 +482,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] @@ -530,7 +530,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 5532bd026a..74b0ea6989 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -50,7 +50,7 @@ def get_empty_tensors(gradient_shape, hessian_shape): class DenseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -183,7 +183,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testObliviousFeatureSplitGeneration(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 1 | 3 | @@ -320,7 +320,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(2, oblivious_split_info.children_parent_id[1]) def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -458,7 +458,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( @@ -546,7 +546,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.3, split_node.threshold, 1e-6) def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( @@ -633,7 +633,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.3, split_node.threshold, 1e-6) def testGenerateFeatureSplitCandidatesInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -708,7 +708,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testGenerateFeatureSplitCandidatesWithTreeComplexity(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -842,7 +842,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -951,7 +951,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1074,7 +1074,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1207,7 +1207,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch is 4, 2 classes gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) @@ -1302,7 +1302,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch is 4, 2 classes gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) @@ -1397,7 +1397,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1475,7 +1475,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) # No values in this feature column in this mini-batch. values = array_ops.constant([], dtype=dtypes.float32) @@ -1545,7 +1545,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testEmptyBuckets(self): """Test that reproduces the case when quantile buckets were empty.""" - with self.test_session() as sess: + with self.cached_session() as sess: sparse_column = array_ops.sparse_placeholder(dtypes.float32) # We have two batches - at first, a sparse feature is empty. @@ -1638,7 +1638,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testDegenerativeCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: # One data example only, one leaf and thus one quantile bucket.The same # situation is when all examples have the same values. This case was # causing before a failure. diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py index 493b3c6f1b..11e177cd0c 100644 --- a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py +++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py @@ -197,7 +197,7 @@ class BigQueryReaderOpsTest(test.TestCase): def _ReadAndCheckRowsUsingFeatures(self, num_rows): self.server.handler.num_rows = num_rows - with self.test_session() as sess: + with self.cached_session() as sess: feature_configs = { "int64_col": parsing_ops.FixedLenFeature( @@ -254,7 +254,7 @@ class BigQueryReaderOpsTest(test.TestCase): num_rows = 10 self.server.handler.num_rows = num_rows - with self.test_session() as sess: + with self.cached_session() as sess: reader = cloud.BigQueryReader( project_id=_PROJECT, dataset_id=_DATASET, diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py index 9b6c056d6c..4f2ecbcb17 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -26,7 +26,7 @@ class GcsConfigOpsTest(test.TestCase): def testSetBlockCache(self): cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024) - with self.test_session() as sess: + with self.cached_session() as sess: gcs_config_ops.configure_gcs(sess, block_cache=cfg) def testConfigureGcsHook(self): @@ -36,7 +36,7 @@ class GcsConfigOpsTest(test.TestCase): 'type': 'authorized_user'} hook = gcs_config_ops.ConfigureGcsHook(credentials=creds) hook.begin() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None hook.after_create_session(sess, None) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 8cfe142059..556d731840 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -61,7 +61,7 @@ class CrfTest(test.TestCase): for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, inputs_list, tag_indices_list): - with self.test_session() as sess: + with self.cached_session() as sess: sequence_score = crf.crf_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_indices=array_ops.expand_dims(tag_indices, 0), @@ -96,7 +96,7 @@ class CrfTest(test.TestCase): ] for sequence_lengths, inputs, tag_bitmap in zip( sequence_lengths_list, inputs_list, tag_bitmap_list): - with self.test_session() as sess: + with self.cached_session() as sess: sequence_score = crf.crf_multitag_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_bitmap=array_ops.expand_dims(tag_bitmap, 0), @@ -124,7 +124,7 @@ class CrfTest(test.TestCase): for dtype in (np.int32, np.int64): tag_indices = np.array([1, 2, 1, 0], dtype=dtype) sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: + with self.cached_session() as sess: unary_score = crf.crf_unary_score( tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), @@ -140,7 +140,7 @@ class CrfTest(test.TestCase): transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: + with self.cached_session() as sess: binary_score = crf.crf_binary_score( tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), @@ -176,7 +176,7 @@ class CrfTest(test.TestCase): tag_indices_list): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] # Compare the dynamic program with brute force computation. @@ -206,7 +206,7 @@ class CrfTest(test.TestCase): """ Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. """ - with self.test_session() as sess: + with self.cached_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], @@ -226,7 +226,7 @@ class CrfTest(test.TestCase): sequence_lengths = np.array(3, dtype=np.int32) num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_log_likelihoods = [] # Make sure all probabilities sum to 1. @@ -254,7 +254,7 @@ class CrfTest(test.TestCase): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] all_sequences = [] @@ -310,7 +310,7 @@ class CrfTest(test.TestCase): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] all_sequences = [] @@ -351,7 +351,7 @@ class CrfTest(test.TestCase): """ Test that crf_decode works when sequence_length contains one or more zeros. """ - with self.test_session() as sess: + with self.cached_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index c7b4e2faa8..be915ef96f 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -49,7 +49,7 @@ class OdeIntTest(test.TestCase): y_solved = odes.odeint(func, y0, t) self.assertIn('odeint', y_solved.name) self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape([11])) - with self.test_session() as sess: + with self.cached_session() as sess: y_solved = sess.run(y_solved) y_true = np.exp(t) self.assertAllClose(y_true, y_solved) @@ -62,7 +62,7 @@ class OdeIntTest(test.TestCase): func = lambda y, t: k * y t = np.linspace(0.0, 1.0, 11) y_solved = odes.odeint(func, 1.0 + 0.0j, t) - with self.test_session() as sess: + with self.cached_session() as sess: y_solved = sess.run(y_solved) y_true = np.exp(k * t) self.assertAllClose(y_true, y_solved) @@ -74,7 +74,7 @@ class OdeIntTest(test.TestCase): func = lambda t, y: (y - t)**2 + 1.0 t = np.linspace(0.0, 1.0, 11) y_solved = odes.odeint(func, np.float64(0.5), t) - with self.test_session() as sess: + with self.cached_session() as sess: y_solved = sess.run(y_solved) y_true = 1.0 / (2.0 - t) + t self.assertAllClose(y_true, y_solved) @@ -96,7 +96,7 @@ class OdeIntTest(test.TestCase): t = np.linspace(0.0, 1.0, 11) y_solved = odes.odeint(func, y0, t) - with self.test_session() as sess: + with self.cached_session() as sess: y_solved = sess.run(y_solved) y_true = np.zeros((len(t), 2, 1)) @@ -113,7 +113,7 @@ class OdeIntTest(test.TestCase): y_solved = odes.odeint(func, array_ops.reshape(y0, shape), t) self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape(expected_shape)) - with self.test_session() as sess: + with self.cached_session() as sess: y_solved = sess.run(y_solved) self.assertEquals(y_solved.shape, expected_shape) @@ -126,7 +126,7 @@ class OdeIntTest(test.TestCase): for t_dtype in [dtypes.float32, dtypes.float64]: y0 = math_ops.cast(1.0, y0_dtype) y_solved = odes.odeint(func, y0, math_ops.cast(t, t_dtype)) - with self.test_session() as sess: + with self.cached_session() as sess: y_solved = sess.run(y_solved) expected = np.asarray(np.exp(t)) self.assertAllClose(y_solved, expected, rtol=1e-5) @@ -148,13 +148,13 @@ class OdeIntTest(test.TestCase): self.y0, [0, 1], method='dopri5', options={'max_num_steps': 0}) - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'max_num_steps'): sess.run(y) y = odes.odeint(self.func, self.y0, [1, 0]) - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'monotonic increasing'): sess.run(y) @@ -164,7 +164,7 @@ class OdeIntTest(test.TestCase): times0 = np.linspace(0, 10, num=11, dtype=float) times1 = np.linspace(0, 10, num=101, dtype=float) - with self.test_session() as sess: + with self.cached_session() as sess: y_solved_0, info_0 = sess.run( odes.odeint(self.func, self.y0, times0, full_output=True)) y_solved_1, info_1 = sess.run( @@ -179,7 +179,7 @@ class OdeIntTest(test.TestCase): t = [0, 20] kwargs = dict( full_output=True, method='dopri5', options=dict(max_num_steps=2000)) - with self.test_session() as sess: + with self.cached_session() as sess: _, info_0 = sess.run( odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs)) _, info_1 = sess.run( @@ -196,7 +196,7 @@ class StepSizeTest(test.TestCase): new_step = odes._optimal_step_size( last_step=constant_op.constant(1.0), error_ratio=constant_op.constant(1.0)) - with self.test_session() as sess: + with self.cached_session() as sess: new_step = sess.run(new_step) self.assertAllClose(new_step, 0.9) @@ -204,7 +204,7 @@ class StepSizeTest(test.TestCase): new_step = odes._optimal_step_size( last_step=constant_op.constant(1.0), error_ratio=constant_op.constant(0.0)) - with self.test_session() as sess: + with self.cached_session() as sess: new_step = sess.run(new_step) self.assertAllClose(new_step, 10.0) @@ -212,7 +212,7 @@ class StepSizeTest(test.TestCase): new_step = odes._optimal_step_size( last_step=constant_op.constant(1.0), error_ratio=constant_op.constant(1e6)) - with self.test_session() as sess: + with self.cached_session() as sess: new_step = sess.run(new_step) self.assertAllClose(new_step, 0.2) @@ -229,13 +229,13 @@ class InterpolationTest(test.TestCase): y_fit = array_ops.stack( [odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times]) y_expected = f(times) - with self.test_session() as sess: + with self.cached_session() as sess: y_actual = sess.run(y_fit) self.assertAllClose(y_expected, y_actual) # attempt interpolation outside bounds y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0) - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors_impl.InvalidArgumentError): sess.run(y_invalid) @@ -251,7 +251,7 @@ class OdeIntFixedTest(test.TestCase): y0 = [0., 1.] y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) - with self.test_session() as sess: + with self.cached_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( @@ -265,7 +265,7 @@ class OdeIntFixedTest(test.TestCase): y0 = [1.] y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) - with self.test_session() as sess: + with self.cached_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py index d50750001e..b6c2cab64a 100644 --- a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py +++ b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py @@ -42,7 +42,7 @@ def _assert_sparse_tensor_value(test_case, expected, actual): class DenseToSparseTensorTest(test.TestCase): def test_dense_to_sparse_tensor_1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: st = sparse_ops.dense_to_sparse_tensor([1, 0, 2, 0]) result = sess.run(st) self.assertEqual(result.indices.dtype, np.int64) @@ -53,7 +53,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([4], result.dense_shape) def test_dense_to_sparse_tensor_1d_float(self): - with self.test_session() as sess: + with self.cached_session() as sess: st = sparse_ops.dense_to_sparse_tensor([1.5, 0.0, 2.3, 0.0]) result = sess.run(st) self.assertEqual(result.indices.dtype, np.int64) @@ -64,7 +64,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([4], result.dense_shape) def test_dense_to_sparse_tensor_1d_bool(self): - with self.test_session() as sess: + with self.cached_session() as sess: st = sparse_ops.dense_to_sparse_tensor([True, False, True, False]) result = sess.run(st) self.assertEqual(result.indices.dtype, np.int64) @@ -75,7 +75,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([4], result.dense_shape) def test_dense_to_sparse_tensor_1d_str(self): - with self.test_session() as sess: + with self.cached_session() as sess: st = sparse_ops.dense_to_sparse_tensor([b'qwe', b'', b'ewq', b'']) result = sess.run(st) self.assertEqual(result.indices.dtype, np.int64) @@ -86,7 +86,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([4], result.dense_shape) def test_dense_to_sparse_tensor_1d_str_special_ignore(self): - with self.test_session() as sess: + with self.cached_session() as sess: st = sparse_ops.dense_to_sparse_tensor( [b'qwe', b'', b'ewq', b''], ignore_value=b'qwe') result = sess.run(st) @@ -98,7 +98,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([4], result.dense_shape) def test_dense_to_sparse_tensor_2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: st = sparse_ops.dense_to_sparse_tensor([[1, 2, 0, 0], [3, 4, 5, 0]]) result = sess.run(st) self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], @@ -107,7 +107,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([2, 4], result.dense_shape) def test_dense_to_sparse_tensor_3d(self): - with self.test_session() as sess: + with self.cached_session() as sess: st = sparse_ops.dense_to_sparse_tensor([[[1, 2, 0, 0], [3, 4, 5, 0]], [[7, 8, 0, 0], [9, 0, 0, 0]]]) result = sess.run(st) @@ -117,7 +117,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([2, 2, 4], result.dense_shape) def test_dense_to_sparse_tensor_unknown_1d_shape(self): - with self.test_session() as sess: + with self.cached_session() as sess: tensor = array_ops.placeholder(shape=[None], dtype=dtypes.int32) st = sparse_ops.dense_to_sparse_tensor(tensor) result = sess.run(st, feed_dict={tensor: [0, 100, 0, 3]}) @@ -126,7 +126,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([4], result.dense_shape) def test_dense_to_sparse_tensor_unknown_3d_shape(self): - with self.test_session() as sess: + with self.cached_session() as sess: tensor = array_ops.placeholder( shape=[None, None, None], dtype=dtypes.int32) st = sparse_ops.dense_to_sparse_tensor(tensor) @@ -142,7 +142,7 @@ class DenseToSparseTensorTest(test.TestCase): def test_dense_to_sparse_unknown_rank(self): ph = array_ops.placeholder(dtype=dtypes.int32) - with self.test_session() as sess: + with self.cached_session() as sess: st = sparse_ops.dense_to_sparse_tensor(ph) result = sess.run(st, feed_dict={ph: [[1, 2, 0, 0], [3, 4, 5, 0]]}) self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], @@ -155,7 +155,7 @@ class SparseRowEnvelopeTest(test.TestCase): def test_sparse_row_envelope(self): expected_sparse_row_envelope = [1, 0, 3] - with self.test_session() as sess: + with self.cached_session() as sess: sparse_input = sparse_tensor.SparseTensor( indices=[[0, 0], [2, 0], [2, 1], [2, 2]], values=[0, 1, 2, 3], @@ -167,7 +167,7 @@ class SparseRowEnvelopeTest(test.TestCase): def test_sparse_row_envelope_unsorted_indices(self): expected_sparse_row_envelope = [1, 0, 3] - with self.test_session() as sess: + with self.cached_session() as sess: sparse_input = sparse_tensor.SparseTensor( indices=[[2, 0], [2, 2], [2, 1], [0, 0]], values=[0, 1, 2, 3], @@ -179,7 +179,7 @@ class SparseRowEnvelopeTest(test.TestCase): def test_sparse_row_envelope_empty_in_the_end(self): expected_sparse_row_envelope = [1, 0, 3, 0, 0] - with self.test_session() as sess: + with self.cached_session() as sess: sparse_input = sparse_tensor.SparseTensor( indices=[[0, 0], [2, 0], [2, 1], [2, 2]], values=[0, 1, 2, 3], @@ -191,7 +191,7 @@ class SparseRowEnvelopeTest(test.TestCase): def test_sparse_row_envelope_empty_3d(self): expected_sparse_row_envelope = [1, 0, 3, 0, 0] - with self.test_session() as sess: + with self.cached_session() as sess: sparse_input = sparse_tensor.SparseTensor( indices=[[0, 0, 0], [0, 2, 0], [0, 2, 1], [0, 2, 2]], values=[0, 1, 2, 3], @@ -207,7 +207,7 @@ class IndicatorToSparseIdsTest(test.TestCase): def test_indicators_to_sparse_ids_1d(self): indicators = (0, 0, 1, 0) sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators) - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=((0,),), values=(2,), @@ -220,7 +220,7 @@ class IndicatorToSparseIdsTest(test.TestCase): (1, 0, 0, 1), ) sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators) - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (1, 1)), values=(2, 0, 3), @@ -235,7 +235,7 @@ class IndicatorToSparseIdsTest(test.TestCase): ((1, 0, 0, 1, 1), (0, 0, 1, 0, 0)), ) sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators) - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=( (0, 0, 0), @@ -255,7 +255,7 @@ class IndicatorToSparseIdsTest(test.TestCase): ) sparse_ids = sparse_ops.indicators_to_sparse_ids( indicators, dtype=dtypes.int16) - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (1, 1)), values=np.array((2, 0, 3), dtype=np.int16), @@ -269,7 +269,7 @@ class IndicatorToSparseIdsTest(test.TestCase): ) sparse_ids = sparse_ops.indicators_to_sparse_ids( indicators, ignore_value=-1) - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), values=(2, 0, 3, 2), @@ -282,7 +282,7 @@ class IndicatorToSparseIdsTest(test.TestCase): (('B', '', '', 'C'), ('', '', 'D', '')), ) sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators) - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), values=(2, 0, 3, 2), @@ -296,7 +296,7 @@ class IndicatorToSparseIdsTest(test.TestCase): ) sparse_ids = sparse_ops.indicators_to_sparse_ids( indicators, ignore_value='x') - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), values=(2, 0, 3, 2), @@ -311,7 +311,7 @@ class IndicatorToSparseIdsTest(test.TestCase): indicators = array_ops.placeholder( dtype=dtypes.int32, shape=(None, None, None)) sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators) - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), values=(2, 0, 3, 2), @@ -325,7 +325,7 @@ class IndicatorToSparseIdsTest(test.TestCase): ) indicators = array_ops.placeholder(dtype=dtypes.int32) sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators) - with self.test_session(): + with self.cached_session(): _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), values=(2, 0, 3, 2), diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py index 423dcce8de..8390ddda90 100644 --- a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py +++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class DecodeLibsvmOpTest(test.TestCase): def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: content = [ "1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503", "2 3:2.5 2:nan 1:0.105" @@ -48,7 +48,7 @@ class DecodeLibsvmOpTest(test.TestCase): [0, 0.105, np.nan, 2.5, 0, 0]]) def testNDimension(self): - with self.test_session() as sess: + with self.cached_session() as sess: content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"], ["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"], ["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]] diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py index 59f537b82a..40a8b5fafb 100644 --- a/tensorflow/contrib/lite/python/convert_test.py +++ b/tensorflow/contrib/lite/python/convert_test.py @@ -188,7 +188,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): return output output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput") - with self.test_session() as sess: + with self.cached_session() as sess: # check if identities have been put into the graph (2 input, 1 output, # and 1 final output). self.assertEqual(self._countIdentities(sess.graph_def.node), 4) @@ -215,7 +215,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b), name="ModelOutput") - with self.test_session() as sess: + with self.cached_session() as sess: # make sure one identity for each input (3) and output (2) => 3 + 2 = 5 # +1 for the final output self.assertEqual(self._countIdentities(sess.graph_def.node), 6) @@ -242,7 +242,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): output = array_ops.identity( math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput") - with self.test_session() as sess: + with self.cached_session() as sess: # make sure one identity for each input (2) and output (2) => 2 + 2 # +1 for the final output self.assertEqual(self._countIdentities(sess.graph_def.node), 5) @@ -279,7 +279,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): aggregate=op_hint.OpHint.AGGREGATE_STACK) res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b)) custom.add_outputs([res]) - with self.test_session(): + with self.cached_session(): self.assertEqual(self._get_input_index(a), 0) self.assertEqual(self._get_sort_index(a), 0) self.assertEqual(self._get_input_index(b), 1) @@ -294,7 +294,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): b = custom.add_input(b) # should auto assign 0 a = custom.add_input(a, index_override=1) c = custom.add_input(c) # should auto assign 2 - with self.test_session(): + with self.cached_session(): self.assertEqual(self._get_input_index(a), 1) self.assertEqual(self._get_input_index(b), 0) self.assertEqual(self._get_input_index(c), 2) @@ -320,7 +320,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): curr = array_ops.stack([c0, c1]) output = array_ops.identity(curr, name="FINAL_OUTPUT") - with self.test_session() as sess: + with self.cached_session() as sess: stubbed_graphdef = op_hint.convert_op_hints_to_stubs( graph_def=sess.graph_def) self.assertCountEqual( diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 0a54bb1f5e..89b538d1ba 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -44,7 +44,7 @@ from tensorflow.python.training.checkpointable import util as checkpointable class HashTableOpTest(test.TestCase): def testHashTable(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -68,7 +68,7 @@ class HashTableOpTest(test.TestCase): self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval()) def testHashTableFindHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -86,7 +86,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([[0, 1], [-1, -1]], result) def testHashTableInitWithPythonArrays(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = ["brain", "salad", "surgery"] values = [0, 1, 2] @@ -105,7 +105,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testHashTableInitWithNumPyArrays(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = np.array(["brain", "salad", "surgery"], dtype=np.str) values = np.array([0, 1, 2], dtype=np.int64) @@ -122,7 +122,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testMultipleHashTables(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -150,7 +150,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], out3) def testHashTableWithTensorDefault(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -165,7 +165,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testHashTableWithSparseTensorInput(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -188,7 +188,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual(sp_shape, out_shape) def testSignatureMismatch(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -210,7 +210,7 @@ class HashTableOpTest(test.TestCase): lookup.KeyValueTensorInitializer(keys, values), "UNK") def testDTypes(self): - with self.test_session(): + with self.cached_session(): default_val = -1 with self.assertRaises(TypeError): lookup.HashTable( @@ -218,7 +218,7 @@ class HashTableOpTest(test.TestCase): dtypes.int64), default_val) def testNotInitialized(self): - with self.test_session(): + with self.cached_session(): default_val = -1 table = lookup.HashTable( lookup.KeyValueTensorInitializer( @@ -232,7 +232,7 @@ class HashTableOpTest(test.TestCase): output.eval() def testInitializeTwice(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -244,7 +244,7 @@ class HashTableOpTest(test.TestCase): table.init.run() def testInitializationWithInvalidDimensions(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) @@ -283,7 +283,7 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual(3, table.size().eval()) def testHashTableInt32String(self): - with self.test_session(): + with self.cached_session(): default_val = "n/a" keys = constant_op.constant([0, 1, 2], dtypes.int32) values = constant_op.constant(["brain", "salad", "surgery"]) @@ -301,7 +301,7 @@ class HashTableOpTest(test.TestCase): class MutableHashTableOpTest(test.TestCase): def testMutableHashTable(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -470,7 +470,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([b"-", b"a", b"b"], output.eval()) def testMutableHashTableOfTensors(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) @@ -500,7 +500,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values) def testMutableHashTableExportInsert(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) @@ -531,7 +531,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual(expected_output, output2.eval()) def testMutableHashTableOfTensorsInvalidShape(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) table = lookup.MutableHashTable(dtypes.string, dtypes.int64, @@ -563,7 +563,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual(3, table.size().eval()) def testMutableHashTableInvalidDefaultValue(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([[-1, -1]], dtypes.int64) table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) @@ -571,7 +571,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual(0, table.size().eval()) def testMutableHashTableDuplicateInsert(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) values = constant_op.constant([0, 1, 2, 3], dtypes.int64) @@ -589,7 +589,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([3, 1, -1], result) def testMutableHashTableFindHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -608,7 +608,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([[0, 1], [-1, -1]], result) def testMutableHashTableInsertHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) @@ -625,7 +625,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, 3, -1], result) def testMutableHashTableOfTensorsFindHighRank(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant([-1, -1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], @@ -646,7 +646,7 @@ class MutableHashTableOpTest(test.TestCase): [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) def testMultipleMutableHashTables(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -676,7 +676,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], out3) def testMutableHashTableWithTensorDefault(self): - with self.test_session(): + with self.cached_session(): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -693,7 +693,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testSignatureMismatch(self): - with self.test_session(): + with self.cached_session(): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) @@ -734,7 +734,7 @@ class MutableHashTableOpTest(test.TestCase): lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK") def testMutableHashTableStringFloat(self): - with self.test_session(): + with self.cached_session(): default_val = -1.5 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) @@ -752,7 +752,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllClose([0, 1.1, default_val], result) def testMutableHashTableIntFloat(self): - with self.test_session(): + with self.cached_session(): default_val = -1.0 keys = constant_op.constant([3, 7, 0], dtypes.int64) values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) @@ -770,7 +770,7 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllClose([-1.2, 9.9, default_val], result) def testMutableHashTableInt64String(self): - with self.test_session(): + with self.cached_session(): default_val = "n/a" keys = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant(["brain", "salad", "surgery"]) @@ -791,7 +791,7 @@ class MutableHashTableOpTest(test.TestCase): class MutableDenseHashTableOpTest(test.TestCase): def testBasic(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -809,7 +809,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testBasicBool(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([True, True, True], dtypes.bool) table = lookup.MutableDenseHashTable( @@ -827,7 +827,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([True, True, False], result) def testLookupUnknownShape(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -843,7 +843,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testMapStringToFloat(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant(["a", "b", "c"], dtypes.string) values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32) default_value = constant_op.constant(-1.5, dtypes.float32) @@ -866,7 +866,7 @@ class MutableDenseHashTableOpTest(test.TestCase): def testMapInt64ToFloat(self): for float_dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0.0, 1.1, 2.2], float_dtype) default_value = constant_op.constant(-1.5, float_dtype) @@ -885,7 +885,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllClose([0, 1.1, -1.5], result) def testVectorValues(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], dtypes.int64) @@ -918,7 +918,7 @@ class MutableDenseHashTableOpTest(test.TestCase): result) def testVectorKeys(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) values = constant_op.constant([10, 11, 12], dtypes.int64) empty_key = constant_op.constant([0, 3], dtypes.int64) @@ -949,7 +949,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([10, 11, -1], result) def testResize(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -977,7 +977,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval()) def testExport(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([1, 2, 3], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -1238,7 +1238,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1, 2, -1], output.eval()) def testReprobe(self): - with self.test_session(): + with self.cached_session(): # Insert 6 keys into a table with 8 buckets. # The values are chosen to make sure collisions occur when using GCC STL keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) @@ -1263,7 +1263,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) def testCustomEmptyKey(self): - with self.test_session(): + with self.cached_session(): keys = constant_op.constant([11, 0, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.MutableDenseHashTable( @@ -1281,7 +1281,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testErrors(self): - with self.test_session(): + with self.cached_session(): table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) @@ -1328,7 +1328,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1339,7 +1339,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file_tensor_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): vocabulary_file = constant_op.constant(vocabulary_file) table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) @@ -1353,7 +1353,7 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file_placeholder_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.test_session(): + with self.cached_session(): vocabulary_placeholder = array_ops.placeholder(dtypes.string, []) table = lookup.index_table_from_file( vocabulary_file=vocabulary_placeholder, num_oov_buckets=1) @@ -1370,7 +1370,7 @@ class IndexTableFromFile(test.TestCase): def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1, key_dtype=dtypes.int32) @@ -1384,7 +1384,7 @@ class IndexTableFromFile(test.TestCase): def test_int64_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab3.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1, key_dtype=dtypes.int64) @@ -1398,7 +1398,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_default_value(self): default_value = -42 vocabulary_file = self._createVocabFile("f2i_vocab4.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1409,7 +1409,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab5.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1000) ids = table.lookup( @@ -1439,7 +1439,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_vocab_size_too_small(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1451,7 +1451,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) self.assertRaisesRegexp(errors_impl.InvalidArgumentError, @@ -1466,7 +1466,7 @@ class IndexTableFromFile(test.TestCase): vocabulary_file=vocabulary_file, vocab_size=0) - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1478,7 +1478,7 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_invalid_hashers(self): vocabulary_file = self._createVocabFile("invalid_hasher.txt") - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): lookup.index_table_from_file( vocabulary_file=vocabulary_file, @@ -1499,21 +1499,21 @@ class IndexTableFromFile(test.TestCase): class KeyValueTensorInitializerTest(test.TestCase): def test_string(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup.KeyValueTensorInitializer( ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) table = lookup.HashTable(init, default_value=-1) table.init.run() def test_int64(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup.KeyValueTensorInitializer( (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64) table = lookup.HashTable(init, default_value=-1) table.init.run() def test_int32(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): init = lookup.KeyValueTensorInitializer( (42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64) table = lookup.HashTable(init, default_value=-1) @@ -1542,7 +1542,7 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int32_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_tensor( mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32) ids = table.lookup( @@ -1553,7 +1553,7 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_tensor( mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) ids = table.lookup( @@ -1565,7 +1565,7 @@ class IndexTableFromTensor(test.TestCase): def test_index_table_from_tensor_with_default_value(self): default_value = -42 - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_tensor( mapping=["brain", "salad", "surgery"], default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1575,12 +1575,12 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_tensor_missing_mapping(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, "mapping must be specified"): lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1) def test_index_table_from_tensor_empty_mapping(self): - with self.test_session(): + with self.cached_session(): table = lookup.index_table_from_tensor( mapping=np.array([], dtype=np.str_), num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"])) @@ -1590,7 +1590,7 @@ class IndexTableFromTensor(test.TestCase): lookup_ops.tables_initializer().run() def test_index_table_from_tensor_with_invalid_hashers(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): lookup.index_table_from_tensor( mapping=["brain", "salad", "surgery"], @@ -1609,7 +1609,7 @@ class IndexTableFromTensor(test.TestCase): class StringToIndexTest(test.TestCase): def test_string_to_index(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) feats = constant_op.constant(["salad", "surgery", "tarkus"]) indices = lookup.string_to_index(feats, mapping=mapping_strings) @@ -1620,7 +1620,7 @@ class StringToIndexTest(test.TestCase): self.assertAllEqual((1, 2, -1), indices.eval()) def test_duplicate_entries(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["hello", "hello"]) feats = constant_op.constant(["hello", "hola"]) _ = lookup.string_to_index(feats, mapping=mapping_strings) @@ -1630,7 +1630,7 @@ class StringToIndexTest(test.TestCase): def test_string_to_index_with_default_value(self): default_value = -42 - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) feats = constant_op.constant(["salad", "surgery", "tarkus"]) indices = lookup.string_to_index( @@ -1651,7 +1651,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table(self): vocabulary_file = self._createVocabFile("i2f_vocab1.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) @@ -1663,7 +1663,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_default_value(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -1675,7 +1675,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size_too_small(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2, @@ -1688,7 +1688,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -1700,7 +1700,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") - with self.test_session(): + with self.cached_session(): table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -1713,7 +1713,7 @@ class IndexToStringTableFromFileTest(test.TestCase): class IndexToStringTableFromTensorTest(test.TestCase): def test_index_to_string_table_from_tensor(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings) @@ -1727,7 +1727,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): features.eval()) def test_duplicate_entries(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["hello", "hello"]) table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings) @@ -1738,7 +1738,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): def test_index_to_string_with_default_value(self): default_value = b"NONE" - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings, default_value=default_value) @@ -1754,7 +1754,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): class IndexToStringTest(test.TestCase): def test_index_to_string(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) feats = lookup.index_to_string(indices, mapping=mapping_strings) @@ -1766,7 +1766,7 @@ class IndexToStringTest(test.TestCase): feats.eval()) def test_duplicate_entries(self): - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["hello", "hello"]) indices = constant_op.constant([0, 1, 4], dtypes.int64) feats = lookup.index_to_string(indices, mapping=mapping_strings) @@ -1778,7 +1778,7 @@ class IndexToStringTest(test.TestCase): def test_index_to_string_with_default_value(self): default_value = b"NONE" - with self.test_session(): + with self.cached_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) indices = constant_op.constant([1, 2, 4], dtypes.int64) feats = lookup.index_to_string( @@ -1818,7 +1818,7 @@ class InitializeTableFromFileOpTest(test.TestCase): vocabulary_file = self._createVocabFile( "one_column_int64.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 table = lookup.HashTable( lookup.TextFileInitializer(vocabulary_file, dtypes.int64, @@ -1837,7 +1837,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInitializeIndexTable(self): vocabulary_file = self._createVocabFile("one_column_2.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" key_index = lookup.TextFileIndex.LINE_NUMBER value_index = lookup.TextFileIndex.WHOLE_LINE @@ -1858,7 +1858,7 @@ class InitializeTableFromFileOpTest(test.TestCase): with open(vocabulary_file, "w") as f: f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 1 value_index = 2 @@ -1880,7 +1880,7 @@ class InitializeTableFromFileOpTest(test.TestCase): with open(vocabulary_file, "w") as f: f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 2 value_index = 1 @@ -1894,7 +1894,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidDataType(self): vocabulary_file = self._createVocabFile("one_column_3.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" key_index = lookup.TextFileIndex.WHOLE_LINE value_index = lookup.TextFileIndex.LINE_NUMBER @@ -1907,7 +1907,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidIndex(self): vocabulary_file = self._createVocabFile("one_column_4.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 key_index = 1 # second column of the line value_index = lookup.TextFileIndex.LINE_NUMBER @@ -1922,7 +1922,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInitializeSameTableWithMultipleNodes(self): vocabulary_file = self._createVocabFile("one_column_5.txt") - with self.test_session() as sess: + with self.cached_session() as sess: shared_name = "shared-one-columm" default_value = -1 table1 = lookup.HashTable( @@ -1961,7 +1961,7 @@ class InitializeTableFromFileOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], out3) def testInitializeTableWithNoFilename(self): - with self.test_session(): + with self.cached_session(): default_value = -1 with self.assertRaises(ValueError): lookup.HashTable( @@ -1971,7 +1971,7 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value) def testInitializeWithVocabSize(self): - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 vocabulary_file1 = self._createVocabFile("one_column6.txt") @@ -2022,7 +2022,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testFeedVocabularyName(self): vocabulary_file = self._createVocabFile("feed_vocabulary.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 table = lookup.HashTable( lookup.TextFileInitializer("old_file.txt", dtypes.string, @@ -2049,7 +2049,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInvalidFilenames(self): vocabulary_file = self._createVocabFile("filename_shape.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 # Invalid data type @@ -2072,7 +2072,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testIdToStringTable(self): vocab_file = self._createVocabFile("feat_to_id_1.txt") - with self.test_session(): + with self.cached_session(): default_value = "UNK" vocab_size = 3 table = lookup.HashTable( @@ -2090,7 +2090,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testStringToIdTable(self): vocab_file = self._createVocabFile("feat_to_id_2.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 table = lookup.HashTable( @@ -2108,7 +2108,7 @@ class InitializeTableFromFileOpTest(test.TestCase): def testInt64ToIdTable(self): vocab_file = self._createVocabFile( "feat_to_id_3.txt", values=("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 table = lookup.HashTable( @@ -2133,7 +2133,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testStringIdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_1.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2154,7 +2154,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt32IdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2176,7 +2176,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt64IdTableWithHashBuckets(self): vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000")) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2196,7 +2196,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(vocab_size + oov_buckets, table.size().eval()) def testStringIdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): oov_buckets = 5 # Set a table that only uses hash buckets, for each input value returns @@ -2217,7 +2217,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(oov_buckets, table.size().eval()) def testInt32IdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): oov_buckets = 5 # Set a table that only uses hash buckets, for each input value returns @@ -2239,20 +2239,20 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertEquals(oov_buckets, table.size().eval()) def testFloat64IdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): lookup.IdTableWithHashBuckets( None, num_oov_buckets=5, key_dtype=dtypes.float64) def testBoolIdTableWithOnlyHashBucket(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"): lookup.IdTableWithHashBuckets( None, num_oov_buckets=5, key_dtype=dtypes.bool) def testIdTableWithHashBucketsWithMultipleInitializers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") - with self.test_session() as sess: + with self.cached_session() as sess: default_value = -1 vocab_size = 3 oov_buckets = 3 @@ -2294,7 +2294,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsInitializationAcrossSessions(self): vocab_file = self._createVocabFile("feat_to_id_5.txt") shared_name = "across-sessions" - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2316,7 +2316,7 @@ class IdTableWithHashBucketsTest(test.TestCase): self.assertAllEqual([0, 1, 2, 3], out1.eval()) self.assertEquals(vocab_size + oov_buckets, table1.size().eval()) - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 @@ -2340,7 +2340,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self): vocab_file = self._createVocabFile("feat_to_id_6.txt") - with self.test_session() as sess: + with self.cached_session() as sess: default_value1 = -1 vocab_size = 3 oov_buckets = 0 @@ -2378,7 +2378,7 @@ class IdTableWithHashBucketsTest(test.TestCase): vocab_file = self._createVocabFile("feat_to_id_7.txt") input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], @@ -2407,7 +2407,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt32SparseTensor(self): input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), @@ -2436,7 +2436,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testInt64SparseTensor(self): input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] input_shape = [4, 4] - with self.test_session() as sess: + with self.cached_session() as sess: sp_features = sparse_tensor.SparseTensor( constant_op.constant(input_indices, dtypes.int64), constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), @@ -2464,7 +2464,7 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsWithInvalidHashers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") - with self.test_session(): + with self.cached_session(): default_value = -1 vocab_size = 3 oov_buckets = 1 diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index 2a442a8fc8..c0aec09778 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -43,68 +43,68 @@ class AbsoluteDifferenceLossTest(test.TestCase): self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) def testValueErrorThrownWhenWeightIsNone(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.absolute_difference( self._predictions, self._predictions, weights=None) def testAllCorrectNoLossWeight(self): loss = loss_ops.absolute_difference(self._predictions, self._predictions) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) def testNonZeroLoss(self): loss = loss_ops.absolute_difference(self._predictions, self._labels) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(5.5, loss.eval(), 3) def testNonZeroLossWithPythonScalarWeight(self): weights = 2.3 loss = loss_ops.absolute_difference(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(5.5 * weights, loss.eval(), 3) def testNonZeroLossWithScalarTensorWeight(self): weights = 2.3 loss = loss_ops.absolute_difference(self._predictions, self._labels, constant_op.constant(weights)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(5.5 * weights, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeights(self): weights = constant_op.constant([1.2, 0.0], shape=[2,]) loss = loss_ops.absolute_difference(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(5.6, loss.eval(), 3) def testNonZeroLossWithTwoDimBatchSpecificWeights(self): weights = constant_op.constant([1.2, 0.0], shape=[2, 1]) loss = loss_ops.absolute_difference(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(5.6, loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeights(self): weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) loss = loss_ops.absolute_difference(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(16.6, loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeightsMostZero(self): weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) loss = loss_ops.absolute_difference(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(6.0, loss.eval(), 3) def testLossWithSampleSpecificWeightsAllZero(self): weights = array_ops.zeros((2, 3)) loss = loss_ops.absolute_difference(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) @@ -117,12 +117,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.softmax_cross_entropy(logits, labels, weights=None) def testAllCorrect(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) @@ -141,7 +141,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): [1, 0, 0], [0, 1, 0]]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.softmax_cross_entropy(logits, labels) self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value') self.assertAlmostEqual(loss.eval(), 10.0, 3) @@ -154,7 +154,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): [1, 0, 0], [0, 1, 0]]) weights = 2.3 - with self.test_session(): + with self.cached_session(): loss = loss_ops.softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) @@ -166,7 +166,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): [1, 0, 0], [0, 1, 0]]) weights = 2.3 - with self.test_session(): + with self.cached_session(): loss = loss_ops.softmax_cross_entropy(logits, labels, constant_op.constant(weights)) self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) @@ -179,7 +179,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): [1, 0, 0], [0, 1, 0]]) weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) @@ -191,7 +191,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): [1, 0, 0], [0, 1, 0]]) weights = constant_op.constant([0, 0, 0], shape=[3]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual(0.0, loss.eval(), 3) @@ -203,12 +203,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): [1, 0, 0], [0, 1, 0]]) weights = constant_op.constant([1.2, 0, 0], shape=[3]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual(12.0, loss.eval(), 3) def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0], [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) @@ -223,7 +223,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): loss_ops.softmax_cross_entropy(logits, labels, weights=weights).eval() def testSoftmaxLabelSmoothing(self): - with self.test_session(): + with self.cached_session(): # Softmax Cross Entropy Loss is: # -\sum_i p_i \log q_i # where for a softmax activation @@ -253,7 +253,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): weights = [2.3, 2.4, 2.5] weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None]) loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, {weights_placeholder: weights}) self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) @@ -268,7 +268,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): weights_placeholder = array_ops.placeholder( dtypes.float32, shape=[None, None]) loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, {weights_placeholder: weights}) self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) @@ -280,12 +280,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) labels = constant_op.constant([[0], [1], [2]]) - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.sparse_softmax_cross_entropy(logits, labels, weights=None) def testAllCorrectInt32Labels(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) @@ -295,7 +295,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): self.assertAlmostEqual(loss.eval(), 0.0, 3) def testAllCorrectInt64Labels(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) @@ -305,7 +305,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): self.assertAlmostEqual(loss.eval(), 0.0, 3) def testAllCorrectNonColumnLabels(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) @@ -320,7 +320,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32) - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy(logits, labels) self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') self.assertAlmostEqual(loss.eval(), 10.0, 3) @@ -331,7 +331,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64) - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy(logits, labels) self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') self.assertAlmostEqual(loss.eval(), 10.0, 3) @@ -342,7 +342,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([2, 0, 1]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy(logits, labels) self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') self.assertAlmostEqual(loss.eval(), 10.0, 3) @@ -353,7 +353,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([[2], [0], [1]]) weights = 2.3 - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) @@ -363,7 +363,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([[2], [0], [1]]) weights = 2.3 - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy( logits, labels, constant_op.constant(weights)) self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) @@ -374,7 +374,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([[2], [0], [1]]) weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) @@ -384,7 +384,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([[2], [0], [1]]) weights = constant_op.constant([[1.2], [3.4], [5.6]]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) @@ -394,7 +394,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([[2], [0], [1]]) weights = constant_op.constant([0, 0, 0], shape=[3]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual(0.0, loss.eval(), 3) @@ -404,12 +404,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): [0.0, 0.0, 10.0]]) labels = constant_op.constant([[2], [0], [1]]) weights = constant_op.constant([1.2, 0, 0], shape=[3]) - with self.test_session(): + with self.cached_session(): loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights) self.assertAlmostEqual(12.0, loss.eval(), 3) def testMeasurementSpecificWeightsRaisesException(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0], [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) @@ -422,7 +422,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): def testInconsistentWeightSizeRaisesException(self): """The weight tensor has incorrect number of elements.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0], [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) @@ -435,7 +435,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): def testInconsistentLabelSizeRaisesException(self): """The label tensor has incorrect number of elements.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0], [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) @@ -448,7 +448,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): def testInconsistentWeightShapeRaisesException(self): """The weight tensor has incorrect shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0], [-100.0, 100.0, -100.0, -100.0], [-100.0, -100.0, 100.0, -100.0], @@ -462,7 +462,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): def testInconsistentLabelShapeRaisesException(self): """The label tensor has incorrect shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0], [-100.0, 100.0, -100.0, -100.0], [-100.0, -100.0, 100.0, -100.0], @@ -484,7 +484,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): dtypes.float32, shape=[None]) loss = loss_ops.sparse_softmax_cross_entropy( logits, labels, weights_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, {weights_placeholder: weights}) self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) @@ -498,7 +498,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): dtypes.float32, shape=[None, None]) loss = loss_ops.sparse_softmax_cross_entropy( logits, labels, weights_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, {weights_placeholder: weights}) self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) @@ -506,7 +506,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): class SigmoidCrossEntropyLossTest(test.TestCase): def testAllCorrectSigmoid(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0], [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) @@ -522,7 +522,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase): loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={ logits: np.ones((32, 1)), @@ -537,7 +537,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase): loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={ logits: np.ones((32, 2)), @@ -546,7 +546,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase): self.assertAlmostEqual(0.313, loss, 3) def testAllWrongSigmoid(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0], [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) @@ -558,7 +558,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase): self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3) def testAllWrongSigmoidWithMeasurementSpecificWeights(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0], [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) @@ -582,11 +582,11 @@ class SigmoidCrossEntropyLossTest(test.TestCase): loss = loss_ops.sigmoid_cross_entropy(logits, labels) self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(loss.eval(), 0.0, 3) def testSigmoidLabelSmoothingCorrect(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[100.0, -100.0, -100.0]]) labels = constant_op.constant([[1, 0, 1]]) # Sigmoid cross entropy loss is: @@ -608,7 +608,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase): self.assertAlmostEqual(loss.eval(), expected_value, 3) def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self): - with self.test_session(): + with self.cached_session(): label_smoothing = 0.1 sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]]) sigmoid_labels = constant_op.constant([[1, 0, 1]]) @@ -641,33 +641,33 @@ class LogLossTest(test.TestCase): self._labels = constant_op.constant(labels) def testValueErrorThrownWhenWeightIsNone(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.log_loss(self._labels, self._labels, weights=None) def testAllCorrectNoLossWeight(self): loss = loss_ops.log_loss(self._labels, self._labels) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) def testAllCorrectNoLossWeightWithPlaceholder(self): tf_predictions = array_ops.placeholder( dtypes.float32, shape=self._np_labels.shape) loss = loss_ops.log_loss(tf_predictions, self._labels) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual( 0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3) def testNonZeroLoss(self): loss = loss_ops.log_loss(self._predictions, self._labels) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0, loss.eval(), 3) def testNonZeroLossWithPythonScalarWeight(self): weights = 2.3 loss = loss_ops.log_loss(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, loss.eval(), 3) @@ -675,7 +675,7 @@ class LogLossTest(test.TestCase): weights = 2.3 loss = loss_ops.log_loss(self._predictions, self._labels, constant_op.constant(weights)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, loss.eval(), 3) @@ -685,7 +685,7 @@ class LogLossTest(test.TestCase): weights = 2.3 loss = loss_ops.log_loss(tf_predictions, self._labels, constant_op.constant(weights)) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, loss, 3) @@ -695,7 +695,7 @@ class LogLossTest(test.TestCase): weights = 2.3 loss = loss_ops.log_loss(tf_predictions, self._labels, constant_op.constant(weights)) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, loss, 3) @@ -706,7 +706,7 @@ class LogLossTest(test.TestCase): self._expected_losses, np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3))) loss = loss_ops.log_loss(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self): @@ -715,7 +715,7 @@ class LogLossTest(test.TestCase): np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape( (2, 3))) loss = loss_ops.log_loss(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3) def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self): @@ -724,12 +724,12 @@ class LogLossTest(test.TestCase): np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape( (2, 3))) loss = loss_ops.log_loss(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3) def testWeightsWithSameNumDimsButWrongShapeThrowsException(self): weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4]) - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.log_loss(self._predictions, self._labels, weights) @@ -742,7 +742,7 @@ class LogLossTest(test.TestCase): self._labels, constant_op.constant( weights, shape=(2, 3))) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3) def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self): @@ -756,7 +756,7 @@ class LogLossTest(test.TestCase): constant_op.constant( weights, shape=(2, 3))) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3) @@ -769,7 +769,7 @@ class LogLossTest(test.TestCase): self._labels, constant_op.constant( weights, shape=(2, 3))) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self): @@ -780,35 +780,35 @@ class LogLossTest(test.TestCase): tf_weights = constant_op.constant(weights, shape=(2, 3)) loss = loss_ops.log_loss(tf_predictions, self._labels, tf_weights) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) self.assertAlmostEqual(-np.sum(expected_losses), loss, 3) def testLossWithSampleSpecificWeightsAllZero(self): tf_weights = array_ops.zeros(shape=(2, 3)) loss = loss_ops.log_loss(self._predictions, self._labels, tf_weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) class HingeLossTest(test.TestCase): def testIncompatibleShapes(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[-1.0], [2.1]]) labels = constant_op.constant([0.0, 1.0]) with self.assertRaises(ValueError): _ = loss_ops.hinge_loss(logits, labels).eval() def testAllOutsideMargin(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([1.2, -1.4, -1.0, 2.1]) labels = constant_op.constant([1.0, 0.0, 0.0, 1.0]) loss = loss_ops.hinge_loss(logits, labels) self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3) def testSomeInsideMargin(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]]) labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]]) loss = loss_ops.hinge_loss(logits, labels) @@ -817,7 +817,7 @@ class HingeLossTest(test.TestCase): self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3) def testSomeMisclassified(self): - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]]) labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]]) loss = loss_ops.hinge_loss(logits, labels) @@ -834,62 +834,62 @@ class MeanSquaredErrorTest(test.TestCase): self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) def testValueErrorThrownWhenWeightIsNone(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.mean_squared_error( self._predictions, self._predictions, weights=None) def testAllCorrectNoLossWeight(self): loss = loss_ops.mean_squared_error(self._predictions, self._predictions) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) def testNonZeroLoss(self): loss = loss_ops.mean_squared_error(self._predictions, self._labels) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(49.5, loss.eval(), 3) def testNonZeroLossWithPythonScalarWeight(self): weights = 2.3 loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(49.5 * weights, loss.eval(), 3) def testNonZeroLossWithScalarTensorWeight(self): weights = 2.3 loss = loss_ops.mean_squared_error(self._predictions, self._labels, constant_op.constant(weights)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(49.5 * weights, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeights(self): weights = constant_op.constant([1.2, 3.4], shape=[2,]) loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3) def testNonZeroLossWithTwoDimBatchSpecificWeights(self): weights = constant_op.constant([1.2, 3.4], shape=[2, 1]) loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeights(self): weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(587 / 5.0, loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeightsMostZero(self): weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(18.0, loss.eval(), 3) def testLossWithSampleSpecificWeightsAllZero(self): weights = array_ops.zeros((2, 3)) loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) @@ -914,7 +914,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): self._expected_losses = np.divide(total, 9.0) def testValueErrorThrownWhenWeightIsNone(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.mean_pairwise_squared_error( predictions=constant_op.constant(self._labels), @@ -925,14 +925,14 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): loss = loss_ops.mean_pairwise_squared_error( predictions=constant_op.constant(self._labels), labels=constant_op.constant(self._labels)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) def testNonZeroLoss(self): loss = loss_ops.mean_pairwise_squared_error( predictions=constant_op.constant(self._predictions), labels=constant_op.constant(self._labels)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3) def testGradientWithZeroWeight(self): @@ -954,7 +954,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): init_op = variables.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for grad, _ in gradients_to_variables: np_grad = sess.run(grad) @@ -966,7 +966,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): predictions=constant_op.constant(self._predictions), labels=constant_op.constant(self._labels), weights=weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss.eval(), 3) @@ -976,7 +976,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): predictions=constant_op.constant(self._predictions), labels=constant_op.constant(self._labels), weights=constant_op.constant(weights)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss.eval(), 3) @@ -986,7 +986,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): predictions=constant_op.constant(self._predictions), labels=constant_op.constant(self._labels), weights=constant_op.constant(weights)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0, loss.eval(), 3) def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self): @@ -998,7 +998,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): predictions=tf_predictions, labels=tf_labels, weights=constant_op.constant(weights)) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={ tf_predictions: self._predictions, @@ -1015,7 +1015,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): labels=constant_op.constant(self._labels), weights=constant_op.constant( weights, shape=[2])) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3) def testZeroLossWithOneDimBatchZeroWeights(self): @@ -1025,7 +1025,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): labels=constant_op.constant(self._labels), weights=constant_op.constant( weights, shape=[2])) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self): @@ -1041,7 +1041,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): weights=constant_op.constant( weights, shape=[2])) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={ tf_predictions: self._predictions, @@ -1056,7 +1056,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): labels=constant_op.constant(self._labels), weights=constant_op.constant( weights, shape=[2])) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) def testLossIsAssociativeAcrossBatchElements(self): @@ -1087,7 +1087,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): predictions=array_ops.concat([predictions0, predictions1], 0), labels=array_ops.concat([labels0, labels1], 0)) - with self.test_session() as session: + with self.cached_session() as session: loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1]) self.assertTrue(loss0 > 0) @@ -1115,7 +1115,7 @@ class CosineDistanceLossTest(test.TestCase): [0, 1, 0]]).reshape((3, 2, 3)) def testValueErrorThrownWhenWeightIsNone(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.cosine_distance( predictions=constant_op.constant(self._labels), @@ -1128,7 +1128,7 @@ class CosineDistanceLossTest(test.TestCase): predictions=constant_op.constant(self._labels), labels=constant_op.constant(self._labels), dim=2) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(0, loss.eval(), 5) def testPartiallyCorrectWithIntegerValues(self): @@ -1136,7 +1136,7 @@ class CosineDistanceLossTest(test.TestCase): predictions=constant_op.constant(self._predictions), labels=constant_op.constant(self._labels), dim=2) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(1, loss.eval(), 5) def testPartiallyCorrectFloatingPointValues(self): @@ -1154,7 +1154,7 @@ class CosineDistanceLossTest(test.TestCase): labels, shape=(3, 1, 3), dtype=dtypes.float32) loss = loss_ops.cosine_distance(tf_preds, tf_labels, dim=2) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(1.0, loss.eval(), 5) def testSampleSpecificWeights(self): @@ -1163,7 +1163,7 @@ class CosineDistanceLossTest(test.TestCase): labels=constant_op.constant(self._labels), dim=2, weights=constant_op.constant([1, 0, 0])) - with self.test_session(): + with self.cached_session(): self.assertEqual(1.0, loss.eval()) def testMeasurementSpecificWeights(self): @@ -1173,12 +1173,12 @@ class CosineDistanceLossTest(test.TestCase): dim=2, weights=constant_op.constant( [1, 0, 0, 1, 1, 1], shape=(3, 2))) - with self.test_session(): + with self.cached_session(): self.assertEqual(3.0 / 4.0, loss.eval()) def testValueErrorThrownWithShapelessPlaceholder(self): tf_predictions = array_ops.placeholder(dtypes.float32) - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): loss_ops.cosine_distance( predictions=tf_predictions, @@ -1196,7 +1196,7 @@ class CosineDistanceLossTest(test.TestCase): dim=2, weights=constant_op.constant( [1, 0, 0, 1, 1, 1], shape=(3, 2))) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._predictions}) self.assertEqual(3.0 / 4.0, loss) @@ -1206,7 +1206,7 @@ class CosineDistanceLossTest(test.TestCase): labels=constant_op.constant(self._labels), dim=2, weights=array_ops.zeros((3,))) - with self.test_session(): + with self.cached_session(): self.assertEqual(0, loss.eval()) def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self): @@ -1215,7 +1215,7 @@ class CosineDistanceLossTest(test.TestCase): labels=constant_op.constant(self._labels), dim=2, weights=array_ops.zeros((3, 2))) - with self.test_session(): + with self.cached_session(): self.assertEqual(0, loss.eval()) @@ -1228,7 +1228,7 @@ class ComputeWeightedLossTest(test.TestCase): self.assertFalse(loss_ops.get_losses()) loss = loss_ops.compute_weighted_loss(losses) self.assertTrue(loss_ops.get_losses()) - with self.test_session(): + with self.cached_session(): self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3) self.assertAllClose(loss.eval(), 3.5 / 4.0, atol=1e-3) @@ -1243,7 +1243,7 @@ class AddLossTest(test.TestCase): loss_ops.add_loss(math_ops.reduce_mean(losses)) self.assertTrue(loss_ops.get_losses()) total_loss = loss_ops.get_total_loss() - with self.test_session(): + with self.cached_session(): self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3) self.assertAllClose(total_loss.eval(), 3.5 / 4.0, atol=1e-3) @@ -1254,7 +1254,7 @@ class AddLossTest(test.TestCase): self.assertFalse(loss_ops.get_losses()) loss_ops.add_loss(math_ops.reduce_mean(losses), loss_collection=None) self.assertFalse(loss_ops.get_losses()) - with self.test_session(): + with self.cached_session(): self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3) def testNoCollectLosses(self): diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py index 7acfc383eb..5777e64c29 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py @@ -47,7 +47,7 @@ class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase): # code used float32 for accumulation. num_updates = 71 - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_updates): sess.run(update_op) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 024bd54912..955b83b44d 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -178,7 +178,7 @@ class StreamingMeanTest(test.TestCase): self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -195,7 +195,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual(1.65, sess.run(mean), 5) def testUpdateOpsReturnsCurrentValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -216,7 +216,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual(1.65, sess.run(mean), 5) def test1dWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -243,7 +243,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5) def test1dWeightedValues_placeholders(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) values = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -265,7 +265,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5) def test2dWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -292,7 +292,7 @@ class StreamingMeanTest(test.TestCase): self.assertAlmostEqual((0 + 1 - 4.2 + 0) / 4.0, mean.eval(), 5) def test2dWeightedValues_placeholders(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) values = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -337,7 +337,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -354,7 +354,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean)) def testMultiDimensional(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2)) _enqueue_vector( @@ -375,7 +375,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean)) def testUpdateOpsReturnsCurrentValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -396,7 +396,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5) def testWeighted1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -423,7 +423,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5) def testWeighted2d_1(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -450,7 +450,7 @@ class StreamingMeanTensorTest(test.TestCase): self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5) def testWeighted2d_2(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -526,7 +526,7 @@ class StreamingAccuracyTest(test.TestCase): (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=2) accuracy, update_op = metrics.streaming_accuracy(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -539,7 +539,7 @@ class StreamingAccuracyTest(test.TestCase): self.assertEqual(initial_accuracy, accuracy.eval()) def testMultipleUpdates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) @@ -569,7 +569,7 @@ class StreamingAccuracyTest(test.TestCase): def testEffectivelyEquivalentSizes(self): predictions = array_ops.ones((40, 1)) labels = array_ops.ones((40,)) - with self.test_session() as sess: + with self.cached_session() as sess: accuracy, update_op = metrics.streaming_accuracy(predictions, labels) sess.run(variables.local_variables_initializer()) @@ -583,7 +583,7 @@ class StreamingAccuracyTest(test.TestCase): weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]), 1) # shape 3, 1 - with self.test_session() as sess: + with self.cached_session() as sess: accuracy, update_op = metrics.streaming_accuracy(predictions, labels, weights) @@ -604,7 +604,7 @@ class StreamingAccuracyTest(test.TestCase): dtype=dtypes_lib.int32, name='weights') feed_dict = {weights_placeholder: weights} - with self.test_session() as sess: + with self.cached_session() as sess: accuracy, update_op = metrics.streaming_accuracy(predictions, labels, weights_placeholder) @@ -616,7 +616,7 @@ class StreamingAccuracyTest(test.TestCase): self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95) def testMultipleUpdatesWithWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) @@ -681,7 +681,7 @@ class StreamingTruePositivesTest(test.TestCase): tp, tp_update_op = metrics.streaming_true_positives( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, tp.eval()) self.assertEqual(1, tp_update_op.eval()) @@ -698,7 +698,7 @@ class StreamingTruePositivesTest(test.TestCase): tp, tp_update_op = metrics.streaming_true_positives( predictions, labels, weights=37.0) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, tp.eval()) self.assertEqual(37.0, tp_update_op.eval()) @@ -732,7 +732,7 @@ class StreamingFalseNegativesTest(test.TestCase): fn, fn_update_op = metrics.streaming_false_negatives( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, fn.eval()) self.assertEqual(2, fn_update_op.eval()) @@ -749,7 +749,7 @@ class StreamingFalseNegativesTest(test.TestCase): fn, fn_update_op = metrics.streaming_false_negatives( predictions, labels, weights=((3.0,), (5.0,), (7.0,))) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, fn.eval()) self.assertEqual(8.0, fn_update_op.eval()) @@ -783,7 +783,7 @@ class StreamingFalsePositivesTest(test.TestCase): fp, fp_update_op = metrics.streaming_false_positives( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, fp.eval()) self.assertEqual(4, fp_update_op.eval()) @@ -803,7 +803,7 @@ class StreamingFalsePositivesTest(test.TestCase): weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0, 29.0, 31.0))) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, fp.eval()) self.assertEqual(42.0, fp_update_op.eval()) @@ -837,7 +837,7 @@ class StreamingTrueNegativesTest(test.TestCase): tn, tn_update_op = metrics.streaming_true_negatives( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, tn.eval()) self.assertEqual(5, tn_update_op.eval()) @@ -854,7 +854,7 @@ class StreamingTrueNegativesTest(test.TestCase): tn, tn_update_op = metrics.streaming_true_negatives( predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, tn.eval()) self.assertEqual(15.0, tn_update_op.eval()) @@ -879,7 +879,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase): tp, tp_update_op = metrics.streaming_true_positives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0, 0, 0), tp.eval()) self.assertAllEqual((3, 1, 0), tp_update_op.eval()) @@ -892,7 +892,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase): tp, tp_update_op = metrics.streaming_true_positives_at_thresholds( predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0.0, 0.0, 0.0), tp.eval()) self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval()) @@ -921,7 +921,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase): fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0, 0, 0), fn.eval()) self.assertAllEqual((0, 2, 3), fn_update_op.eval()) @@ -937,7 +937,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase): weights=((3.0,), (5.0,), (7.0,)), thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0.0, 0.0, 0.0), fn.eval()) self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval()) @@ -962,7 +962,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase): fp, fp_update_op = metrics.streaming_false_positives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0, 0, 0), fp.eval()) self.assertAllEqual((7, 4, 2), fp_update_op.eval()) @@ -979,7 +979,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase): 29.0, 31.0)), thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0.0, 0.0, 0.0), fp.eval()) self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval()) @@ -1004,7 +1004,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase): tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0, 0, 0), tn.eval()) self.assertAllEqual((2, 5, 7), tn_update_op.eval()) @@ -1020,7 +1020,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase): weights=((0.0, 2.0, 3.0, 5.0),), thresholds=(0.15, 0.5, 0.85)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAllEqual((0.0, 0.0, 0.0), tn.eval()) self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval()) @@ -1062,7 +1062,7 @@ class StreamingPrecisionTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1081,7 +1081,7 @@ class StreamingPrecisionTest(test.TestCase): labels = constant_op.constant(inputs) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(1, sess.run(update_op)) self.assertAlmostEqual(1, precision.eval()) @@ -1091,7 +1091,7 @@ class StreamingPrecisionTest(test.TestCase): labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.5, update_op.eval()) self.assertAlmostEqual(0.5, precision.eval()) @@ -1102,7 +1102,7 @@ class StreamingPrecisionTest(test.TestCase): precision, update_op = metrics.streaming_precision( predictions, labels, weights=constant_op.constant([[2], [5]])) - with self.test_session(): + with self.cached_session(): variables.local_variables_initializer().run() weighted_tp = 2.0 + 5.0 weighted_positives = (2.0 + 2.0) + (5.0 + 5.0) @@ -1120,7 +1120,7 @@ class StreamingPrecisionTest(test.TestCase): precision, update_op = metrics.streaming_precision( predictions, labels, weights=constant_op.constant([[2], [5]])) - with self.test_session(): + with self.cached_session(): variables.local_variables_initializer().run() weighted_tp = 2.0 + 5.0 weighted_positives = (2.0 + 2.0) + (5.0 + 5.0) @@ -1138,7 +1138,7 @@ class StreamingPrecisionTest(test.TestCase): labels, weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])) - with self.test_session(): + with self.cached_session(): variables.local_variables_initializer().run() weighted_tp = 3.0 + 4.0 weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) @@ -1158,7 +1158,7 @@ class StreamingPrecisionTest(test.TestCase): labels, weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])) - with self.test_session(): + with self.cached_session(): variables.local_variables_initializer().run() weighted_tp = 3.0 + 4.0 weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) @@ -1175,7 +1175,7 @@ class StreamingPrecisionTest(test.TestCase): labels = constant_op.constant(1 - inputs) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertAlmostEqual(0, precision.eval()) @@ -1185,7 +1185,7 @@ class StreamingPrecisionTest(test.TestCase): labels = constant_op.constant([0, 0, 0, 0]) precision, update_op = metrics.streaming_precision(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0.0, precision.eval()) @@ -1227,7 +1227,7 @@ class StreamingRecallTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1246,7 +1246,7 @@ class StreamingRecallTest(test.TestCase): labels = constant_op.constant(np_inputs) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(1, recall.eval()) @@ -1256,7 +1256,7 @@ class StreamingRecallTest(test.TestCase): labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.5, update_op.eval()) self.assertAlmostEqual(0.5, recall.eval()) @@ -1268,7 +1268,7 @@ class StreamingRecallTest(test.TestCase): recall, update_op = metrics.streaming_recall( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_tp = 2.0 + 5.0 weighted_t = (2.0 + 2.0) + (5.0 + 5.0) @@ -1283,7 +1283,7 @@ class StreamingRecallTest(test.TestCase): recall, update_op = metrics.streaming_recall( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_tp = 3.0 + 1.0 weighted_t = (2.0 + 3.0) + (4.0 + 1.0) @@ -1298,7 +1298,7 @@ class StreamingRecallTest(test.TestCase): labels = constant_op.constant(1 - np_inputs) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, recall.eval()) @@ -1308,7 +1308,7 @@ class StreamingRecallTest(test.TestCase): labels = array_ops.zeros((1, 4)) recall, update_op = metrics.streaming_recall(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, recall.eval()) @@ -1350,7 +1350,7 @@ class StreamingFPRTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1369,7 +1369,7 @@ class StreamingFPRTest(test.TestCase): labels = constant_op.constant(np_inputs) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, fpr.eval()) @@ -1379,7 +1379,7 @@ class StreamingFPRTest(test.TestCase): labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.5, update_op.eval()) self.assertAlmostEqual(0.5, fpr.eval()) @@ -1391,7 +1391,7 @@ class StreamingFPRTest(test.TestCase): fpr, update_op = metrics.streaming_false_positive_rate( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_fp = 2.0 + 5.0 weighted_f = (2.0 + 2.0) + (5.0 + 5.0) @@ -1406,7 +1406,7 @@ class StreamingFPRTest(test.TestCase): fpr, update_op = metrics.streaming_false_positive_rate( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_fp = 1.0 + 3.0 weighted_f = (1.0 + 4.0) + (2.0 + 3.0) @@ -1421,7 +1421,7 @@ class StreamingFPRTest(test.TestCase): labels = constant_op.constant(1 - np_inputs) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(1, fpr.eval()) @@ -1431,7 +1431,7 @@ class StreamingFPRTest(test.TestCase): labels = array_ops.ones((1, 4)) fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, fpr.eval()) @@ -1473,7 +1473,7 @@ class StreamingFNRTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1492,7 +1492,7 @@ class StreamingFNRTest(test.TestCase): labels = constant_op.constant(np_inputs) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, fnr.eval()) @@ -1502,7 +1502,7 @@ class StreamingFNRTest(test.TestCase): labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.5, update_op.eval()) self.assertAlmostEqual(0.5, fnr.eval()) @@ -1514,7 +1514,7 @@ class StreamingFNRTest(test.TestCase): fnr, update_op = metrics.streaming_false_negative_rate( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_fn = 2.0 + 5.0 weighted_t = (2.0 + 2.0) + (5.0 + 5.0) @@ -1529,7 +1529,7 @@ class StreamingFNRTest(test.TestCase): fnr, update_op = metrics.streaming_false_negative_rate( predictions, labels, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) weighted_fn = 2.0 + 4.0 weighted_t = (2.0 + 3.0) + (1.0 + 4.0) @@ -1544,7 +1544,7 @@ class StreamingFNRTest(test.TestCase): labels = constant_op.constant(1 - np_inputs) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(1, fnr.eval()) @@ -1554,7 +1554,7 @@ class StreamingFNRTest(test.TestCase): labels = array_ops.zeros((1, 4)) fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) self.assertEqual(0, fnr.eval()) @@ -1599,7 +1599,7 @@ class StreamingCurvePointsTest(test.TestCase): points, update_op = metric_ops.streaming_curve_points( labels, predictions=predictions, curve=curve) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) sess.run(update_op) @@ -1615,7 +1615,7 @@ class StreamingCurvePointsTest(test.TestCase): self._testValueTensorIsIdempotent(curve='PR') def _testCase(self, labels, predictions, curve, expected_points): - with self.test_session() as sess: + with self.cached_session() as sess: predictions_tensor = constant_op.constant( predictions, dtype=dtypes_lib.float32) labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32) @@ -1717,7 +1717,7 @@ class StreamingAUCTest(test.TestCase): (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_auc(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -1730,7 +1730,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(initial_auc, auc.eval(), 5) def testPredictionsOutOfRange(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -1744,7 +1744,7 @@ class StreamingAUCTest(test.TestCase): def allCorrectAsExpected(self, curve): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) auc, update_op = metrics.streaming_auc(predictions, labels, curve=curve) @@ -1755,7 +1755,7 @@ class StreamingAUCTest(test.TestCase): self.assertEqual(1, auc.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -1767,7 +1767,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.5, auc.eval()) def testWeighted1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -1781,7 +1781,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.5, auc.eval(), 5) def testWeighted2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -1795,7 +1795,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.7, auc.eval(), 5) def testAUCPRSpecialCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4)) @@ -1807,7 +1807,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) def testAnotherAUCPRSpecialCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], shape=(1, 7), @@ -1821,7 +1821,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) def testThirdAUCPRSpecialCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], shape=(1, 7), @@ -1837,7 +1837,7 @@ class StreamingAUCTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) auc, update_op = metrics.streaming_auc(predictions, labels) @@ -1848,7 +1848,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(0, auc.eval()) def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) auc, update_op = metrics.streaming_auc(predictions, labels) @@ -1859,7 +1859,7 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(1, auc.eval(), 6) def testRecallOneAndPrecisionOneGivesOnePRAUC(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.ones([4], dtype=dtypes_lib.float32) labels = array_ops.ones([4]) auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') @@ -1893,7 +1893,7 @@ class StreamingAUCTest(test.TestCase): np.random.exponential(scale=1.0, size=num_samples)): expected_auc = _np_auc(predictions, labels, weights) - with self.test_session() as sess: + with self.cached_session() as sess: enqueue_ops = [[] for i in range(num_batches)] tf_predictions = _enqueue_as_batches(predictions, enqueue_ops) tf_labels = _enqueue_as_batches(labels, enqueue_ops) @@ -1966,7 +1966,7 @@ class StreamingDynamicAUCTest(test.TestCase): labels = random_ops.random_uniform( (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. for _ in xrange(10): @@ -1977,7 +1977,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(initial_auc, auc.eval(), 5) def testAllLabelsOnes(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1., 1., 1.]) labels = constant_op.constant([1, 1, 1]) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -1986,7 +1986,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertEqual(0, auc.eval()) def testAllLabelsZeros(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1., 1., 1.]) labels = constant_op.constant([0, 0, 0]) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -1995,7 +1995,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertEqual(0, auc.eval()) def testNonZeroOnePredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32) labels = constant_op.constant([1, 0, 1, 0]) @@ -2006,7 +2006,7 @@ class StreamingDynamicAUCTest(test.TestCase): def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs) labels = constant_op.constant(inputs) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2015,7 +2015,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertEqual(1, auc.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0, 1, 0]) labels = constant_op.constant([0, 1, 1, 0]) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2025,7 +2025,7 @@ class StreamingDynamicAUCTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2034,7 +2034,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(0, auc.eval()) def testExceptionOnIncompatibleShapes(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.ones([5]) labels = array_ops.zeros([6]) with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'): @@ -2043,7 +2043,7 @@ class StreamingDynamicAUCTest(test.TestCase): sess.run(update_op) def testExceptionOnGreaterThanOneLabel(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32) labels = constant_op.constant([2, 1, 0]) _, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2054,7 +2054,7 @@ class StreamingDynamicAUCTest(test.TestCase): sess.run(update_op) def testExceptionOnNegativeLabel(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32) labels = constant_op.constant([1, 0, -1]) _, update_op = metrics.streaming_dynamic_auc(labels, predictions) @@ -2078,7 +2078,7 @@ class StreamingDynamicAUCTest(test.TestCase): collections=[ops.GraphKeys.LOCAL_VARIABLES], dtype=dtypes_lib.float32) auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_batches): new_labels = np.random.randint(0, 2, size=batch_size) @@ -2093,7 +2093,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(expected_auc, auc.eval()) def testAUCPRReverseIncreasingPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32) labels = constant_op.constant([0, 0, 1, 1]) @@ -2104,7 +2104,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5) def testAUCPRJumbledPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32) labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1]) @@ -2115,7 +2115,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6) def testAUCPRPredictionsLessThanHalf(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], shape=(1, 7), @@ -2148,7 +2148,7 @@ class StreamingDynamicAUCTest(test.TestCase): auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions, weights=tf_weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_batches): new_labels = np.random.randint(0, 2, size=batch_size) @@ -2196,7 +2196,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): expected_result: The expected result (dict) that maps to tensors. weights: Optional weights tensor. """ - with self.test_session() as sess: + with self.cached_session() as sess: predictions_tensor = constant_op.constant( predictions, dtype=dtypes_lib.float32) labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64) @@ -2320,7 +2320,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): dtype=dtypes_lib.float32) auc, update_op = metrics.auc_with_confidence_intervals(tf_labels, tf_predictions) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for _ in xrange(num_batches): new_labels = np.random.randint(0, 2, size=batch_size) @@ -2335,7 +2335,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): self.assertAllClose(expected_auc, auc.auc.eval()) def testExceptionOnFloatLabels(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) labels = constant_op.constant([0.7, 0, 1, 0, 1]) _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) @@ -2343,7 +2343,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): self.assertRaises(TypeError, sess.run(update_op)) def testExceptionOnGreaterThanOneLabel(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) labels = constant_op.constant([2, 1, 0, 1, 0]) _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) @@ -2354,7 +2354,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase): sess.run(update_op) def testExceptionOnNegativeLabel(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) labels = constant_op.constant([1, 0, -1, 1, 0]) _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) @@ -2415,7 +2415,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): result, update_op = metric_ops.precision_recall_at_equal_thresholds( labels=labels, predictions=predictions) - with self.test_session() as sess: + with self.cached_session() as sess: # Run several updates. sess.run(variables.local_variables_initializer()) for _ in range(3): @@ -2448,7 +2448,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): default from assertAllClose. weights: Optional weights tensor. """ - with self.test_session() as sess: + with self.cached_session() as sess: predictions_tensor = constant_op.constant(predictions, dtype=dtype) labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) weights_tensor = None @@ -2621,7 +2621,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -2641,7 +2641,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, specificity.eval()) @@ -2656,7 +2656,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.8) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(1.0, sess.run(update_op)) self.assertAlmostEqual(1.0, specificity.eval()) @@ -2671,7 +2671,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.6, sess.run(update_op)) @@ -2689,7 +2689,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, weights=weights, sensitivity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.6, sess.run(update_op)) @@ -2707,7 +2707,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, weights=weights, sensitivity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op)) @@ -2757,7 +2757,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): sensitivity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, specificity=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -2777,7 +2777,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): specificity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, specificity=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, specificity.eval()) @@ -2792,7 +2792,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): specificity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, specificity=0.8) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.8, sess.run(update_op)) self.assertAlmostEqual(0.8, specificity.eval()) @@ -2807,7 +2807,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): specificity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, specificity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.6, sess.run(update_op)) self.assertAlmostEqual(0.6, specificity.eval()) @@ -2824,7 +2824,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): specificity, update_op = metrics.streaming_sensitivity_at_specificity( predictions, labels, weights=weights, specificity=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.675, sess.run(update_op)) self.assertAlmostEqual(0.675, specificity.eval()) @@ -2887,7 +2887,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds( predictions, labels, thresholds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -2905,7 +2905,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] @@ -2921,7 +2921,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertEqual(1, rec.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -2940,7 +2940,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] @@ -2956,7 +2956,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertAlmostEqual(0, rec.eval()) def testWeights1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -2982,7 +2982,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, rec_high.eval(), places=5) def testWeights2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3008,7 +3008,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, rec_high.eval(), places=5) def testExtremeThresholds(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) @@ -3032,7 +3032,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, rec_high.eval()) def testZeroLabelsPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] @@ -3082,7 +3082,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: # Reshape the data so its easy to queue up: predictions_batches = predictions.reshape((batch_size, num_batches)) labels_batches = labels.reshape((batch_size, num_batches)) @@ -3162,7 +3162,7 @@ class StreamingFPRThresholdsTest(test.TestCase): fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( predictions, labels, thresholds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -3177,7 +3177,7 @@ class StreamingFPRThresholdsTest(test.TestCase): def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] @@ -3190,7 +3190,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertEqual(0, fpr.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -3206,7 +3206,7 @@ class StreamingFPRThresholdsTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] @@ -3219,7 +3219,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(1, fpr.eval()) def testWeights1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3239,7 +3239,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) def testWeights2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3259,7 +3259,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) def testExtremeThresholds(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) @@ -3277,7 +3277,7 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) def testZeroLabelsPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] @@ -3317,7 +3317,7 @@ class StreamingFPRThresholdsTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: # Reshape the data so its easy to queue up: predictions_batches = predictions.reshape((batch_size, num_batches)) labels_batches = labels.reshape((batch_size, num_batches)) @@ -3393,7 +3393,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, precision=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -3413,7 +3413,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, precision=1.0) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, recall.eval()) @@ -3428,7 +3428,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, precision=0.8) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.8, sess.run(update_op)) self.assertAlmostEqual(0.8, recall.eval()) @@ -3443,7 +3443,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, precision=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) target_recall = 2.0 / 3.0 self.assertAlmostEqual(target_recall, sess.run(update_op)) @@ -3461,7 +3461,7 @@ class RecallAtPrecisionTest(test.TestCase): recall, update_op = metrics.recall_at_precision( labels, predictions, weights=weights, precision=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) target_recall = 2.0 / 3.0 self.assertAlmostEqual(target_recall, sess.run(update_op)) @@ -3486,7 +3486,7 @@ class RecallAtPrecisionTest(test.TestCase): precision=target_precision, strict_mode=strict_mode) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(expected_recall, sess.run(update_op)) self.assertAlmostEqual(expected_recall, recall.eval()) @@ -3565,7 +3565,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -3585,7 +3585,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.7) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, precision.eval()) @@ -3599,7 +3599,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(sess.run(label_prior), sess.run(update_op)) self.assertEqual(sess.run(label_prior), precision.eval()) @@ -3614,7 +3614,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.8) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(0.8, sess.run(update_op)) self.assertAlmostEqual(0.8, precision.eval()) @@ -3629,7 +3629,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.4) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(2.0/3, sess.run(update_op)) self.assertAlmostEqual(2.0/3, precision.eval()) @@ -3648,7 +3648,7 @@ class PrecisionAtRecallTest(test.TestCase): precision, update_op = metrics.precision_at_recall( labels, predictions, target_recall=0.8, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(34.0/43, sess.run(update_op)) self.assertAlmostEqual(34.0/43, precision.eval()) @@ -3697,7 +3697,7 @@ class StreamingFNRThresholdsTest(test.TestCase): fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -3712,7 +3712,7 @@ class StreamingFNRThresholdsTest(test.TestCase): def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] @@ -3725,7 +3725,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertEqual(0, fnr.eval()) def testSomeCorrect(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) @@ -3741,7 +3741,7 @@ class StreamingFNRThresholdsTest(test.TestCase): def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] @@ -3754,7 +3754,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertAlmostEqual(1, fnr.eval()) def testWeights1d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3774,7 +3774,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testWeights2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) @@ -3794,7 +3794,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testExtremeThresholds(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) @@ -3812,7 +3812,7 @@ class StreamingFNRThresholdsTest(test.TestCase): self.assertAlmostEqual(1.0, fnr_high.eval()) def testZeroLabelsPredictions(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] @@ -3852,7 +3852,7 @@ class StreamingFNRThresholdsTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: # Reshape the data so its easy to queue up: predictions_batches = predictions.reshape((batch_size, num_batches)) labels_batches = labels.reshape((batch_size, num_batches)) @@ -3940,7 +3940,7 @@ class StreamingRecallAtKTest(test.TestCase): sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k( predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=1) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0.25, sess.run(update_op)) self.assertEqual(0.25, recall.eval()) @@ -3958,7 +3958,7 @@ class StreamingRecallAtKTest(test.TestCase): sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k( predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0.5, sess.run(update_op)) self.assertEqual(0.5, recall.eval()) @@ -3976,7 +3976,7 @@ class StreamingRecallAtKTest(test.TestCase): sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k( predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=3) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1.0, sess.run(update_op)) self.assertEqual(1.0, recall.eval()) @@ -4000,7 +4000,7 @@ class StreamingRecallAtKTest(test.TestCase): k=2, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1.0, sess.run(update_op)) self.assertEqual(1.0, recall.eval()) @@ -4122,7 +4122,7 @@ class StreamingSparsePrecisionTest(test.TestCase): self.assertAlmostEqual(expected, metric.eval()) def test_top_k_rank_invalid(self): - with self.test_session(): + with self.cached_session(): # top_k_predictions has rank < 2. top_k_predictions = [9, 4, 6, 2, 0] sp_labels = sparse_tensor.SparseTensorValue( @@ -4669,7 +4669,7 @@ class StreamingSparsePrecisionTest(test.TestCase): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] labels = [[0, 0, 0, 1], [0, 0, 1, 0]] expected_precision = 0.5 - with self.test_session(): + with self.cached_session(): _, precision = metrics.streaming_sparse_precision_at_k( predictions=constant_op.constant(predictions, dtypes_lib.float32), labels=_binary_2d_label_to_sparse_value(labels), @@ -5374,7 +5374,7 @@ class StreamingSparseRecallTest(test.TestCase): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] labels = [[0, 0, 1, 0], [0, 0, 0, 1]] expected_recall = 0.5 - with self.test_session(): + with self.cached_session(): _, recall = metrics.streaming_sparse_recall_at_k( predictions=constant_op.constant(predictions, dtypes_lib.float32), labels=_binary_2d_label_to_sparse_value(labels), @@ -5418,7 +5418,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_absolute_error( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5440,7 +5440,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_absolute_error( predictions, labels, weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(3, sess.run(update_op)) self.assertEqual(3, error.eval()) @@ -5484,7 +5484,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_relative_error( predictions, labels, normalizer) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5509,7 +5509,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_relative_error( predictions, labels, normalizer=labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(expected_error, sess.run(update_op)) self.assertEqual(expected_error, error.eval()) @@ -5525,7 +5525,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_relative_error( predictions, labels, normalizer=array_ops.zeros_like(labels)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0.0, sess.run(update_op)) self.assertEqual(0.0, error.eval()) @@ -5563,7 +5563,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): labels = random_ops.random_normal((10, 3), seed=2) error, update_op = metrics.streaming_mean_squared_error(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5581,7 +5581,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_squared_error(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) @@ -5594,7 +5594,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_squared_error(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(6, sess.run(update_op)) self.assertEqual(6, error.eval()) @@ -5609,13 +5609,13 @@ class StreamingMeanSquaredErrorTest(test.TestCase): error, update_op = metrics.streaming_mean_squared_error( predictions, labels, weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(13, sess.run(update_op)) self.assertEqual(13, error.eval()) def testMultipleBatchesOfSizeOne(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) @@ -5640,7 +5640,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): self.assertAlmostEqual(208.0 / 6, error.eval(), 5) def testMetricsComputedConcurrently(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates one set of predictions. preds_queue0 = data_flow_ops.FIFOQueue( 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) @@ -5683,7 +5683,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase): self.assertAlmostEqual(79.0 / 6, mse1, 5) def testMultipleMetricsOnMultipleBatchesOfSizeOne(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) @@ -5745,7 +5745,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): error, update_op = metrics.streaming_root_mean_squared_error( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5758,7 +5758,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): self.assertEqual(initial_error, error.eval()) def testSingleUpdateZeroError(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( 0.0, shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32) @@ -5772,7 +5772,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): self.assertEqual(0, rmse.eval()) def testSingleUpdateWithError(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -5786,7 +5786,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5) def testSingleUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -5842,7 +5842,7 @@ class StreamingCovarianceTest(test.TestCase): predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5 cov, update_op = metrics.streaming_covariance(predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -5855,7 +5855,7 @@ class StreamingCovarianceTest(test.TestCase): self.assertEqual(initial_cov, cov.eval()) def testSingleUpdateIdentical(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = math_ops.to_float(math_ops.range(10)) labels = math_ops.to_float(math_ops.range(10)) @@ -5867,7 +5867,7 @@ class StreamingCovarianceTest(test.TestCase): self.assertAlmostEqual(expected_cov, cov.eval(), 5) def testSingleUpdateNonIdentical(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -5881,7 +5881,7 @@ class StreamingCovarianceTest(test.TestCase): self.assertAlmostEqual(expected_cov, cov.eval()) def testSingleUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -5899,7 +5899,7 @@ class StreamingCovarianceTest(test.TestCase): self.assertAlmostEqual(expected_cov, cov.eval()) def testMultiUpdateWithErrorNoWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -5933,7 +5933,7 @@ class StreamingCovarianceTest(test.TestCase): prev_expected_cov = expected_cov def testMultiUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -6023,7 +6023,7 @@ class StreamingPearsonRTest(test.TestCase): pearson_r, update_op = metrics.streaming_pearson_correlation( predictions, labels) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -6036,7 +6036,7 @@ class StreamingPearsonRTest(test.TestCase): self.assertEqual(initial_r, pearson_r.eval()) def testSingleUpdateIdentical(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = math_ops.to_float(math_ops.range(10)) labels = math_ops.to_float(math_ops.range(10)) @@ -6049,7 +6049,7 @@ class StreamingPearsonRTest(test.TestCase): self.assertAlmostEqual(expected_r, pearson_r.eval(), 5) def testSingleUpdateNonIdentical(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant( [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant( @@ -6064,7 +6064,7 @@ class StreamingPearsonRTest(test.TestCase): self.assertAlmostEqual(expected_r, pearson_r.eval()) def testSingleUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: predictions = np.array([2, 4, 6, 8]) labels = np.array([1, 3, 2, 7]) weights = np.array([0, 1, 3, 1]) @@ -6085,7 +6085,7 @@ class StreamingPearsonRTest(test.TestCase): self.assertAlmostEqual(expected_r, pearson_r.eval()) def testMultiUpdateWithErrorNoWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -6120,7 +6120,7 @@ class StreamingPearsonRTest(test.TestCase): prev_expected_r = expected_r def testMultiUpdateWithErrorAndWeights(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -6162,7 +6162,7 @@ class StreamingPearsonRTest(test.TestCase): prev_expected_r = expected_r def testMultiUpdateWithErrorAndSingletonBatches(self): - with self.test_session() as sess: + with self.cached_session() as sess: np.random.seed(123) n = 100 predictions = np.random.randn(n) @@ -6243,7 +6243,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=1) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -6266,7 +6266,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) @@ -6283,7 +6283,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(1, sess.run(update_op), 5) self.assertAlmostEqual(1, error.eval(), 5) @@ -6305,7 +6305,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(1.0, sess.run(update_op), 5) self.assertAlmostEqual(1.0, error.eval(), 5) @@ -6324,7 +6324,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) @@ -6343,7 +6343,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase): error, update_op = metrics.streaming_mean_cosine_distance( predictions, labels, dim=2, weights=weights) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1.5, update_op.eval()) self.assertEqual(1.5, error.eval()) @@ -6378,7 +6378,7 @@ class PcntBelowThreshTest(test.TestCase): self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testOneUpdate(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = constant_op.constant( [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) @@ -6398,7 +6398,7 @@ class PcntBelowThreshTest(test.TestCase): self.assertAlmostEqual(0.0, pcnt2, 5) def testSomePresentOneUpdate(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = constant_op.constant( [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) weights = constant_op.constant( @@ -6475,7 +6475,7 @@ class StreamingMeanIOUTest(test.TestCase): miou, update_op = metrics.streaming_mean_iou( predictions, labels, num_classes=num_classes) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -6489,7 +6489,7 @@ class StreamingMeanIOUTest(test.TestCase): def testMultipleUpdates(self): num_classes = 3 - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) @@ -6521,7 +6521,7 @@ class StreamingMeanIOUTest(test.TestCase): def testMultipleUpdatesWithWeights(self): num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. preds_queue = data_flow_ops.FIFOQueue( 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) @@ -6569,7 +6569,7 @@ class StreamingMeanIOUTest(test.TestCase): # one class, and thus there is one row and one column with # zero entries in the confusion matrix. num_classes = 3 - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the predictions. # There is no prediction for class 2. preds_queue = data_flow_ops.FIFOQueue( @@ -6611,7 +6611,7 @@ class StreamingMeanIOUTest(test.TestCase): constant_op.constant(1, shape=[7]) ], 0) num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6624,7 +6624,7 @@ class StreamingMeanIOUTest(test.TestCase): predictions = array_ops.zeros([40]) labels = array_ops.zeros([40]) num_classes = 1 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6635,7 +6635,7 @@ class StreamingMeanIOUTest(test.TestCase): predictions = array_ops.zeros([40]) labels = array_ops.ones([40]) num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6657,7 +6657,7 @@ class StreamingMeanIOUTest(test.TestCase): constant_op.constant(1, shape=[8]), constant_op.constant(0, shape=[1]) ], 0) - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou( predictions, labels, num_classes, weights=weights) sess.run(variables.local_variables_initializer()) @@ -6672,7 +6672,7 @@ class StreamingMeanIOUTest(test.TestCase): [[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1], [1, 1, 2, 0, 0, 0]]]) num_classes = 3 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6684,7 +6684,7 @@ class StreamingMeanIOUTest(test.TestCase): labels = constant_op.constant([0]) predictions = constant_op.constant([0]) num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6698,7 +6698,7 @@ class StreamingMeanIOUTest(test.TestCase): [[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1], [1, 1, 1, 0, 0, 0]]]) num_classes = 3 - with self.test_session() as sess: + with self.cached_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, num_classes) sess.run(variables.local_variables_initializer()) @@ -6733,7 +6733,7 @@ class StreamingConcatTest(test.TestCase): def testNextArraySize(self): next_array_size = metric_ops._next_array_size # pylint: disable=protected-access - with self.test_session(): + with self.cached_session(): self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2) self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4) self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4) @@ -6741,7 +6741,7 @@ class StreamingConcatTest(test.TestCase): self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8) def testStreamingConcat(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = array_ops.placeholder(dtypes_lib.int32, [None]) concatenated, update_op = metrics.streaming_concat(values) sess.run(variables.local_variables_initializer()) @@ -6758,7 +6758,7 @@ class StreamingConcatTest(test.TestCase): self.assertAllEqual(np.arange(10), concatenated.eval()) def testStreamingConcatStringValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = array_ops.placeholder(dtypes_lib.string, [None]) concatenated, update_op = metrics.streaming_concat(values) sess.run(variables.local_variables_initializer()) @@ -6777,7 +6777,7 @@ class StreamingConcatTest(test.TestCase): concatenated.eval()) def testStreamingConcatMaxSize(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = math_ops.range(3) concatenated, update_op = metrics.streaming_concat(values, max_size=5) sess.run(variables.local_variables_initializer()) @@ -6794,7 +6794,7 @@ class StreamingConcatTest(test.TestCase): self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval()) def testStreamingConcat2D(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = array_ops.reshape(math_ops.range(3), (3, 1)) concatenated, update_op = metrics.streaming_concat(values, axis=-1) sess.run(variables.local_variables_initializer()) @@ -6817,7 +6817,7 @@ class StreamingConcatTest(test.TestCase): array_ops.placeholder(dtypes_lib.float32, [None, None])) def testStreamingConcatReset(self): - with self.test_session() as sess: + with self.cached_session() as sess: values = array_ops.placeholder(dtypes_lib.int32, [None]) concatenated, update_op = metrics.streaming_concat(values) sess.run(variables.local_variables_initializer()) @@ -6845,7 +6845,7 @@ class AggregateMetricsTest(test.TestCase): metrics.streaming_mean(values)) self.assertEqual(len(value_tensors), 1) self.assertEqual(len(update_ops), 1) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(1, update_ops[0].eval()) self.assertEqual(1, value_tensors[0].eval()) @@ -6858,7 +6858,7 @@ class AggregateMetricsTest(test.TestCase): metrics.streaming_mean_squared_error(predictions, labels)) self.assertEqual(len(value_tensors), 2) self.assertEqual(len(update_ops), 2) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(2, update_ops[0].eval()) self.assertEqual(4, update_ops[1].eval()) @@ -6879,7 +6879,7 @@ class AggregateMetricMapTest(test.TestCase): self.assertEqual(2, len(names_to_values)) self.assertEqual(2, len(names_to_updates)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) self.assertEqual(2, names_to_updates['m1'].eval()) self.assertEqual(4, names_to_updates['m2'].eval()) @@ -6914,7 +6914,7 @@ class CountTest(test.TestCase): self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor)) def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -6931,7 +6931,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(8.0, sess.run(result), 5) def testUpdateOpsReturnsCurrentValue(self): - with self.test_session() as sess: + with self.cached_session() as sess: values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) _enqueue_vector(sess, values_queue, [0, 1]) @@ -6952,7 +6952,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(8.0, sess.run(result), 5) def test1dWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -6979,7 +6979,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(3.4, result.eval(), 5) def test1dWeightedValues_placeholders(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) values = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -7001,7 +7001,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(3.4, result.eval(), 5) def test2dWeightedValues(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) @@ -7028,7 +7028,7 @@ class CountTest(test.TestCase): self.assertAlmostEqual(4.1, result.eval(), 5) def test2dWeightedValues_placeholders(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Create the queue that populates the values. feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) values = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -7101,7 +7101,7 @@ class CohenKappaTest(test.TestCase): (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2) kappa, update_op = metrics.cohen_kappa(labels, predictions, 3) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) # Run several updates. @@ -7135,7 +7135,7 @@ class CohenKappaTest(test.TestCase): for dtype in dtypes: for shape in shapes: for weight in weights: - with self.test_session() as sess: + with self.cached_session() as sess: predictions_tensor = constant_op.constant( np.reshape(predictions, shape), dtype=dtype) labels_tensor = constant_op.constant( @@ -7156,7 +7156,7 @@ class CohenKappaTest(test.TestCase): # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs) expect = 1.0 - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) kappa, update_op = metrics.cohen_kappa(labels, predictions, 4) @@ -7175,7 +7175,7 @@ class CohenKappaTest(test.TestCase): # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions) expect = -0.333333333333 - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) labels = constant_op.constant(labels) kappa, update_op = metrics.cohen_kappa(labels, predictions, 4) @@ -7193,7 +7193,7 @@ class CohenKappaTest(test.TestCase): # labels, predictions, sample_weight=weights) expect = 0.453466583385 - with self.test_session() as sess: + with self.cached_session() as sess: predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) labels = constant_op.constant(labels) kappa, update_op = metrics.cohen_kappa( @@ -7218,7 +7218,7 @@ class CohenKappaTest(test.TestCase): weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,)) kappa, update_op = metrics.cohen_kappa( labels_t, predictions_t, num_classes, weights=weights_t) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for idx in range(0, num_samples, batch_size): @@ -7256,7 +7256,7 @@ class CohenKappaTest(test.TestCase): def testConditionalPackingOptimization(self): placeholder = array_ops.placeholder(dtypes_lib.float32, [None]) values, update_op = metric_ops.streaming_concat(placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.local_variables_initializer()) for feed in range(10): sess.run(update_op, feed_dict={placeholder: [feed]}) diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py index e85ae7b22a..586c6c7bfc 100644 --- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py +++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py @@ -37,7 +37,7 @@ class RnnCellsTest(test.TestCase): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim - with self.test_session(): + with self.cached_session(): inputs = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) c = variables.Variable( @@ -61,7 +61,7 @@ class RnnCellsTest(test.TestCase): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim - with self.test_session(): + with self.cached_session(): inputs = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) c = variables.Variable( diff --git a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py index cb69c72970..d0955cbe11 100644 --- a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py +++ b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py @@ -31,7 +31,7 @@ class HyperplaneLshProbesTest(test.TestCase): # tests in hyperplane_lsh_probes_test.cc already cover most of the LSH # functionality. def simple_batch_test(self): - with self.test_session(): + with self.cached_session(): hyperplanes = np.eye(4) points = np.array([[1.2, 0.5, -0.9, -1.0], [2.0, -3.0, 1.0, -1.5]]) product = np.dot(points, hyperplanes) diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py index 31a6fe1d94..9a19502276 100644 --- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py +++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py @@ -38,7 +38,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): desired_shape = numpy.array([6, None]) output_tensor = input_tensor.reshape((6, 2)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() result = periodic_resample(input_tensor, desired_shape).eval() self.assertAllEqual(result, output_tensor) @@ -49,7 +49,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): desired_shape = numpy.array([5, None]) output_tensor = input_tensor.reshape((6, 2))[:-1] - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() result = periodic_resample(input_tensor, desired_shape).eval() self.assertAllEqual(result, output_tensor) @@ -63,7 +63,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): [15]]]) # NOTE: output_tensor != input_tensor.reshape((4, 4, -1)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() result = periodic_resample(input_tensor, desired_shape).eval() # input_tensor[0, 0, 0] == result[0, 0, 0] @@ -88,14 +88,14 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): [[49], [53], [57], [61]], [[51], [55], [59], [63]]]]) # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() result = periodic_resample(input_tensor, desired_shape).eval() self.assertAllEqual(result, output_tensor) def testPeriodicResampleErrors(self): input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) - with self.test_session(): + with self.cached_session(): with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Dimension 3 input tensor has size 4, desired shape has size 1'): @@ -109,7 +109,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): desired_shape = numpy.array([4, 4, None]) result_shape = (4, 4, 1) input_shape = (2, 2, 4) - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32, shape=input_shape) output = periodic_resample(x, desired_shape) error = gradient_checker.compute_gradient_error( @@ -117,7 +117,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): self.assertLess(error, 1e-4) def testPeriodicResampleShapeInference(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Case 1: output shape can be fully inferreed. x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4)) output = periodic_resample(x, [4, 4, None]) diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py index 00fbd4fbb8..aea80a5256 100644 --- a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py +++ b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py @@ -56,7 +56,7 @@ class RecurrentTest(test_util.TensorFlowTestCase): x_power=state.x_power * theta.x) return next_state, [] - with self.test_session() as sess: + with self.cached_session() as sess: theta = _PolyTheta(x=array_ops.constant(2.0)) state = _PolyState( value=array_ops.constant(0.0), @@ -142,7 +142,7 @@ class RecurrentTest(test_util.TensorFlowTestCase): def _ParameterizedTestElman(self, seqlen, use_grad): - with self.test_session() as sess: + with self.cached_session() as sess: random_seed.set_random_seed(342462) batch = 3 diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index 8a0dbef788..12dd72a95b 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -50,7 +50,7 @@ class TestModelSavingandLoading(test.TestCase): return os.path.join(temp_dir, dirname) def test_saving_sequential_model(self): - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.RepeatVector(3)) @@ -75,7 +75,7 @@ class TestModelSavingandLoading(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_saving_sequential_model_without_compile(self): - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.RepeatVector(3)) @@ -92,7 +92,7 @@ class TestModelSavingandLoading(test.TestCase): self.assertAllClose(ref_y, y, atol=1e-05) def test_saving_functional_model(self): - with self.test_session(): + with self.cached_session(): inputs = keras.layers.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) output = keras.layers.Dense(3)(x) @@ -117,7 +117,7 @@ class TestModelSavingandLoading(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_saving_functional_model_without_compile(self): - with self.test_session(): + with self.cached_session(): inputs = keras.layers.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) output = keras.layers.Dense(3)(x) @@ -138,7 +138,7 @@ class TestModelSavingandLoading(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_saving_with_tf_optimizer(self): - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.Dense(3)) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index f87a96e547..4afc6399d5 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1762,7 +1762,7 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, feed_fn1, feed_fn2) - with self.test_session() as sess: + with self.cached_session() as sess: np1 = np.array([1.0, 1.5, 2.0, 2.5]) np2 = np.array([3.0, 3.5, 4.0, 4.5]) squared_tensor = SquaredTensor(np2) @@ -1922,7 +1922,7 @@ class SessionTest(test_util.TensorFlowTestCase): pass def testAutoConvertAndCheckData(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = array_ops.placeholder(dtype=dtypes.string) with self.assertRaisesRegexp( TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'): diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index c0e66cb0b8..d403b0c61a 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -1259,7 +1259,7 @@ class SparseTest(PForTest): [3]) # [0, 2, 0] pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(pfor, feed_dict={num_iters: 3}) def test_sparse_result_none_stacked(self): diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py index f9cf16f6a4..628c6764cd 100644 --- a/tensorflow/python/ops/parallel_for/gradients_test.py +++ b/tensorflow/python/ops/parallel_for/gradients_test.py @@ -356,7 +356,7 @@ class GradientsTest(test.TestCase): self.run_and_assert_equal(answer, jacobian_while) def test_jacobian_unknown_shape(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32, shape=[None, None]) y = math_ops.matmul(x, x, transpose_a=True) jacobian_pfor = gradients.jacobian(y, x, use_pfor=True) @@ -381,7 +381,7 @@ class GradientsTest(test.TestCase): gradients.batch_jacobian(y, x, use_pfor=True) def test_batch_jacobian_bad_unknown_shapes(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32) y = array_ops.concat([x, x], axis=0) jacobian = gradients.batch_jacobian(y, x) @@ -402,7 +402,7 @@ class GradientsTest(test.TestCase): self.run_and_assert_equal(answer, batch_jacobian_while) def test_batch_jacobian_unknown_shape(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32) y = x * x batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True) diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 2369eb610e..ef503137d1 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -461,7 +461,7 @@ class NestTest(parameterized.TestCase, test.TestCase): inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) } - with self.test_session() as sess: + with self.cached_session() as sess: output_np = sess.run(output, feed_dict=feed_dict) self.assertAllClose(output_np[0], feed_dict[inp_a][0] + feed_dict[inp_b][0]) diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py index 16fa1f547d..fedbe1dff6 100644 --- a/tensorflow/python/util/tf_should_use_test.py +++ b/tensorflow/python/util/tf_should_use_test.py @@ -106,7 +106,7 @@ class TfShouldUseTest(test.TestCase): def return_const(value): return constant_op.constant(value, name='blah3') with reroute_error() as (error, _): - with self.test_session(): + with self.cached_session(): return_const(0.0) # Creating another op and executing it does not mark the # unused op as being "used". @@ -124,7 +124,8 @@ class TfShouldUseTest(test.TestCase): @tf_should_use.should_use_result def return_const(value): return constant_op.constant(value, name='blah3') - with self.test_session(): + + with self.cached_session(): return_const(0.0).mark_used() if __name__ == '__main__': diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py index 01f37d8768..35a74c9664 100644 --- a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py +++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py @@ -35,7 +35,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): """ def testArgRenames(self): - with self.test_session(): + with self.cached_session(): a = [[1., 2., 3.], [4., 5., 6.]] b = [[True, False, False], [False, True, True]] @@ -98,7 +98,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): [[[1, 2]], [[3, 4]]]) def testArgMinMax(self): - with self.test_session(): + with self.cached_session(): self.assertAllEqual( tf.argmin([[1, 2, 3], [4, 1, 0]], dimension=1).eval(), [0, 2]) @@ -113,7 +113,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): [1, 0, 0]) def testExpandAndSqueeze(self): - with self.test_session(): + with self.cached_session(): # TODO(aselle): sparse_split, sparse_reduce_sum, # sparse_reduce_sum_sparse, reduce_join @@ -140,7 +140,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): a) def testArithmeticRenames(self): - with self.test_session() as s: + with self.cached_session() as s: stuff = tf.split(1, 2, [[1, 2, 3, 4], [4, 5, 6, 7]]) vals = s.run(stuff) self.assertAllEqual(vals, @@ -164,7 +164,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): # ] def testBatchAndSvd(self): - with self.test_session(): + with self.cached_session(): mat = [[1., 2.], [2., 3.]] batched_mat = tf.expand_dims(mat, [0]) result = tf.matmul(mat, mat).eval() @@ -176,7 +176,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): def testCrossEntropy(self): # TODO(aselle): Test sparse_softmax_... - with self.test_session(): + with self.cached_session(): labels = [.8, .5, .2, .1] logits = [.9, .1, .3, .1] self.assertAllEqual( @@ -191,7 +191,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): labels=labels, logits=logits).eval()) def testVariables(self): - with self.test_session() as s: + with self.cached_session() as s: # make some variables _ = [tf.Variable([1, 2, 3], dtype=tf.float32), @@ -201,7 +201,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): _ = [v.name for v in tf.local_variables()] def testSummaries(self): - with self.test_session() as s: + with self.cached_session() as s: var = tf.Variable([1, 2, 3], dtype=tf.float32) s.run(tf.initialize_all_variables()) x, y = np.meshgrid(np.linspace(-10, 10, 256), np.linspace(-10, 10, 256)) diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py index a49035a1a0..e5ca8d3e2e 100644 --- a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py +++ b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py @@ -26,7 +26,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): """Test various APIs that have been changed in 2.0.""" def testRenames(self): - with self.test_session(): + with self.cached_session(): self.assertAllClose(1.04719755, tf.acos(0.5).eval()) self.assertAllClose(0.5, tf.rsqrt(4.0).eval()) -- cgit v1.2.3 From f1cc58bb4144de61a693076d8ff8a26b2644ebbb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 14:36:35 -0700 Subject: Move from deprecated self.test_session() to self.cached_session(). self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336352 --- tensorflow/python/training/adadelta_test.py | 4 +- tensorflow/python/training/adagrad_da_test.py | 10 +-- tensorflow/python/training/adagrad_test.py | 16 ++-- tensorflow/python/training/adam_test.py | 10 +-- .../training/basic_session_run_hooks_test.py | 10 +-- .../python/training/checkpoint_management_test.py | 6 +- tensorflow/python/training/checkpoint_ops_test.py | 18 ++--- .../python/training/checkpoint_utils_test.py | 24 +++--- tensorflow/python/training/ftrl_test.py | 28 +++---- .../python/training/gradient_descent_test.py | 18 ++--- tensorflow/python/training/input_test.py | 94 +++++++++++----------- .../python/training/learning_rate_decay_test.py | 2 +- tensorflow/python/training/momentum_test.py | 14 ++-- .../python/training/monitored_session_test.py | 58 ++++++------- tensorflow/python/training/moving_averages_test.py | 30 +++---- tensorflow/python/training/optimizer_test.py | 8 +- .../python/training/proximal_adagrad_test.py | 18 ++--- .../training/proximal_gradient_descent_test.py | 16 ++-- tensorflow/python/training/queue_runner_test.py | 26 +++--- tensorflow/python/training/rmsprop_test.py | 4 +- tensorflow/python/training/saver_test.py | 54 ++++++------- tensorflow/python/training/session_manager_test.py | 28 +++---- tensorflow/python/training/slot_creator_test.py | 14 ++-- tensorflow/python/training/supervisor_test.py | 6 +- .../python/training/warm_starting_util_test.py | 2 +- 25 files changed, 259 insertions(+), 259 deletions(-) diff --git a/tensorflow/python/training/adadelta_test.py b/tensorflow/python/training/adadelta_test.py index 2678016d24..a14ac895ac 100644 --- a/tensorflow/python/training/adadelta_test.py +++ b/tensorflow/python/training/adadelta_test.py @@ -155,7 +155,7 @@ class AdadeltaOptimizerTest(test.TestCase): rtol=1e-5) def testBasic(self): - with self.test_session(): + with self.cached_session(): self.doTestBasic(use_resource=False) @test_util.run_in_graph_and_eager_modes(reset_test=True) @@ -168,7 +168,7 @@ class AdadeltaOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) diff --git a/tensorflow/python/training/adagrad_da_test.py b/tensorflow/python/training/adagrad_da_test.py index c3a242a75e..00801be3b4 100644 --- a/tensorflow/python/training/adagrad_da_test.py +++ b/tensorflow/python/training/adagrad_da_test.py @@ -34,7 +34,7 @@ class AdagradDAOptimizerTest(test.TestCase): def doTestAdagradDAwithoutRegularizationBasic1(self, use_resource=False): for dtype in [dtypes.float64, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable(0, dtype=dtypes.int64) if use_resource: var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) @@ -81,7 +81,7 @@ class AdagradDAOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) @@ -101,7 +101,7 @@ class AdagradDAOptimizerTest(test.TestCase): def testAdagradDAwithoutRegularizationBasic2(self): for dtype in [dtypes.float64, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable(0, dtype=dtypes.int64) var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([4.0, 3.0], dtype=dtype) @@ -133,7 +133,7 @@ class AdagradDAOptimizerTest(test.TestCase): def testAdagradDAWithL1(self): for dtype in [dtypes.float64, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable(0, dtype=dtypes.int64) var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([4.0, 3.0], dtype=dtype) @@ -165,7 +165,7 @@ class AdagradDAOptimizerTest(test.TestCase): def testAdagradDAWithL1_L2(self): for dtype in [dtypes.float64, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable(0, dtype=dtypes.int64) var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([4.0, 3.0], dtype=dtype) diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py index 4e634fff84..7caf01f64d 100644 --- a/tensorflow/python/training/adagrad_test.py +++ b/tensorflow/python/training/adagrad_test.py @@ -98,7 +98,7 @@ class AdagradOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable( [[1.0, 2.0], [3.0, 4.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) @@ -117,7 +117,7 @@ class AdagradOptimizerTest(test.TestCase): def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -141,7 +141,7 @@ class AdagradOptimizerTest(test.TestCase): def testSparseBasic(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) grads0 = ops.IndexedSlices( @@ -172,7 +172,7 @@ class AdagradOptimizerTest(test.TestCase): def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): repeated_index_update_var = variables.Variable( [[1.0], [2.0]], dtype=dtype) aggregated_update_var = variables.Variable( @@ -202,7 +202,7 @@ class AdagradOptimizerTest(test.TestCase): def testSparseRepeatedIndicesResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var_repeated = resource_variable_ops.ResourceVariable( [1.0, 2.0], dtype=dtype) loss_repeated = math_ops.reduce_sum( @@ -226,7 +226,7 @@ class AdagradOptimizerTest(test.TestCase): def testSparseStability(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): shape = [1, 6] var0 = variables.Variable( [[ @@ -262,7 +262,7 @@ class AdagradOptimizerTest(test.TestCase): def testSharing(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -295,7 +295,7 @@ class AdagradOptimizerTest(test.TestCase): np.array([2.715679168701172, 3.715679168701172]), var1.eval()) def testDynamicShapeVariable_Ok(self): - with self.test_session(): + with self.cached_session(): v = variable_scope.get_variable("v", initializer=constant_op.constant(1.), validate_shape=False) self.assertFalse(v.shape.is_fully_defined()) diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index 778c672077..48db6e3733 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -56,7 +56,7 @@ class AdamOptimizerTest(test.TestCase): def doTestSparse(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -122,7 +122,7 @@ class AdamOptimizerTest(test.TestCase): def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): repeated_index_update_var = variables.Variable( [[1.0], [2.0]], dtype=dtype) aggregated_update_var = variables.Variable( @@ -224,7 +224,7 @@ class AdamOptimizerTest(test.TestCase): opt.get_slot(var=var0, name="m").name) def testBasic(self): - with self.test_session(): + with self.cached_session(): self.doTestBasic(use_resource=False) @test_util.run_in_graph_and_eager_modes(reset_test=True) @@ -237,7 +237,7 @@ class AdamOptimizerTest(test.TestCase): def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -274,7 +274,7 @@ class AdamOptimizerTest(test.TestCase): def testSharing(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index fe8a3e9062..2d469634e0 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -1145,7 +1145,7 @@ class SummarySaverHookTest(test.TestCase): summary_writer=self.summary_writer, summary_op=self.summary_op) - with self.test_session() as sess: + with self.cached_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) @@ -1177,7 +1177,7 @@ class SummarySaverHookTest(test.TestCase): summary_writer=self.summary_writer, summary_op=[self.summary_op, self.summary_op2]) - with self.test_session() as sess: + with self.cached_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) @@ -1205,7 +1205,7 @@ class SummarySaverHookTest(test.TestCase): summary_writer=self.summary_writer, summary_op=self.summary_op) - with self.test_session() as sess: + with self.cached_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) @@ -1240,7 +1240,7 @@ class SummarySaverHookTest(test.TestCase): summary_writer=self.summary_writer, summary_op=self.summary_op) - with self.test_session() as sess: + with self.cached_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) @@ -1388,7 +1388,7 @@ class ResourceSummarySaverHookTest(test.TestCase): summary_writer=self.summary_writer, summary_op=self.summary_op) - with self.test_session() as sess: + with self.cached_session() as sess: hook.begin() sess.run(variables_lib.global_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py index 8ef5048299..3a061bcb35 100644 --- a/tensorflow/python/training/checkpoint_management_test.py +++ b/tensorflow/python/training/checkpoint_management_test.py @@ -73,7 +73,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase): # Collides with the default name of the checkpoint state file. filepath = os.path.join(traindir, "checkpoint") - with self.test_session() as sess: + with self.cached_session() as sess: unused_a = variables.Variable(0.0) # So that Saver saves something. variables.global_variables_initializer().run() @@ -113,7 +113,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase): filename = "snapshot" filepath = os.path.join(traindir, filename) - with self.test_session() as sess: + with self.cached_session() as sess: # Build a simple graph. v0 = variables.Variable(0.0) inc = v0.assign_add(1.0) @@ -128,7 +128,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase): inc.eval() save.save(sess, filepath, global_step=2) - with self.test_session() as sess: + with self.cached_session() as sess: # Build a new graph with different initialization. v0 = variables.Variable(-1.0) diff --git a/tensorflow/python/training/checkpoint_ops_test.py b/tensorflow/python/training/checkpoint_ops_test.py index 00611de862..dde8431497 100644 --- a/tensorflow/python/training/checkpoint_ops_test.py +++ b/tensorflow/python/training/checkpoint_ops_test.py @@ -43,7 +43,7 @@ class LoadAndRemapWrappersTest(test.TestCase): # 0., 1., ..., 79. reshaped into [5, 16]. initializer = init_ops.constant_initializer( np.reshape(np.linspace(0.0, 79, 5 * 16), (5, 16))) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope('some_scope'): variable_scope.get_variable(name='embeddings', shape=[5, 16], initializer=initializer) @@ -114,7 +114,7 @@ class LoadAndRemapWrappersTest(test.TestCase): ], axis=1) - with self.test_session(): + with self.cached_session(): self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval()) def test_load_and_remap_output_layer_weight_initializer_linear(self): @@ -150,7 +150,7 @@ class LoadAndRemapWrappersTest(test.TestCase): initializer=loading_initializer, partitioner=partitioned_variables.fixed_size_partitioner(2)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_matrix, remapped_matrix.as_tensor().eval()) @@ -184,7 +184,7 @@ class LoadAndRemapWrappersTest(test.TestCase): initializer=loading_initializer, partitioner=partitioned_variables.fixed_size_partitioner(2)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_matrix, remapped_matrix.as_tensor().eval()) @@ -222,7 +222,7 @@ class LoadAndRemapWrappersTest(test.TestCase): initializer=loading_initializer, partitioner=partitioned_variables.fixed_size_partitioner(2)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_matrix, remapped_matrix.as_tensor().eval()) @@ -258,7 +258,7 @@ class LoadAndRemapWrappersTest(test.TestCase): initializer=loading_initializer, partitioner=partitioned_variables.fixed_size_partitioner(2)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_matrix, remapped_matrix.as_tensor().eval()) @@ -292,7 +292,7 @@ class LoadAndRemapWrappersTest(test.TestCase): initializer=embedding_loading_initializer, partitioner=partitioned_variables.fixed_size_partitioner(2)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_embeddings, remapped_embeddings.as_tensor().eval()) @@ -338,7 +338,7 @@ class LoadAndRemapWrappersTest(test.TestCase): initializer=embedding_loading_initializer, partitioner=partitioned_variables.fixed_size_partitioner(2)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_embeddings, remapped_embeddings.as_tensor().eval()) @@ -376,7 +376,7 @@ class LoadAndRemapWrappersTest(test.TestCase): initializer=embedding_loading_initializer, partitioner=partitioned_variables.fixed_size_partitioner(2)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAllClose(expected_remapped_embeddings, remapped_embeddings.as_tensor().eval()) diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py index 1aab16338a..61dcbdb2b8 100644 --- a/tensorflow/python/training/checkpoint_utils_test.py +++ b/tensorflow/python/training/checkpoint_utils_test.py @@ -84,7 +84,7 @@ class CheckpointsTest(test.TestCase): def testNoTensor(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) with self.assertRaises(errors_impl.OpError): self.assertAllEqual( @@ -92,7 +92,7 @@ class CheckpointsTest(test.TestCase): def testGetTensor(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1) @@ -105,7 +105,7 @@ class CheckpointsTest(test.TestCase): def testGetAllVariables(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _create_checkpoints(session, checkpoint_dir) self.assertEqual( checkpoint_utils.list_variables(checkpoint_dir), @@ -114,7 +114,7 @@ class CheckpointsTest(test.TestCase): def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -148,7 +148,7 @@ class CheckpointsTest(test.TestCase): def testInitialValueComesFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -178,7 +178,7 @@ class CheckpointsTest(test.TestCase): def testInitWithScopeDoesNotCaptureSuffixes(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, v4 = _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default() as g: @@ -197,7 +197,7 @@ class CheckpointsTest(test.TestCase): def testRestoreRunsOnSameDevice(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default(): @@ -213,7 +213,7 @@ class CheckpointsTest(test.TestCase): def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -237,7 +237,7 @@ class CheckpointsTest(test.TestCase): def testInitToRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -260,7 +260,7 @@ class CheckpointsTest(test.TestCase): def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. @@ -322,7 +322,7 @@ class CheckpointsTest(test.TestCase): def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -367,7 +367,7 @@ class CheckpointsTest(test.TestCase): def testNoAdditionalReadOpsForResourceVariables(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py index 76ca5b45c9..09d6fe36d3 100644 --- a/tensorflow/python/training/ftrl_test.py +++ b/tensorflow/python/training/ftrl_test.py @@ -37,7 +37,7 @@ class FtrlOptimizerTest(test.TestCase): def doTestFtrlwithoutRegularization(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: if use_resource: var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) @@ -76,7 +76,7 @@ class FtrlOptimizerTest(test.TestCase): def testFtrlwithoutRegularization2(self): for dtype in [dtypes.half, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -105,7 +105,7 @@ class FtrlOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) @@ -121,7 +121,7 @@ class FtrlOptimizerTest(test.TestCase): def testFtrlWithL1(self): for dtype in [dtypes.half, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -150,7 +150,7 @@ class FtrlOptimizerTest(test.TestCase): def testFtrlWithL1_L2(self): for dtype in [dtypes.half, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -186,7 +186,7 @@ class FtrlOptimizerTest(test.TestCase): weights will tend to have smaller magnitudes with this parameter set. """ for dtype in [dtypes.half, dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -335,7 +335,7 @@ class FtrlOptimizerTest(test.TestCase): # FTRL-Proximal performs same updates as Adagrad or GradientDescent. def testEquivAdagradwithoutRegularization(self): for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(): + with self.cached_session(): val0, val1 = self.applyOptimizer( ftrl.FtrlOptimizer( 3.0, @@ -346,7 +346,7 @@ class FtrlOptimizerTest(test.TestCase): l2_regularization_strength=0.0), dtype) - with self.test_session(): + with self.cached_session(): val2, val3 = self.applyOptimizer( adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype) @@ -355,7 +355,7 @@ class FtrlOptimizerTest(test.TestCase): def testEquivSparseAdagradwithoutRegularization(self): for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(): + with self.cached_session(): val0, val1 = self.applyOptimizer( ftrl.FtrlOptimizer( 3.0, @@ -367,7 +367,7 @@ class FtrlOptimizerTest(test.TestCase): dtype, is_sparse=True) - with self.test_session(): + with self.cached_session(): val2, val3 = self.applyOptimizer( adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype, @@ -378,7 +378,7 @@ class FtrlOptimizerTest(test.TestCase): def testEquivSparseGradientDescentwithoutRegularization(self): for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(): + with self.cached_session(): val0, val1 = self.applyOptimizer( ftrl.FtrlOptimizer( 3.0, @@ -390,7 +390,7 @@ class FtrlOptimizerTest(test.TestCase): dtype, is_sparse=True) - with self.test_session(): + with self.cached_session(): val2, val3 = self.applyOptimizer( gradient_descent.GradientDescentOptimizer(3.0), dtype, @@ -401,7 +401,7 @@ class FtrlOptimizerTest(test.TestCase): def testEquivGradientDescentwithoutRegularization(self): for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(): + with self.cached_session(): val0, val1 = self.applyOptimizer( ftrl.FtrlOptimizer( 3.0, @@ -412,7 +412,7 @@ class FtrlOptimizerTest(test.TestCase): l2_regularization_strength=0.0), dtype) - with self.test_session(): + with self.cached_session(): val2, val3 = self.applyOptimizer( gradient_descent.GradientDescentOptimizer(3.0), dtype) diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py index b304e92421..56d82a5b88 100644 --- a/tensorflow/python/training/gradient_descent_test.py +++ b/tensorflow/python/training/gradient_descent_test.py @@ -37,7 +37,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testBasic(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -60,7 +60,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testBasicResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -85,7 +85,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testBasicCallableParams(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -111,7 +111,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testMinimizeResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) @@ -137,7 +137,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) @@ -164,7 +164,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -186,7 +186,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testGradWrtRef(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): opt = gradient_descent.GradientDescentOptimizer(3.0) values = [1.0, 3.0] vars_ = [variables.Variable([v], dtype=dtype) for v in values] @@ -197,7 +197,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testWithGlobalStep(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): global_step = variables.Variable(0, trainable=False) var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) @@ -220,7 +220,7 @@ class GradientDescentOptimizerTest(test.TestCase): def testSparseBasic(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) grads0 = ops.IndexedSlices( diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index 1b1e89cb26..a9b05dcc73 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -51,7 +51,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase): for name in additional: open(name, "w").write("Some contents") filenames = list(set(filenames + additional)) - with self.test_session(): + with self.cached_session(): star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*")) question = inp.match_filenames_once( os.path.join(self.get_temp_dir(), "match_filenames.?")) @@ -66,7 +66,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase): class LimitEpochsTest(test_lib.TestCase): def testNoLimit(self): - with self.test_session(): + with self.cached_session(): seven = constant_op.constant(7) seven_forever = inp.limit_epochs(seven) variables.local_variables_initializer().run() @@ -74,7 +74,7 @@ class LimitEpochsTest(test_lib.TestCase): self.assertEqual(7, seven_forever.eval()) def testLimit(self): - with self.test_session(): + with self.cached_session(): love_me = constant_op.constant("Love Me") love_me_two_times = inp.limit_epochs(love_me, num_epochs=2) variables.global_variables_initializer().run() @@ -88,7 +88,7 @@ class LimitEpochsTest(test_lib.TestCase): class InputProducerTest(test_lib.TestCase): def testNoShuffle(self): - with self.test_session(): + with self.cached_session(): input_tensor = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] @@ -111,7 +111,7 @@ class InputProducerTest(test_lib.TestCase): thread.join() def testNoShapeInference(self): - with self.test_session(): + with self.cached_session(): # Disable shape inference for the input. input_value = [[1, 2, 3, 4], [5, 6, 7, 8], @@ -144,7 +144,7 @@ class InputProducerTest(test_lib.TestCase): class StringInputProducerTest(test_lib.TestCase): def testNoShuffle(self): - with self.test_session(): + with self.cached_session(): strings = [b"to", b"be", b"or", b"not", b"to", b"be"] num_epochs = 3 queue = inp.string_input_producer( @@ -166,7 +166,7 @@ class StringInputProducerTest(test_lib.TestCase): thread.join() def testShuffle(self): - with self.test_session(): + with self.cached_session(): strings = [b"a", b"b", b"c"] num_epochs = 600 queue = inp.string_input_producer( @@ -206,7 +206,7 @@ class StringInputProducerTest(test_lib.TestCase): def testNullStringPython(self): # Graph-construction time check for empty string list: - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): _ = inp.string_input_producer([]) @@ -214,7 +214,7 @@ class StringInputProducerTest(test_lib.TestCase): # Runtime check for empty string list. This is slightly oblique: # The queue runner should die with an assertion error on the null # input tensor, causing the dequeue to fail with an OutOfRangeError. - with self.test_session(): + with self.cached_session(): coord = coordinator.Coordinator() queue = inp.string_input_producer( constant_op.constant( @@ -230,7 +230,7 @@ class StringInputProducerTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): strings = [b"to", b"be", b"or", b"not", b"to", b"be"] queue = inp.string_input_producer( strings, shared_name="SHARED_NAME_XYZ", name="Q") @@ -238,7 +238,7 @@ class StringInputProducerTest(test_lib.TestCase): queue.queue_ref.op.node_def.attr["shared_name"]) def testConstructionRace(self): - with self.test_session() as sess: + with self.cached_session() as sess: strings = [b"to", b"be", b"or", b"not", b"to", b"be"] queue = inp.string_input_producer(strings, shuffle=False) coord = coordinator.Coordinator() @@ -260,7 +260,7 @@ class StringInputProducerTest(test_lib.TestCase): class RangeInputProducerTest(test_lib.TestCase): def testNoShuffle(self): - with self.test_session(): + with self.cached_session(): num_epochs = 3 range_size = 5 queue = inp.range_input_producer( @@ -282,7 +282,7 @@ class RangeInputProducerTest(test_lib.TestCase): thread.join() def testShuffle(self): - with self.test_session(): + with self.cached_session(): num_epochs = 200 range_size = 2 queue = inp.range_input_producer( @@ -321,7 +321,7 @@ class RangeInputProducerTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): range_size = 5 queue = inp.range_input_producer( range_size, shared_name="SHARED_NAME_XYZ", name="Q") @@ -332,7 +332,7 @@ class RangeInputProducerTest(test_lib.TestCase): class SliceInputProducerTest(test_lib.TestCase): def testNoShuffle(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_epochs = 3 source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"] source_ints = [2, 3, 5, 7] @@ -356,7 +356,7 @@ class SliceInputProducerTest(test_lib.TestCase): thread.join() def testShuffle(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_epochs = 1200 source_strings = ["A", "B", "D", "G"] source_ints = [7, 3, 5, 2] @@ -400,7 +400,7 @@ class SliceInputProducerTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): source_strings = ["A", "B", "D", "G"] source_ints = [7, 3, 5, 2] slices = inp.slice_input_producer( @@ -440,7 +440,7 @@ class DictHelperTest(test_lib.TestCase): class BatchTest(test_lib.TestCase): def _testOneThreadHelper(self, use_dict): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -500,7 +500,7 @@ class BatchTest(test_lib.TestCase): def testUint32DataTypes(self): values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32) batched = inp.batch([values], batch_size=2) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) sess.run(batched) @@ -511,7 +511,7 @@ class BatchTest(test_lib.TestCase): def testUint64DataTypes(self): values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64) batched = inp.batch([values], batch_size=2) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) sess.run(batched) @@ -520,7 +520,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testOneThreadDynamicPad(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -550,7 +550,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testOneThreadEnqueueMany(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -585,7 +585,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testManyThreads(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -625,7 +625,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testOneThreadSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 extra_elements = 5 @@ -682,7 +682,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testManyThreadsSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 extra_elements = 5 @@ -737,7 +737,7 @@ class BatchTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -754,7 +754,7 @@ class BatchTest(test_lib.TestCase): batched[0].op.inputs[0].op.node_def.attr["shared_name"]) def testCannotInferRankError(self): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(dtype=dtypes.int64) with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"): inp.batch([x], batch_size=2) @@ -797,7 +797,7 @@ class BatchTest(test_lib.TestCase): def _testKeepInputHelper(self, num_threads, enqueue_many, keep_input_vector=False): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 5 num_batches = 4 examples = variables.Variable(0) @@ -934,7 +934,7 @@ class BatchTest(test_lib.TestCase): batched = inp.maybe_batch( [sparse_t], keep_input=keep, batch_size=1, enqueue_many=True) - with self.test_session(): + with self.cached_session(): coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) @@ -952,7 +952,7 @@ class BatchTest(test_lib.TestCase): class BatchJoinTest(test_lib.TestCase): def _testTwoThreadsHelper(self, use_dict): - with self.test_session() as sess: + with self.cached_session() as sess: # Two threads, the first generates (0..69, "a"). num_a = 70 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1069,7 +1069,7 @@ class BatchJoinTest(test_lib.TestCase): batch_size=8) def DISABLED_testTwoThreadsDynamicPad(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Two threads, the first generates (0..69, ["a"] * 1..70). num_a = 70 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1144,7 +1144,7 @@ class BatchJoinTest(test_lib.TestCase): thread.join() def DISABLED_testTwoThreadsSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: extra_elements = 2 # Two threads, the first generates (0..69, "a"). num_a = 70 + extra_elements @@ -1243,7 +1243,7 @@ class BatchJoinTest(test_lib.TestCase): thread.join() def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: extra_elements = 2 # Two threads, the first generates (0..69, ["a"] * 1..70). num_a = 70 + extra_elements @@ -1338,7 +1338,7 @@ class BatchJoinTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1360,7 +1360,7 @@ class BatchJoinTest(test_lib.TestCase): batched[0].op.inputs[0].op.node_def.attr["shared_name"]) def testCannotInferRankError(self): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(dtype=dtypes.int64) with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"): inp.batch_join([[x]], batch_size=2) @@ -1371,7 +1371,7 @@ class BatchJoinTest(test_lib.TestCase): def _testKeepInputHelper(self, num_threads, enqueue_many, keep_input_vector=False): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 5 num_batches = 4 examples = variables.Variable(0) @@ -1511,7 +1511,7 @@ class BatchJoinTest(test_lib.TestCase): batched = inp.maybe_batch_join( [[sparse]], keep_input=keep, batch_size=1, enqueue_many=True) - with self.test_session(): + with self.cached_session(): coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) @@ -1529,7 +1529,7 @@ class BatchJoinTest(test_lib.TestCase): class ShuffleBatchTest(test_lib.TestCase): def _testOneThreadHelper(self, use_dict): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1594,7 +1594,7 @@ class ShuffleBatchTest(test_lib.TestCase): self._testOneThreadHelper(use_dict=True) def testOneThreadSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 extra_elements = 5 @@ -1650,7 +1650,7 @@ class ShuffleBatchTest(test_lib.TestCase): thread.join() def testManyThreads(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1697,7 +1697,7 @@ class ShuffleBatchTest(test_lib.TestCase): thread.join() def testManyThreadsSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 10 num_batches = 3 extra_elements = 5 @@ -1755,7 +1755,7 @@ class ShuffleBatchTest(test_lib.TestCase): thread.join() def testSharedName(self): - with self.test_session(): + with self.cached_session(): batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -1775,7 +1775,7 @@ class ShuffleBatchTest(test_lib.TestCase): def _testKeepInputHelper(self, num_threads, enqueue_many, keep_input_vector=False): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 5 num_batches = 4 examples = variables.Variable(0) @@ -1906,7 +1906,7 @@ class ShuffleBatchTest(test_lib.TestCase): class ShuffleBatchJoinTest(test_lib.TestCase): def _testTwoThreadsHelper(self, use_dict): - with self.test_session() as sess: + with self.cached_session() as sess: # Two threads, the first generates (0..24, "a"). num_a = 25 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -2017,7 +2017,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase): self._testTwoThreadsHelper(use_dict=True) def testTwoThreadsSmallerBatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Two threads, the first generates (0..26, "a"). extra_elements = 2 num_a = 25 + extra_elements @@ -2137,7 +2137,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase): seed=223607) def testSharedName(self): - with self.test_session(): + with self.cached_session(): batch_size = 10 num_batches = 3 zero64 = constant_op.constant(0, dtype=dtypes.int64) @@ -2162,7 +2162,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase): def _testKeepInputHelper(self, num_threads, enqueue_many, keep_input_vector=False): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = 5 num_batches = 4 examples = variables.Variable(0) diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py index 4f3cf01822..5a9215730e 100644 --- a/tensorflow/python/training/learning_rate_decay_test.py +++ b/tensorflow/python/training/learning_rate_decay_test.py @@ -62,7 +62,7 @@ class LRDecayTest(test_util.TensorFlowTestCase): self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) def testVariables(self): - with self.test_session(): + with self.cached_session(): step = variables.Variable(1) assign_1 = step.assign(1) assign_2 = step.assign(2) diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index f7e78071d8..8a21c39d32 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -123,7 +123,7 @@ class MomentumOptimizerTest(test.TestCase): ]), self.evaluate(var1)) def testBasic(self): - with self.test_session(): + with self.cached_session(): self.doTestBasic(use_resource=False) @test_util.run_in_graph_and_eager_modes(reset_test=True) @@ -162,7 +162,7 @@ class MomentumOptimizerTest(test.TestCase): def testNesterovMomentum(self): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -188,7 +188,7 @@ class MomentumOptimizerTest(test.TestCase): def testSparseNesterovMomentum(self): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) @@ -282,7 +282,7 @@ class MomentumOptimizerTest(test.TestCase): def testTensorLearningRateAndMomentum(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -435,7 +435,7 @@ class MomentumOptimizerTest(test.TestCase): return db_grad, db_out def testLikeDistBeliefMom01(self): - with self.test_session(): + with self.cached_session(): db_grad, db_out = self._dbParamsMom01() num_samples = len(db_grad) var0 = variables.Variable([0.0] * num_samples) @@ -449,7 +449,7 @@ class MomentumOptimizerTest(test.TestCase): def testSparse(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype)) var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2])) grads0 = ops.IndexedSlices( @@ -518,7 +518,7 @@ class MomentumOptimizerTest(test.TestCase): def testSharing(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index ff586b6c03..2d7799d66a 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -80,7 +80,7 @@ class ScaffoldTest(test.TestCase): self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertItemsEqual([b'my_var', b'my_local_var'], sess.run(scaffold.ready_op)) self.assertItemsEqual([b'my_var'], @@ -513,21 +513,21 @@ class WrappedSessionTest(test.TestCase): """_WrappedSession tests.""" def test_properties(self): - with self.test_session() as sess: + with self.cached_session() as sess: constant_op.constant(0.0) wrapped_sess = monitored_session._WrappedSession(sess) self.assertEquals(sess.graph, wrapped_sess.graph) self.assertEquals(sess.sess_str, wrapped_sess.sess_str) def test_should_stop_on_close(self): - with self.test_session() as sess: + with self.cached_session() as sess: wrapped_sess = monitored_session._WrappedSession(sess) self.assertFalse(wrapped_sess.should_stop()) wrapped_sess.close() self.assertTrue(wrapped_sess.should_stop()) def test_should_stop_uses_check_stop(self): - with self.test_session() as sess: + with self.cached_session() as sess: wrapped_sess = StopAtNSession(sess, 3) self.assertFalse(wrapped_sess.should_stop()) self.assertFalse(wrapped_sess.should_stop()) @@ -535,7 +535,7 @@ class WrappedSessionTest(test.TestCase): self.assertTrue(wrapped_sess.should_stop()) def test_should_stop_delegates_to_wrapped_session(self): - with self.test_session() as sess: + with self.cached_session() as sess: wrapped_sess0 = StopAtNSession(sess, 4) wrapped_sess1 = monitored_session._WrappedSession(wrapped_sess0) self.assertFalse(wrapped_sess1.should_stop()) @@ -545,7 +545,7 @@ class WrappedSessionTest(test.TestCase): self.assertTrue(wrapped_sess1.should_stop()) def test_close_twice(self): - with self.test_session() as sess: + with self.cached_session() as sess: wrapped_sess = monitored_session._WrappedSession(sess) wrapped_sess.close() self.assertTrue(wrapped_sess.should_stop()) @@ -553,7 +553,7 @@ class WrappedSessionTest(test.TestCase): self.assertTrue(wrapped_sess.should_stop()) def test_run(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) self.assertEqual(42, sess.run(v, feed_dict={c: 42})) @@ -570,7 +570,7 @@ class CoordinatedSessionTest(test.TestCase): """_CoordinatedSession tests.""" def test_properties(self): - with self.test_session() as sess: + with self.cached_session() as sess: constant_op.constant(0.0) coord = coordinator.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) @@ -578,7 +578,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertEquals(sess.sess_str, coord_sess.sess_str) def test_run(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) coord = coordinator.Coordinator() @@ -586,7 +586,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42})) def test_should_stop_on_close(self): - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertFalse(coord_sess.should_stop()) @@ -594,7 +594,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertTrue(coord_sess.should_stop()) def test_should_stop_on_coord_stop(self): - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() coord_sess = monitored_session._CoordinatedSession(sess, coord) self.assertFalse(coord_sess.should_stop()) @@ -602,7 +602,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertTrue(coord_sess.should_stop()) def test_dont_request_stop_on_exception_in_main_thread(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) coord = coordinator.Coordinator() @@ -616,7 +616,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertFalse(coord_sess.should_stop()) def test_stop_threads_on_close_after_exception(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) coord = coordinator.Coordinator() @@ -646,7 +646,7 @@ class CoordinatedSessionTest(test.TestCase): self.assertTrue(coord_sess.should_stop()) def test_stop_threads_on_close(self): - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = [ threading.Thread( @@ -664,7 +664,7 @@ class CoordinatedSessionTest(test.TestCase): def test_propagates_exception_trace(self): assertion = control_flow_ops.Assert(False, ['This should fail.']) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator(clean_stop_exception_types=()) coord_sess = monitored_session._CoordinatedSession(sess, coord) try: @@ -810,7 +810,7 @@ class RecoverableSessionTest(test.TestCase): return self._sess def test_properties(self): - with self.test_session() as sess: + with self.cached_session() as sess: constant_op.constant(0.0) recoverable_sess = monitored_session._RecoverableSession( self._SessionReturner(sess)) @@ -818,7 +818,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEquals(sess.sess_str, recoverable_sess.sess_str) def test_run(self): - with self.test_session() as sess: + with self.cached_session() as sess: c = constant_op.constant(0) v = array_ops.identity(c) recoverable_sess = monitored_session._RecoverableSession( @@ -826,7 +826,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51})) def test_recovery(self): - with self.test_session() as sess: + with self.cached_session() as sess: class StackSessionCreator(object): @@ -872,7 +872,7 @@ class RecoverableSessionTest(test.TestCase): recoverable_sess.run(v, feed_dict={c: -12}) def test_recovery_from_coordinator_exception(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = monitored_session.MonitoredSession( session_creator, @@ -897,7 +897,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(2, session_creator.number_of_sessions_created) def test_recovery_from_non_preemption_in_coordinator(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) hook = StopCoordinatorWithException( calls_before_stopping=2, @@ -926,7 +926,7 @@ class RecoverableSessionTest(test.TestCase): session.close() def test_recovery_from_session_getting_stuck(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = monitored_session.MonitoredSession( session_creator, @@ -950,7 +950,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(2, session_creator.number_of_sessions_created) def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = monitored_session.MonitoredSession( session_creator, @@ -980,7 +980,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(2, session_creator.number_of_sessions_created) def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) hook = StopCoordinatorWithException( calls_before_stopping=2, @@ -1014,7 +1014,7 @@ class RecoverableSessionTest(test.TestCase): session.close() def test_recovery_from_session_getting_stuck_when_run_hooks(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = monitored_session.MonitoredSession( session_creator, @@ -1058,7 +1058,7 @@ class RecoverableSessionTest(test.TestCase): return session def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = self.create_raw_session_with_failing_coordinator( session_creator, @@ -1090,7 +1090,7 @@ class RecoverableSessionTest(test.TestCase): self.assertEqual(2, session_creator.number_of_sessions_created) def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = self.create_raw_session_with_failing_coordinator( session_creator, @@ -1127,7 +1127,7 @@ class RecoverableSessionTest(test.TestCase): session.close() def test_recovery_from_session_getting_stuck_with_raw_session(self): - with self.test_session() as test_session: + with self.cached_session() as test_session: session_creator = CountingSessionCreator(test_session) session = self.create_raw_session_with_failing_coordinator( session_creator, @@ -2047,7 +2047,7 @@ class MonitoredSessionTest(test.TestCase): return value - with self.test_session() as test_session: + with self.cached_session() as test_session: with monitored_session.MonitoredSession( CountingSessionCreator(test_session)) as session: session.run(variables.global_variables_initializer()) @@ -2110,7 +2110,7 @@ class MonitoredSessionTest(test.TestCase): step_context.session.run(graph_side_effect) return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3}) - with self.test_session() as test_session: + with self.cached_session() as test_session: with monitored_session.MonitoredSession( CountingSessionCreator(test_session), hooks=[Hook(self)]) as session: diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index fdb8d795c3..93991d0e14 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -35,7 +35,7 @@ from tensorflow.python.training import saver as saver_lib class MovingAveragesTest(test.TestCase): def testAssignMovingAverageWithoutZeroDebias(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable([10.0, 11.0]) val = constant_op.constant([1.0, 2.0], dtypes.float32) decay = 0.25 @@ -49,7 +49,7 @@ class MovingAveragesTest(test.TestCase): var.eval()) def testAssignMovingAverage(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable([0.0, 0.0]) val = constant_op.constant([1.0, 2.0], dtypes.float32) decay = 0.25 @@ -86,7 +86,7 @@ class MovingAveragesTest(test.TestCase): moving_averages.assign_moving_average(var, 0.0, 0.99) def testWeightedMovingAverage(self): - with self.test_session() as sess: + with self.cached_session() as sess: decay = 0.5 weight = array_ops.placeholder(dtypes.float32, []) val = array_ops.placeholder(dtypes.float32, []) @@ -187,53 +187,53 @@ class ExponentialMovingAverageTest(test.TestCase): self.assertAllClose(expected, avg2.eval()) def testAverageVariablesNoNumUpdates_Scalar(self): - with self.test_session(): + with self.cached_session(): ema = moving_averages.ExponentialMovingAverage(0.25) self._CheckDecay(ema, actual_decay=0.25, dim=1) def testAverageVariablesNoNumUpdates_Scalar_Debias(self): - with self.test_session(): + with self.cached_session(): ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True) self._CheckDecay(ema, actual_decay=0.25, dim=1) def testAverageVariablesNoNumUpdates_Vector(self): - with self.test_session(): + with self.cached_session(): ema = moving_averages.ExponentialMovingAverage(0.25) self._CheckDecay(ema, actual_decay=0.25, dim=5) def testAverageVariablesNoNumUpdates_Vector_Debias(self): - with self.test_session(): + with self.cached_session(): ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True) self._CheckDecay(ema, actual_decay=0.25, dim=5) def testAverageVariablesNumUpdates_Scalar(self): - with self.test_session(): + with self.cached_session(): # With num_updates 1, the decay applied is 0.1818 ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1) self._CheckDecay(ema, actual_decay=0.181818, dim=1) def testAverageVariablesNumUpdates_Scalar_Debias(self): - with self.test_session(): + with self.cached_session(): # With num_updates 1, the decay applied is 0.1818 ema = moving_averages.ExponentialMovingAverage( 0.25, num_updates=1, zero_debias=True) self._CheckDecay(ema, actual_decay=0.181818, dim=1) def testAverageVariablesNumUpdates_Vector(self): - with self.test_session(): + with self.cached_session(): # With num_updates 1, the decay applied is 0.1818 ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1) self._CheckDecay(ema, actual_decay=0.181818, dim=5) def testAverageVariablesNumUpdates_Vector_Debias(self): - with self.test_session(): + with self.cached_session(): # With num_updates 1, the decay applied is 0.1818 ema = moving_averages.ExponentialMovingAverage( 0.25, num_updates=1, zero_debias=True) self._CheckDecay(ema, actual_decay=0.181818, dim=5) def testAverageVariablesWithControlDeps(self): - with self.test_session() as sess: + with self.cached_session() as sess: v0 = variables.Variable(0, name="v0") add_to_v0 = v0.assign_add(1) v1 = variables.Variable([10.0], name="v1") @@ -276,7 +276,7 @@ class ExponentialMovingAverageTest(test.TestCase): self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5) def averageVariablesNamesHelper(self, zero_debias): - with self.test_session(): + with self.cached_session(): v0 = variables.Variable(10.0, name="v0") v1 = variables.Variable(30.0, name="v1") # Add a non-trainable variable. @@ -320,7 +320,7 @@ class ExponentialMovingAverageTest(test.TestCase): def averageVariablesNamesRespectScopeHelper(self, zero_debias): # See discussion on #2740. - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope("scope1"): v0 = variables.Variable(10.0, name="v0") v1 = variables.Variable(30.0, name="v1") @@ -367,7 +367,7 @@ class ExponentialMovingAverageTest(test.TestCase): self.averageVariablesNamesRespectScopeHelper(zero_debias=False) def testSubsetAverageVariablesNames(self): - with self.test_session(): + with self.cached_session(): v0 = variables.Variable(10.0, name="v0") v1 = variables.Variable(30.0, name="v1") # Add a non-trainable variable. diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index dfe9176bea..7a7d01d50e 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -64,7 +64,7 @@ class OptimizerTest(test.TestCase): def testAggregationMethod(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) cost = 5 * var0 + 3 * var1 @@ -89,7 +89,7 @@ class OptimizerTest(test.TestCase): def testPrecomputedGradient(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) cost = 5 * var0 + 3 * var1 @@ -231,7 +231,7 @@ class OptimizerTest(test.TestCase): sgd_op.apply_gradients(grads_and_vars) def testTrainOp(self): - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0]) var1 = variables.Variable([3.0, 4.0]) cost = 5 * var0 + 3 * var1 @@ -244,7 +244,7 @@ class OptimizerTest(test.TestCase): def testConstraint(self): constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.) constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.) - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], constraint=constraint_01) var1 = variables.Variable([3.0, 4.0], diff --git a/tensorflow/python/training/proximal_adagrad_test.py b/tensorflow/python/training/proximal_adagrad_test.py index 430c16b351..74e06a5e2e 100644 --- a/tensorflow/python/training/proximal_adagrad_test.py +++ b/tensorflow/python/training/proximal_adagrad_test.py @@ -35,7 +35,7 @@ from tensorflow.python.training import proximal_adagrad class ProximalAdagradOptimizerTest(test.TestCase): def doTestProximalAdagradwithoutRegularization(self, use_resource=False): - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([0.0, 0.0]) var1 = variables.Variable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -71,7 +71,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): self.doTestProximalAdagradwithoutRegularization(use_resource=True) def testProximalAdagradwithoutRegularization2(self): - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0]) var1 = variables.Variable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -98,7 +98,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) @@ -114,7 +114,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): [[0, 1]], var0.eval(), atol=0.01) def testProximalAdagradWithL1(self): - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0]) var1 = variables.Variable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -140,7 +140,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): self.assertAllClose(np.array([2.959304, 1.029232]), v1_val) def testProximalAdagradWithL1_L2(self): - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0]) var1 = variables.Variable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -206,7 +206,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): return v0_val, v1_val def testEquivAdagradwithoutRegularization(self): - with self.test_session(): + with self.cached_session(): val0, val1 = self.applyOptimizer( proximal_adagrad.ProximalAdagradOptimizer( 3.0, @@ -214,7 +214,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.test_session(): + with self.cached_session(): val2, val3 = self.applyOptimizer( adagrad.AdagradOptimizer( 3.0, initial_accumulator_value=0.1)) @@ -223,7 +223,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): self.assertAllClose(val1, val3) def testEquivSparseAdagradwithoutRegularization(self): - with self.test_session(): + with self.cached_session(): val0, val1 = self.applyOptimizer( proximal_adagrad.ProximalAdagradOptimizer( 3.0, @@ -232,7 +232,7 @@ class ProximalAdagradOptimizerTest(test.TestCase): l2_regularization_strength=0.0), is_sparse=True) - with self.test_session(): + with self.cached_session(): val2, val3 = self.applyOptimizer( adagrad.AdagradOptimizer( 3.0, initial_accumulator_value=0.1), diff --git a/tensorflow/python/training/proximal_gradient_descent_test.py b/tensorflow/python/training/proximal_gradient_descent_test.py index 4e4812fe60..f77f68b234 100644 --- a/tensorflow/python/training/proximal_gradient_descent_test.py +++ b/tensorflow/python/training/proximal_gradient_descent_test.py @@ -36,7 +36,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase): def doTestProximalGradientDescentwithoutRegularization( self, use_resource=False): - with self.test_session() as sess: + with self.cached_session() as sess: if use_resource: var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) @@ -69,7 +69,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase): self.doTestProximalGradientDescentwithoutRegularization(use_resource=True) def testProximalGradientDescentwithoutRegularization2(self): - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0]) var1 = variables.Variable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -94,7 +94,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) @@ -111,7 +111,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase): [[-111, -138]], var0.eval(), atol=0.01) def testProximalGradientDescentWithL1_L2(self): - with self.test_session() as sess: + with self.cached_session() as sess: var0 = variables.Variable([1.0, 2.0]) var1 = variables.Variable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -174,7 +174,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase): return v0_val, v1_val def testEquivSparseGradientDescentwithoutRegularization(self): - with self.test_session(): + with self.cached_session(): val0, val1 = self.applyOptimizer( proximal_gradient_descent.ProximalGradientDescentOptimizer( 3.0, @@ -182,7 +182,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase): l2_regularization_strength=0.0), is_sparse=True) - with self.test_session(): + with self.cached_session(): val2, val3 = self.applyOptimizer( gradient_descent.GradientDescentOptimizer(3.0), is_sparse=True) @@ -190,14 +190,14 @@ class ProximalGradientDescentOptimizerTest(test.TestCase): self.assertAllClose(val1, val3) def testEquivGradientDescentwithoutRegularization(self): - with self.test_session(): + with self.cached_session(): val0, val1 = self.applyOptimizer( proximal_gradient_descent.ProximalGradientDescentOptimizer( 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.test_session(): + with self.cached_session(): val2, val3 = self.applyOptimizer( gradient_descent.GradientDescentOptimizer(3.0)) diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py index 900f9706ac..9b9e28af2b 100644 --- a/tensorflow/python/training/queue_runner_test.py +++ b/tensorflow/python/training/queue_runner_test.py @@ -41,7 +41,7 @@ _MockOp = collections.namedtuple("MockOp", ["name"]) class QueueRunnerTest(test.TestCase): def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: # CountUpTo will raise OUT_OF_RANGE when it reaches the count. zero64 = constant_op.constant(0, dtype=dtypes.int64) var = variables.Variable(zero64) @@ -61,7 +61,7 @@ class QueueRunnerTest(test.TestCase): self.assertEqual(3, var.eval()) def testTwoOps(self): - with self.test_session() as sess: + with self.cached_session() as sess: # CountUpTo will raise OUT_OF_RANGE when it reaches the count. zero64 = constant_op.constant(0, dtype=dtypes.int64) var0 = variables.Variable(zero64) @@ -84,7 +84,7 @@ class QueueRunnerTest(test.TestCase): self.assertEqual(30, var1.eval()) def testExceptionsCaptured(self): - with self.test_session() as sess: + with self.cached_session() as sess: queue = data_flow_ops.FIFOQueue(10, dtypes.float32) qr = queue_runner_impl.QueueRunner(queue, [_MockOp("i fail"), _MockOp("so fail")]) @@ -100,7 +100,7 @@ class QueueRunnerTest(test.TestCase): self.assertTrue("Operation not in the graph" in str(exceptions[1])) def testRealDequeueEnqueue(self): - with self.test_session() as sess: + with self.cached_session() as sess: q0 = data_flow_ops.FIFOQueue(3, dtypes.float32) enqueue0 = q0.enqueue((10.0,)) close0 = q0.close() @@ -128,7 +128,7 @@ class QueueRunnerTest(test.TestCase): dequeue1.eval() def testRespectCoordShouldStop(self): - with self.test_session() as sess: + with self.cached_session() as sess: # CountUpTo will raise OUT_OF_RANGE when it reaches the count. zero64 = constant_op.constant(0, dtype=dtypes.int64) var = variables.Variable(zero64) @@ -152,7 +152,7 @@ class QueueRunnerTest(test.TestCase): self.assertEqual(0, var.eval()) def testRequestStopOnException(self): - with self.test_session() as sess: + with self.cached_session() as sess: queue = data_flow_ops.FIFOQueue(10, dtypes.float32) qr = queue_runner_impl.QueueRunner(queue, [_MockOp("not an op")]) coord = coordinator.Coordinator() @@ -164,7 +164,7 @@ class QueueRunnerTest(test.TestCase): coord.join() def testGracePeriod(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The enqueue will quickly block. queue = data_flow_ops.FIFOQueue(2, dtypes.float32) enqueue = queue.enqueue((10.0,)) @@ -181,7 +181,7 @@ class QueueRunnerTest(test.TestCase): coord.join(stop_grace_period_secs=1.0) def testMultipleSessions(self): - with self.test_session() as sess: + with self.cached_session() as sess: with session.Session() as other_sess: zero64 = constant_op.constant(0, dtype=dtypes.int64) var = variables.Variable(zero64) @@ -196,7 +196,7 @@ class QueueRunnerTest(test.TestCase): self.assertEqual(len(threads), len(other_threads)) def testIgnoreMultiStarts(self): - with self.test_session() as sess: + with self.cached_session() as sess: # CountUpTo will raise OUT_OF_RANGE when it reaches the count. zero64 = constant_op.constant(0, dtype=dtypes.int64) var = variables.Variable(zero64) @@ -212,7 +212,7 @@ class QueueRunnerTest(test.TestCase): self.assertEqual([], new_threads) def testThreads(self): - with self.test_session() as sess: + with self.cached_session() as sess: # CountUpTo will raise OUT_OF_RANGE when it reaches the count. zero64 = constant_op.constant(0, dtype=dtypes.int64) var = variables.Variable(zero64) @@ -256,7 +256,7 @@ class QueueRunnerTest(test.TestCase): init_op = variables.global_variables_initializer() qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) queue_runner_impl.add_queue_runner(qr) - with self.test_session() as sess: + with self.cached_session() as sess: init_op.run() threads = queue_runner_impl.start_queue_runners(sess) for t in threads: @@ -273,7 +273,7 @@ class QueueRunnerTest(test.TestCase): init_op = variables.global_variables_initializer() qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) queue_runner_impl.add_queue_runner(qr) - with self.test_session(): + with self.cached_session(): init_op.run() with self.assertRaisesRegexp(TypeError, "tf.Session"): queue_runner_impl.start_queue_runners("NotASession") @@ -286,7 +286,7 @@ class QueueRunnerTest(test.TestCase): init_op = variables.global_variables_initializer() qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) queue_runner_impl.add_queue_runner(qr) - with self.test_session(): + with self.cached_session(): init_op.run() threads = queue_runner_impl.start_queue_runners( monitored_session.MonitoredSession()) diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py index 6043327384..4f5f96e2b4 100644 --- a/tensorflow/python/training/rmsprop_test.py +++ b/tensorflow/python/training/rmsprop_test.py @@ -165,7 +165,7 @@ class RMSPropOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) @@ -187,7 +187,7 @@ class RMSPropOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariableCentered(self): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index f5b2a22327..0ac84813c8 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -324,7 +324,7 @@ class SaverTest(test.TestCase): save_relative_paths=True) init_all_op = [variables.global_variables_initializer(), v2_init] - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize all variables sess.run(init_all_op) @@ -349,7 +349,7 @@ class SaverTest(test.TestCase): # Start a second session. In that session the parameter nodes # have not been initialized either. - with self.test_session() as sess: + with self.cached_session() as sess: v0 = variables.Variable(-1.0, name="v0") v1 = variables.Variable(-1.0, name="v1") v2 = saver_test_utils.CheckpointedOp(name="v2") @@ -373,7 +373,7 @@ class SaverTest(test.TestCase): v0 = variables.Variable(0, name="v0") filename = b"somerandomfilename" save = saver_module.Saver({"v0": v0}, filename=filename) - with self.test_session() as sess: + with self.cached_session() as sess: tensor = sess.graph.get_tensor_by_name( save.saver_def.filename_tensor_name) self.assertEqual(sess.run(tensor), filename) @@ -381,7 +381,7 @@ class SaverTest(test.TestCase): def testInvalidPath(self): v0 = variables.Variable(0, name="v0") for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): - with self.test_session() as sess: + with self.cached_session() as sess: save = saver_module.Saver({"v0": v0}, write_version=ver) with self.assertRaisesRegexp( ValueError, "The passed save_path is not a valid checkpoint:"): @@ -390,7 +390,7 @@ class SaverTest(test.TestCase): def testInt64(self): save_path = os.path.join(self.get_temp_dir(), "int64") - with self.test_session() as sess: + with self.cached_session() as sess: # Build a graph with 1 node, and save and restore for them. v = variables.Variable(np.int64(15), name="v") save = saver_module.Saver({"v": v}, restore_sequentially=True) @@ -401,7 +401,7 @@ class SaverTest(test.TestCase): self.assertTrue(isinstance(val, six.string_types)) self.assertEqual(save_path, val) - with self.test_session() as sess: + with self.cached_session() as sess: v = variables.Variable(np.int64(-1), name="v") save = saver_module.Saver({"v": v}) @@ -559,12 +559,12 @@ class SaverTest(test.TestCase): def testAllowEmpty(self): save_path = os.path.join(self.get_temp_dir(), "allow_empty") - with self.test_session() as sess: + with self.cached_session() as sess: _ = constant_op.constant(1) save = saver_module.Saver(allow_empty=True) val = save.save(sess, save_path) self.assertIsNone(val) - with self.test_session() as sess: + with self.cached_session() as sess: save = saver_module.Saver(allow_empty=True) save.restore(sess, save_path) @@ -740,7 +740,7 @@ class SaverTest(test.TestCase): # save succeeds or fails is implementation dependent. Therefore we allow # both cases. try: - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize all variables sess.run(init_all_op) @@ -751,7 +751,7 @@ class SaverTest(test.TestCase): # Save the graph. save.save(sess, save_path) - with self.test_session() as sess: + with self.cached_session() as sess: # Restore the saved values in the parameter nodes. save.restore(sess, save_path) # Check that the parameter nodes have been restored. @@ -775,7 +775,7 @@ class SaverTest(test.TestCase): save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True) init_all_op = variables.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize all variables sess.run(init_all_op) @@ -983,7 +983,7 @@ class SaveRestoreShardedTest(test.TestCase): os.path.join(self.get_temp_dir(), "sharded_basics")) def testSaverDef(self): - with self.test_session(): + with self.cached_session(): v0 = variables.Variable(123, name="v0") save = saver_module.Saver({"v0": v0}, sharded=True) sd = save.as_saver_def() @@ -1209,7 +1209,7 @@ class MaxToKeepTest(test.TestCase): def testNonSharded(self): save_dir = self._get_test_dir("max_to_keep_non_sharded") - with self.test_session() as sess: + with self.cached_session() as sess: v = variables.Variable(10.0, name="v") save = saver_module.Saver({"v": v}, max_to_keep=2) variables.global_variables_initializer().run() @@ -1447,7 +1447,7 @@ class MaxToKeepTest(test.TestCase): save_dir = self._get_test_dir("no_max_to_keep") save_dir2 = self._get_test_dir("max_to_keep_0") - with self.test_session() as sess: + with self.cached_session() as sess: v = variables.Variable(10.0, name="v") variables.global_variables_initializer().run() @@ -1474,7 +1474,7 @@ class MaxToKeepTest(test.TestCase): def testNoMetaGraph(self): save_dir = self._get_test_dir("no_meta_graph") - with self.test_session() as sess: + with self.cached_session() as sess: v = variables.Variable(10.0, name="v") save = saver_module.Saver({"v": v}) variables.global_variables_initializer().run() @@ -1497,7 +1497,7 @@ class KeepCheckpointEveryNHoursTest(test.TestCase): def testNonSharded(self, mock_time): save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") - with self.test_session() as sess: + with self.cached_session() as sess: v = variable_scope.variable([10.0], name="v") # Run the initializer NOW to avoid the 0.5s overhead of the first Run() # call, which throws the test timing off in fastbuild mode. @@ -1630,7 +1630,7 @@ class MetaGraphTest(test.TestCase): def testAddCollectionDef(self): test_dir = self._get_test_dir("good_collection") filename = os.path.join(test_dir, "metafile") - with self.test_session(): + with self.cached_session(): # Creates a graph. v0 = variables.Variable(1.0, name="v0") control_flow_ops.cond( @@ -1685,7 +1685,7 @@ class MetaGraphTest(test.TestCase): self, meta_graph_def, new_meta_graph_def) def testAddCollectionDefFails(self): - with self.test_session(): + with self.cached_session(): # Creates a graph. v0 = variables.Variable(10.0, name="v0") # Creates a saver. @@ -1870,7 +1870,7 @@ class MetaGraphTest(test.TestCase): def testSliceVariable(self): test_dir = self._get_test_dir("slice_saver") filename = os.path.join(test_dir, "metafile") - with self.test_session(): + with self.cached_session(): v1 = variables.Variable([20.0], name="v1") v2 = variables.Variable([20.0], name="v2") v2._set_save_slice_info( @@ -1946,7 +1946,7 @@ class MetaGraphTest(test.TestCase): ops_lib.add_to_collection("logits", logits) init_all_op = variables.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: # Initializes all the variables. sess.run(init_all_op) # Runs to logit. @@ -2120,7 +2120,7 @@ class MetaGraphTest(test.TestCase): # pylint: enable=g-long-lambda def testStrippedOpListDef(self): - with self.test_session(): + with self.cached_session(): # Creates a graph. v0 = variables.Variable(0.0) var = variables.Variable(10.0) @@ -2160,7 +2160,7 @@ class MetaGraphTest(test.TestCase): # With strip_default_attrs enabled, attributes "T" (float32) and "Tout" # (complex64) in the "Complex" op must be removed. - with self.test_session(): + with self.cached_session(): real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") @@ -2397,7 +2397,7 @@ class CheckpointReaderTest(test.TestCase): }, write_version=self._WRITE_VERSION) save_path = os.path.join(self.get_temp_dir(), "ckpt_for_debug_string" + str(self._WRITE_VERSION)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_all_op) # Saves a checkpoint. save.save(sess, save_path) @@ -2853,7 +2853,7 @@ class CheckpointableCompatibilityTests(test.TestCase): saver = saver_module.Saver(var_list=[v]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") - with self.test_session() as sess: + with self.cached_session() as sess: self.evaluate(v.non_dep_variable.assign(42.)) save_path = saver.save(sess, prefix) self.evaluate(v.non_dep_variable.assign(43.)) @@ -2867,7 +2867,7 @@ class CheckpointableCompatibilityTests(test.TestCase): test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") self.evaluate(v.non_dep_variable.assign(42.)) - with self.test_session() as sess: + with self.cached_session() as sess: save_path = saver.save(sess, prefix) self.evaluate(v.non_dep_variable.assign(43.)) self.evaluate(v.mirrored.assign(44.)) @@ -2900,7 +2900,7 @@ class CheckpointableCompatibilityTests(test.TestCase): saver = saver_module.Saver(var_list=[v]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") - with self.test_session() as sess: + with self.cached_session() as sess: save_path = saver.save(sess, prefix) self.assertEqual(1, v.eval_count) saver.restore(sess, save_path) @@ -2957,7 +2957,7 @@ class CheckpointableCompatibilityTests(test.TestCase): b = resource_variable_ops.ResourceVariable(1., name="b") a_saver = saver_module.Saver([a]) b_saver = saver_module.Saver([b]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(a.initializer) save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) with self.assertRaisesRegexp( diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py index d7e6dac95b..f1d18f7704 100644 --- a/tensorflow/python/training/session_manager_test.py +++ b/tensorflow/python/training/session_manager_test.py @@ -98,7 +98,7 @@ class SessionManagerTest(test.TestCase): os.rename(checkpoint_dir, checkpoint_dir2) gfile.MakeDirs(checkpoint_dir) v = variables.Variable([6.0, 7.0, 8.0], name="v") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) @@ -236,7 +236,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( @@ -294,7 +294,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( @@ -326,7 +326,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), @@ -362,7 +362,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( @@ -467,7 +467,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="x") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) self.assertEqual(False, variables.is_variable_initialized(x).eval()) @@ -519,7 +519,7 @@ class SessionManagerTest(test.TestCase): collections=[ops.GraphKeys.LOCAL_VARIABLES], name="x_res") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) self.assertEqual(False, variables.is_variable_initialized(x).eval()) @@ -566,7 +566,7 @@ class SessionManagerTest(test.TestCase): with ops.Graph().as_default(): i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0]) v = variables.Variable(array_ops.identity(i), name="v") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) sm = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) @@ -585,7 +585,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( @@ -602,7 +602,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( @@ -619,7 +619,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( @@ -640,7 +640,7 @@ class SessionManagerTest(test.TestCase): trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( @@ -714,7 +714,7 @@ class ObsoleteSessionManagerTest(test.TestCase): os.rename(checkpoint_dir, checkpoint_dir2) gfile.MakeDirs(checkpoint_dir) v = variables.Variable([6.0, 7.0, 8.0], name="v") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) session_manager.SessionManager( ready_op=variables.assert_variables_initialized()) @@ -769,7 +769,7 @@ class ObsoleteSessionManagerTest(test.TestCase): # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): v = variables.Variable(2, name="v") - with self.test_session(): + with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) sm2 = session_manager.SessionManager( ready_op=variables.assert_variables_initialized()) diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py index 08a3c8dc53..6d6364169f 100644 --- a/tensorflow/python/training/slot_creator_test.py +++ b/tensorflow/python/training/slot_creator_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import slot_creator class SlotCreatorTest(test.TestCase): def testCreateSlotFromVariable(self): - with self.test_session(): + with self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") @@ -44,7 +44,7 @@ class SlotCreatorTest(test.TestCase): self.assertAllEqual([1.0, 2.5], slot.eval()) def testCreateSlotFromTensor(self): - with self.test_session(): + with self.cached_session(): v = constant_op.constant([1.0, 2.5], name="const") slot = slot_creator.create_slot(v, v * 2, name="slot") @@ -56,7 +56,7 @@ class SlotCreatorTest(test.TestCase): self.assertAllEqual([2.0, 5.0], slot.eval()) def testCreateZerosSlotFromVariable(self): - with self.test_session(): + with self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot( @@ -70,7 +70,7 @@ class SlotCreatorTest(test.TestCase): self.assertAllEqual([0.0, 0.0], slot.eval()) def testCreateZerosSlotFromDynamicShapedVariable(self): - with self.test_session(): + with self.cached_session(): dyn_shape = constant_op.constant([2], dtype=dtypes.int32) dyn_shape = array_ops.placeholder_with_default(dyn_shape, shape=[None]) @@ -91,7 +91,7 @@ class SlotCreatorTest(test.TestCase): self.assertAllEqual([0.0, 0.0], slot.eval()) def testCreateZerosSlotFromTensor(self): - with self.test_session(): + with self.cached_session(): v = constant_op.constant([1.0, 2.5], name="const") with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot") @@ -104,7 +104,7 @@ class SlotCreatorTest(test.TestCase): self.assertAllEqual([0.0, 0.0], slot.eval()) def testCreateZerosSlotFromDynamicShapedTensor(self): - with self.test_session(): + with self.cached_session(): v = random_ops.random_uniform([2], dtype=dtypes.float64) v = array_ops.placeholder_with_default(v, shape=[None], name="const") with ops.control_dependencies(None): @@ -120,7 +120,7 @@ class SlotCreatorTest(test.TestCase): def testCreateSlotFromVariableRespectsScope(self): # See discussion on #2740. - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope("scope"): v = variables.Variable([1.0, 2.5], name="var") slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index 71ed88093a..caf6eba3e0 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -795,7 +795,7 @@ class SupervisorTest(test.TestCase): self.assertRaises(StopIteration, lambda: next(rr)) # There should be a checkpoint file with the variable "foo" - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v = variables.Variable([10.10], name="foo") sav = saver_lib.Saver([v]) sav.restore(sess, save_path) @@ -859,14 +859,14 @@ class SupervisorTest(test.TestCase): self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) self.assertRaises(StopIteration, lambda: next(rr)) # There should be a checkpoint file with the variable "foo" - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v = variables.Variable([-12], name="global_step") sav = saver_lib.Saver([v]) sav.restore(sess, save_path) self.assertEqual(123, v.eval()[0]) def testNoQueueRunners(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners")) self.assertEqual(0, len(sv.start_queue_runners(sess))) sv.stop() diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py index 3ee0f6aaa2..6c860cd452 100644 --- a/tensorflow/python/training/warm_starting_util_test.py +++ b/tensorflow/python/training/warm_starting_util_test.py @@ -1133,7 +1133,7 @@ class WarmStartingUtilTest(test.TestCase): # Unused variable names raises ValueError. with ops.Graph().as_default(): - with self.test_session() as sess: + with self.cached_session() as sess: x = variable_scope.get_variable( "x", shape=[4, 1], -- cgit v1.2.3 From acf0ee82092727afc2067316982407cf5e496f75 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 14:36:52 -0700 Subject: Move from deprecated self.test_session() to self.cached_session(). self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336417 --- tensorflow/compiler/tests/adam_test.py | 6 +-- tensorflow/compiler/tests/reshape_op_test.py | 2 +- tensorflow/compiler/tests/xla_ops_test.py | 2 +- tensorflow/contrib/autograph/utils/misc_test.py | 4 +- tensorflow/contrib/autograph/utils/py_func_test.py | 8 ++-- .../contrib/autograph/utils/tensor_list_test.py | 8 ++-- .../python/learn/learn_io/data_feeder_test.py | 4 +- .../python/learn/learn_io/generator_io_test.py | 26 +++++----- .../learn/python/learn/learn_io/pandas_io_test.py | 18 +++---- .../ops/sharded_mutable_dense_hashtable_test.py | 6 +-- .../python/ops/sparse_feature_column_test.py | 4 +- .../rnn/python/kernel_tests/core_rnn_test.py | 2 +- .../rnn/python/kernel_tests/fused_rnn_cell_test.py | 4 +- .../rnn/python/kernel_tests/rnn_cell_test.py | 56 +++++++++++----------- tensorflow/python/eager/function_test.py | 28 +++++------ tensorflow/python/eager/graph_only_ops_test.py | 4 +- tensorflow/python/eager/tape_test.py | 4 +- tensorflow/python/keras/layers/gru_test.py | 8 ++-- tensorflow/python/keras/layers/lstm_test.py | 22 ++++----- tensorflow/python/keras/layers/simplernn_test.py | 8 ++-- 20 files changed, 112 insertions(+), 112 deletions(-) diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index df0f21471a..058576b3d4 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -98,7 +98,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -140,7 +140,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py index 84c6777940..96e0b07475 100644 --- a/tensorflow/compiler/tests/reshape_op_test.py +++ b/tensorflow/compiler/tests/reshape_op_test.py @@ -33,7 +33,7 @@ class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): ('64_bit_index', dtypes.int64)) def testBasic(self, index_dtype): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): shape = constant_op.constant([3, 2], dtype=index_dtype) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 3f928a1bea..0f3843dc1e 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -34,7 +34,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): def _assertOpOutputMatchesExpected(self, op, args, expected, equality_fn=None): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) diff --git a/tensorflow/contrib/autograph/utils/misc_test.py b/tensorflow/contrib/autograph/utils/misc_test.py index 71e358c33e..968ea03df6 100644 --- a/tensorflow/contrib/autograph/utils/misc_test.py +++ b/tensorflow/contrib/autograph/utils/misc_test.py @@ -31,7 +31,7 @@ class MiscTest(test.TestCase): new_a = alias_tensors(a) self.assertFalse(new_a is a) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(1, sess.run(new_a)) def test_alias_tensors(self): @@ -46,7 +46,7 @@ class MiscTest(test.TestCase): self.assertTrue(new_v is v) self.assertTrue(new_s is s) self.assertTrue(new_l is l) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(1, sess.run(new_a)) diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/contrib/autograph/utils/py_func_test.py index 2468263142..f60b57bcce 100644 --- a/tensorflow/contrib/autograph/utils/py_func_test.py +++ b/tensorflow/contrib/autograph/utils/py_func_test.py @@ -31,7 +31,7 @@ class PyFuncTest(test.TestCase): def test_fn(a, b, c): return a + b + c - with self.test_session() as sess: + with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, constant_op.constant(1), 1)) self.assertEqual(3, sess.run(result)) @@ -52,7 +52,7 @@ class PyFuncTest(test.TestCase): def test_fn(a, b): return a * b.foo - with self.test_session() as sess: + with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass())) self.assertEqual(35, sess.run(result)) result = py_func.wrap_py_func(test_fn, dtypes.int64, @@ -69,7 +69,7 @@ class PyFuncTest(test.TestCase): def test_fn(a, b, c, d): return a * b.foo + c * d.foo - with self.test_session() as sess: + with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), { 'c': 11, 'd': TestClass(13) @@ -89,7 +89,7 @@ class PyFuncTest(test.TestCase): def test_fn(_): side_counter[0] += 1 - with self.test_session() as sess: + with self.cached_session() as sess: result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True) self.assertEqual(1, sess.run(result)) self.assertEqual([1], side_counter) diff --git a/tensorflow/contrib/autograph/utils/tensor_list_test.py b/tensorflow/contrib/autograph/utils/tensor_list_test.py index d58489eb68..faaf7b7877 100644 --- a/tensorflow/contrib/autograph/utils/tensor_list_test.py +++ b/tensorflow/contrib/autograph/utils/tensor_list_test.py @@ -42,18 +42,18 @@ class TensorListTest(test.TestCase): l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32) l = tl.dynamic_list_append(l, 1) s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual(sess.run(s), [1]) l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) l = tl.dynamic_list_append(l, 1) s = l.stack() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual(sess.run(s), [1]) l = tl.TensorList(self._shape(()), dtypes.int32) l = tl.dynamic_list_append(l, 1) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual(sess.run(l[0]), 1) def test_list_append_python(self): @@ -107,7 +107,7 @@ class TensorListTest(test.TestCase): l0 = l[0] l[0] = b l1 = l[0] - with self.test_session() as sess: + with self.cached_session() as sess: l0, l1, a, b = sess.run([l0, l1, a, b]) self.assertEqual(l0, a) self.assertEqual(l1, b) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py index 5e07b9313f..284a4f45f6 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py @@ -147,7 +147,7 @@ class DataFeederTest(test.TestCase): def test_unsupervised(self): def func(feeder): - with self.test_session(): + with self.cached_session(): inp, _ = feeder.input_builder() feed_dict_fn = feeder.get_feed_dict_fn() feed_dict = feed_dict_fn() @@ -181,7 +181,7 @@ class DataFeederTest(test.TestCase): def test_epoch(self): def func(feeder): - with self.test_session(): + with self.cached_session(): feeder.input_builder() epoch = feeder.make_epoch_variable() feed_dict_fn = feeder.get_feed_dict_fn() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py index 7e81f2b7d9..5e90d1fa20 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py @@ -38,7 +38,7 @@ class GeneratorIoTest(test.TestCase): 'label': np.ones(1) * index - 32 } - with self.test_session() as session: + with self.cached_session() as session: input_fn = generator_io.generator_input_fn( generator, target_key='label', @@ -68,7 +68,7 @@ class GeneratorIoTest(test.TestCase): for index in range(2): yield {'a': np.ones(1) * index} - with self.test_session() as session: + with self.cached_session() as session: input_fn = generator_io.generator_input_fn( generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1) features = input_fn() @@ -97,7 +97,7 @@ class GeneratorIoTest(test.TestCase): 'label2': np.ones(1) * index - 64, } - with self.test_session() as session: + with self.cached_session() as session: input_fn = generator_io.generator_input_fn( generator, target_key=['label', 'label2'], @@ -134,7 +134,7 @@ class GeneratorIoTest(test.TestCase): 'label': np.ones((3, 3)) * index - 32 } - with self.test_session() as session: + with self.cached_session() as session: input_fn = generator_io.generator_input_fn( generator, target_key='label', @@ -162,7 +162,7 @@ class GeneratorIoTest(test.TestCase): def testGeneratorInputFnWithXAsNonGeneratorFunction(self): x = np.arange(32, 36) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, 'x must be generator function'): failing_input_fn = generator_io.generator_input_fn( x, batch_size=2, shuffle=False, num_epochs=1) @@ -173,7 +173,7 @@ class GeneratorIoTest(test.TestCase): def generator(): return np.arange(32, 36) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) @@ -184,7 +184,7 @@ class GeneratorIoTest(test.TestCase): def generator(): yield np.arange(32, 36) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) @@ -201,7 +201,7 @@ class GeneratorIoTest(test.TestCase): } y = np.arange(32, 36) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, 'target_key must be str or' ' Container of str'): failing_input_fn = generator_io.generator_input_fn( @@ -219,7 +219,7 @@ class GeneratorIoTest(test.TestCase): } y = ['label', np.arange(10)] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, 'target_key must be str or' ' Container of str'): failing_input_fn = generator_io.generator_input_fn( @@ -237,7 +237,7 @@ class GeneratorIoTest(test.TestCase): } y = ['label', 'target'] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(KeyError, 'target_key not in yielded dict'): failing_input_fn = generator_io.generator_input_fn( generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1) @@ -253,7 +253,7 @@ class GeneratorIoTest(test.TestCase): 'label': np.ones(1) * index - 32 } - with self.test_session() as session: + with self.cached_session() as session: input_fn = generator_io.generator_input_fn( generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1) features = input_fn() @@ -283,7 +283,7 @@ class GeneratorIoTest(test.TestCase): 'label': np.ones(1) * index - 32 } - with self.test_session() as session: + with self.cached_session() as session: input_fn = generator_io.generator_input_fn( generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1) features = input_fn() @@ -319,7 +319,7 @@ class GeneratorIoTest(test.TestCase): 'label': np.ones(1) * index - 32 } - with self.test_session() as session: + with self.cached_session() as session: input_fn = generator_io.generator_input_fn( generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1) features = input_fn() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py index c738f0e8f3..396539a76a 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py @@ -65,7 +65,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_ProducesExpectedOutputs(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) @@ -79,7 +79,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: index = np.arange(100, 102) a = np.arange(2) b = np.arange(32, 34) @@ -107,7 +107,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: index = np.arange(100, 105) a = np.arange(5) b = np.arange(32, 37) @@ -146,7 +146,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_OnlyX(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, _ = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y=None, batch_size=2, shuffle=False, num_epochs=1) @@ -159,7 +159,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_ExcludesIndex(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) @@ -182,7 +182,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_RespectsEpoch_NoShuffle(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=4, shuffle=False, num_epochs=1) @@ -192,7 +192,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_RespectsEpoch_WithShuffle(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=4, shuffle=True, num_epochs=1) @@ -202,7 +202,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2) @@ -213,7 +213,7 @@ class PandasIoTest(test.TestCase): if not HAS_PANDAS: return x, y = self.makeTestDataFrame() - with self.test_session() as session: + with self.cached_session() as session: input_fn = pandas_io.pandas_input_fn( x, y, batch_size=3, shuffle=False, num_epochs=1) diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index a2d82cf800..553b116a3b 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -30,7 +30,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): def testShardedMutableHashTable(self): for num_shards in [1, 3, 10]: - with self.test_session(): + with self.cached_session(): default_val = -1 empty_key = 0 keys = constant_op.constant([11, 12, 13], dtypes.int64) @@ -53,7 +53,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): def testShardedMutableHashTableVectors(self): for num_shards in [1, 3, 10]: - with self.test_session(): + with self.cached_session(): default_val = [-0.1, 0.2] empty_key = [0, 1] keys = constant_op.constant([[11, 12], [13, 14], [15, 16]], @@ -79,7 +79,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): output.eval()) def testExportSharded(self): - with self.test_session(): + with self.cached_session(): empty_key = -2 default_val = -1 num_shards = 2 diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py index 237a6812b7..51c4f68543 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py @@ -36,13 +36,13 @@ class SparseFeatureColumnTest(TensorFlowTestCase): self.assertTrue(isinstance(sfc.example_indices, ops.Tensor)) self.assertTrue(isinstance(sfc.feature_indices, ops.Tensor)) self.assertEqual(sfc.feature_values, None) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_example_indices, sfc.example_indices.eval()) self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval()) expected_feature_values = [1.0, 2.0, 3.0, 4.0] sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0], expected_feature_values) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_feature_values, sfc.feature_values.eval()) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index aa4562be7c..bf699db3ed 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -1906,7 +1906,7 @@ class StateSaverRNNTest(test.TestCase): state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units) out, state, state_saver = self._factory(scope=None, state_saver=state_saver) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) sess.run(variables_lib.local_variables_initializer()) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py index f2a032e41e..8d34b9e852 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py @@ -38,7 +38,7 @@ class FusedRnnCellTest(test.TestCase): def testBasicRNNFusedWrapper(self): """This test checks that using a wrapper for BasicRNN works as expected.""" - with self.test_session() as sess: + with self.cached_session() as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=19890212) cell = rnn_cell.BasicRNNCell(10) @@ -106,7 +106,7 @@ class FusedRnnCellTest(test.TestCase): self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) def testTimeReversedFusedRNN(self): - with self.test_session() as sess: + with self.cached_session() as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=19890213) fw_cell = rnn_cell.BasicRNNCell(10) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 2df8f0ec05..6689664fb9 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -47,7 +47,7 @@ from tensorflow.python.util import nest class RNNCellTest(test.TestCase): def testCoupledInputForgetGateLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 state_size = num_units * 2 batch_size = 3 @@ -81,7 +81,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1], expected_state) def testTimeFreqLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 state_size = num_units * 2 batch_size = 3 @@ -120,7 +120,7 @@ class RNNCellTest(test.TestCase): float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) def testGridLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 batch_size = 3 input_size = 4 @@ -166,7 +166,7 @@ class RNNCellTest(test.TestCase): .state_f00_b00_c[i, :]))) > 1e-6) def testGridLSTMCellWithFrequencyBlocks(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 batch_size = 3 feature_size = 2 @@ -248,7 +248,7 @@ class RNNCellTest(test.TestCase): ]], dtype=np.float32) for state_is_tuple in [False, True]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple" + str(state_is_tuple), initializer=init_ops.constant_initializer(0.5)): @@ -294,7 +294,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) def testBidirectionGridLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -374,7 +374,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) def testBidirectionGridLSTMCellWithSliceOffset(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -487,7 +487,7 @@ class RNNCellTest(test.TestCase): input_size = 4 for state_is_tuple in [False, True]: with ops.Graph().as_default(): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( @@ -538,7 +538,7 @@ class RNNCellTest(test.TestCase): batch_size = 3 for state_is_tuple in [False, True]: with ops.Graph().as_default(): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( @@ -677,7 +677,7 @@ class RNNCellTest(test.TestCase): 0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348 ]]) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "nas_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.NASCell(num_units=num_units) @@ -725,7 +725,7 @@ class RNNCellTest(test.TestCase): 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517 ]]) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "nas_proj_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj) @@ -765,7 +765,7 @@ class RNNCellTest(test.TestCase): [[0.13752282, 0.13752282], [0.10545051, 0.10545051], [0.10074195, 0.10074195]], dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.UGRNNCell(num_units=num_units) @@ -796,7 +796,7 @@ class RNNCellTest(test.TestCase): [[2.00431061, 2.00431061], [4.00060606, 4.00060606], [6.00008249, 6.00008249]], dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "intersection_rnn_cell_test", initializer=init_ops.constant_initializer(0.5)): @@ -837,7 +837,7 @@ class RNNCellTest(test.TestCase): cell(inputs, init_state) def testPhasedLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -874,7 +874,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv1DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 1] filter_size = [3] num_features = 1 @@ -907,7 +907,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv2DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 2, 1] filter_size = [3, 3] num_features = 1 @@ -948,7 +948,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv3DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 2, 2, 1] filter_size = [3, 3, 3] num_features = 1 @@ -999,7 +999,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testHighwayWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "base_cell", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -1030,7 +1030,7 @@ class RNNCellTest(test.TestCase): # Try with input dimension equal to num_units or not. for num_inputs in [num_units, num_units + number_of_groups]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root1_%d" % num_inputs, initializer=init_ops.constant_initializer(0.5)): @@ -1059,7 +1059,7 @@ class RNNCellTest(test.TestCase): # Try with num_inputs equal to or not equal to num_units. for num_inputs in [num_units, num_units + number_of_groups]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root2_%d" % num_inputs, initializer=init_ops.constant_initializer(0.5)): @@ -1092,7 +1092,7 @@ class RNNCellTest(test.TestCase): batch_size = 2 num_units = 4 number_of_groups = 2 - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope( "glstm_failure", initializer=init_ops.constant_initializer(0.5)): gcell = contrib_rnn_cell.GLSTMCell( @@ -1121,7 +1121,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): # NOTE: all the values in the current test case have been calculated. def testBasicLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1189,7 +1189,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): def testBasicLSTMCellWithoutNorm(self): """Tests that BasicLSTMCell with layer_norm=False.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1256,7 +1256,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_h, 1e-5) def testBasicLSTMCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1294,7 +1294,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): def testBasicLSTMCellWithStateTupleLayerNorm(self): """The results of LSTMCell and LayerNormBasicLSTMCell should be the same.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1353,7 +1353,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): num_units = 5 allowed_low = [1, 2, 3] - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "other", initializer=init_ops.constant_initializer(1)): x = array_ops.zeros([1, 5]) @@ -1479,7 +1479,7 @@ class CompiledWrapperTest(test.TestCase): self.assertAllClose(xla_g, non_xla_g, atol=atol) def testMultiRNNCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1583,7 +1583,7 @@ class WeightNormLSTMCellTest(test.TestCase): def _cell_output(self, cell): """Calculates cell output.""" - with self.test_session() as sess: + with self.cached_session() as sess: init = init_ops.constant_initializer(0.5) with variable_scope.variable_scope("root", initializer=init): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 37a9957cea..92254a2c00 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -104,7 +104,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(step(), 2.0) def testGraphGradientVariable(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) @function.defun @@ -211,7 +211,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(f(), x) def testSymGradGatherNd(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: @function.defun def f(x): @@ -481,7 +481,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) def testGraphModeCaptureVariable(self): - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: class HasAVar(object): @@ -509,12 +509,12 @@ class FunctionTest(test.TestCase): x = constant_op.constant(1.0) l = f(x, v) _, dv = gradients_impl.gradients(l, [x, v]) - with self.test_session(): + with self.cached_session(): v.initializer.run() self.assertAllEqual(dv.eval(), 0.0) def testGraphModeManyFunctions(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): @function.defun def f(x): @@ -934,7 +934,7 @@ class FunctionTest(test.TestCase): self.assertEqual(1, int(read())) def testReturnCapturedGraphTensor(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): t = constant_op.constant(1) @function.defun @@ -1497,7 +1497,7 @@ class FunctionTest(test.TestCase): class AutomaticControlDependenciesTest(test.TestCase): def testBasic(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() with function.AutomaticControlDependencies() as c: @@ -1508,7 +1508,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(), 4.0) def testCondMustRun(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1529,7 +1529,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0) def testCondMustRunSeparateRead(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1552,7 +1552,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(v.read_value().eval(), 6.0) def testCondNested(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1586,7 +1586,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0) def testCondOneBranch(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1606,7 +1606,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0) def testCondOneBranchUpdateBefore(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1627,7 +1627,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0) def testCondOneBranchUpdateAfter(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1663,7 +1663,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(out, [3, 4, 5]) def testDecorator(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py index d2a2b4e223..3cf3a61a62 100644 --- a/tensorflow/python/eager/graph_only_ops_test.py +++ b/tensorflow/python/eager/graph_only_ops_test.py @@ -32,13 +32,13 @@ class GraphOnlyOpsTest(test_util.TensorFlowTestCase): def testGraphZerosLike(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) z_tf = graph_only_ops.graph_zeros_like(x) - with self.test_session(): + with self.cached_session(): self.assertAllClose(np.zeros((2, 3)), z_tf.eval()) def testGraphPlaceholder(self): x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,)) y_tf = math_ops.square(x_tf) - with self.test_session() as sess: + with self.cached_session() as sess: x = np.array([42]) y = sess.run(y_tf, feed_dict={x_tf: np.array([42])}) self.assertAllClose(np.square(x), y) diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py index 4326d5efa3..acd0e569f1 100644 --- a/tensorflow/python/eager/tape_test.py +++ b/tensorflow/python/eager/tape_test.py @@ -72,7 +72,7 @@ class TapeTest(test.TestCase): a = constant_op.constant([[1., 0.], [0., 1.]]) b = constant_op.constant([[1., 2.], [3., 4.]]) da, db = backprop.gradients_function(fn, [0, 1])(a, b) - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32) tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32) tf_c = tf_a + tf_b @@ -135,7 +135,7 @@ class TapeTest(test.TestCase): a = constant_op.constant([[1., 0.], [0., 1.]]) b = constant_op.constant([[1., 2.], [3., 4.]]) da, db = backprop.gradients_function(fn, [0, 1])(a, b) - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32) tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32) tf_mm = math_ops.matmul(tf_a, tf_b) diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py index afef997b00..9988c9fae5 100644 --- a/tensorflow/python/keras/layers/gru_test.py +++ b/tensorflow/python/keras/layers/gru_test.py @@ -87,7 +87,7 @@ class GRULayerTest(test.TestCase): embedding_dim = 4 units = 2 layer_class = keras.layers.GRU - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add( keras.layers.Embedding( @@ -146,7 +146,7 @@ class GRULayerTest(test.TestCase): def test_regularizers_GRU(self): embedding_dim = 4 layer_class = keras.layers.GRU - with self.test_session(): + with self.cached_session(): layer = layer_class( 5, return_sequences=False, @@ -166,7 +166,7 @@ class GRULayerTest(test.TestCase): def test_constraints_GRU(self): embedding_dim = 4 layer_class = keras.layers.GRU - with self.test_session(): + with self.cached_session(): k_constraint = keras.constraints.max_norm(0.01) r_constraint = keras.constraints.max_norm(0.01) b_constraint = keras.constraints.max_norm(0.01) @@ -186,7 +186,7 @@ class GRULayerTest(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_with_masking_layer_GRU(self): layer_class = keras.layers.GRU - with self.test_session(): + with self.cached_session(): inputs = np.random.random((2, 3, 4)) targets = np.abs(np.random.random((2, 3, 5))) targets /= targets.sum(axis=-1, keepdims=True) diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py index 9802820fd0..f536915324 100644 --- a/tensorflow/python/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/layers/lstm_test.py @@ -102,7 +102,7 @@ class LSTMLayerTest(test.TestCase): embedding_dim = 4 units = 2 layer_class = keras.layers.LSTM - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add( keras.layers.Embedding( @@ -161,7 +161,7 @@ class LSTMLayerTest(test.TestCase): def test_regularizers_LSTM(self): embedding_dim = 4 layer_class = keras.layers.LSTM - with self.test_session(): + with self.cached_session(): layer = layer_class( 5, return_sequences=False, @@ -180,7 +180,7 @@ class LSTMLayerTest(test.TestCase): def test_constraints_LSTM(self): embedding_dim = 4 layer_class = keras.layers.LSTM - with self.test_session(): + with self.cached_session(): k_constraint = keras.constraints.max_norm(0.01) r_constraint = keras.constraints.max_norm(0.01) b_constraint = keras.constraints.max_norm(0.01) @@ -200,7 +200,7 @@ class LSTMLayerTest(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_with_masking_layer_LSTM(self): layer_class = keras.layers.LSTM - with self.test_session(): + with self.cached_session(): inputs = np.random.random((2, 3, 4)) targets = np.abs(np.random.random((2, 3, 5))) targets /= targets.sum(axis=-1, keepdims=True) @@ -225,7 +225,7 @@ class LSTMLayerTest(test.TestCase): units = 3 num_samples = 2 - with self.test_session(): + with self.cached_session(): # Test with Keras tensor inputs = keras.Input((timesteps, embedding_dim)) initial_state = [keras.Input((units,)) for _ in range(num_states)] @@ -252,7 +252,7 @@ class LSTMLayerTest(test.TestCase): units = 3 num_samples = 2 - with self.test_session(): + with self.cached_session(): # Test with non-Keras tensor inputs = keras.Input((timesteps, embedding_dim)) initial_state = [keras.backend.random_normal_variable( @@ -275,7 +275,7 @@ class LSTMLayerTest(test.TestCase): units = 3 num_samples = 2 - with self.test_session(): + with self.cached_session(): layer = keras.layers.LSTM(units, stateful=True) layer.build((num_samples, timesteps, embedding_dim)) layer.reset_states() @@ -306,7 +306,7 @@ class LSTMLayerTest(test.TestCase): units = 3 num_samples = 2 - with self.test_session(): + with self.cached_session(): inputs = keras.Input((timesteps, embedding_dim)) _ = keras.layers.Masking()(inputs) initial_state = [keras.Input((units,)) for _ in range(num_states)] @@ -329,7 +329,7 @@ class LSTMLayerTest(test.TestCase): units = 3 num_samples = 2 - with self.test_session(): + with self.cached_session(): inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim)) layer = keras.layers.LSTM(units, return_state=True, stateful=True) outputs = layer(inputs) @@ -347,7 +347,7 @@ class LSTMLayerTest(test.TestCase): units = 3 num_samples = 2 - with self.test_session(): + with self.cached_session(): inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim)) layer = keras.layers.LSTM(units, return_state=True, return_sequences=True) outputs = layer(inputs) @@ -366,7 +366,7 @@ class LSTMLayerTest(test.TestCase): num_states = 2 layer_class = keras.layers.LSTM - with self.test_session(): + with self.cached_session(): # Test with Keras tensor main_inputs = keras.Input((timesteps, embedding_dim)) initial_state = [keras.Input((units,)) for _ in range(num_states)] diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py index 1429537648..2f2295a793 100644 --- a/tensorflow/python/keras/layers/simplernn_test.py +++ b/tensorflow/python/keras/layers/simplernn_test.py @@ -87,7 +87,7 @@ class SimpleRNNLayerTest(test.TestCase): embedding_dim = 4 units = 2 layer_class = keras.layers.SimpleRNN - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add( keras.layers.Embedding( @@ -146,7 +146,7 @@ class SimpleRNNLayerTest(test.TestCase): def test_regularizers_SimpleRNN(self): embedding_dim = 4 layer_class = keras.layers.SimpleRNN - with self.test_session(): + with self.cached_session(): layer = layer_class( 5, return_sequences=False, @@ -166,7 +166,7 @@ class SimpleRNNLayerTest(test.TestCase): def test_constraints_SimpleRNN(self): embedding_dim = 4 layer_class = keras.layers.SimpleRNN - with self.test_session(): + with self.cached_session(): k_constraint = keras.constraints.max_norm(0.01) r_constraint = keras.constraints.max_norm(0.01) b_constraint = keras.constraints.max_norm(0.01) @@ -186,7 +186,7 @@ class SimpleRNNLayerTest(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_with_masking_layer_SimpleRNN(self): layer_class = keras.layers.SimpleRNN - with self.test_session(): + with self.cached_session(): inputs = np.random.random((2, 3, 4)) targets = np.abs(np.random.random((2, 3, 5))) targets /= targets.sum(axis=-1, keepdims=True) -- cgit v1.2.3 From b828f89263e054bfa7c7a808cab1506834ab906d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 14:37:06 -0700 Subject: Move from deprecated self.test_session() to self.cached_session(). self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336464 --- .../python/kernel_tests/prediction_ops_test.py | 4 +- .../python/kernel_tests/training_ops_test.py | 8 +- .../python/external_regret_optimizer_test.py | 4 +- .../python/swap_regret_optimizer_test.py | 10 +- .../optimization/latency_all_edges_test.py | 2 +- .../optimization/map_and_filter_fusion_test.py | 4 +- tensorflow/contrib/eager/python/evaluator_test.py | 4 +- tensorflow/contrib/eager/python/metrics_test.py | 4 +- .../python/framework/checkpoint_utils_test.py | 18 +- .../framework/python/framework/tensor_util_test.py | 20 +- .../gan/python/losses/python/losses_impl_test.py | 52 +++--- .../gan/python/losses/python/tuple_losses_test.py | 8 +- .../contrib/learn/python/learn/ops/ops_test.py | 6 +- .../learn/python/learn/ops/seq2seq_ops_test.py | 6 +- tensorflow/contrib/specs/python/specs_test.py | 22 +-- tensorflow/contrib/specs/python/summaries_test.py | 8 +- tensorflow/python/data/util/convert_test.py | 16 +- tensorflow/python/data/util/sparse_test.py | 2 +- .../python/estimator/canned/boosted_trees_test.py | 16 +- tensorflow/python/estimator/canned/head_test.py | 208 ++++++++++----------- .../python/estimator/inputs/numpy_io_test.py | 34 ++-- .../python/estimator/inputs/pandas_io_test.py | 24 +-- .../training/checkpointable/tracking_test.py | 2 +- .../python/training/checkpointable/util_test.py | 2 +- 24 files changed, 242 insertions(+), 242 deletions(-) diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 4278a30ba9..46dfbdefeb 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -331,7 +331,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([[], []], dropout_info.eval()) def testObliviousEnsemble(self): - with self.test_session(): + with self.cached_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Bias tree. tree1 = tree_ensemble_config.trees.add() @@ -1399,7 +1399,7 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([0, 0], result.eval()) def testObliviousTreeNonFinalized(self): - with self.test_session(): + with self.cached_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Depth 3 tree. tree1 = tree_ensemble_config.trees.add() diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index b3e4c2e5f7..86fd5770a0 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -411,7 +411,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEmptyEnsembleObliviousCase(self): """Test growing an empty ensemble in the oblivious case.""" - with self.test_session() as session: + with self.cached_session() as session: # Create empty ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree_ensemble_handle = model_ops.tree_ensemble_variable( @@ -1620,7 +1620,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleTreeLayerByLayerObliviousCase(self): """Test growing an existing ensemble with the last tree not finalized.""" - with self.test_session() as session: + with self.cached_session() as session: # Create existing ensemble with one root split tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( @@ -1810,7 +1810,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleWithEmptyNodesMiddleCase(self): """Test case: The middle existing leaves don't have examples.""" - with self.test_session() as session: + with self.cached_session() as session: tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( """ @@ -2071,7 +2071,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleWithEmptyNodesBorderCase(self): """Test case: The first and last existing leaves don't have examples.""" - with self.test_session() as session: + with self.cached_session() as session: tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( """ diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py index 9b4bf62710..3e25079e02 100644 --- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py @@ -75,7 +75,7 @@ class ExternalRegretOptimizerTest(test.TestCase): multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1]) expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0]) - with self.test_session() as session: + with self.cached_session() as session: projected_multipliers1 = session.run( external_regret_optimizer._project_multipliers_wrt_euclidean_norm( multipliers1, 1.0)) @@ -122,7 +122,7 @@ class ExternalRegretOptimizerTest(test.TestCase): ] multipliers = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(multipliers) < len(expected_multipliers): multipliers.append(session.run(optimizer.lagrange_multipliers)) diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py index 34c4543dca..df0eced631 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py @@ -97,7 +97,7 @@ class SwapRegretOptimizerTest(test.TestCase): matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) - with self.test_session() as session: + with self.cached_session() as session: eigenvector1 = session.run( swap_regret_optimizer._maximal_eigenvector_power_method( standard_ops.constant(matrix1))) @@ -119,7 +119,7 @@ class SwapRegretOptimizerTest(test.TestCase): expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) - with self.test_session() as session: + with self.cached_session() as session: projected_matrix = session.run( swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm( matrix)) @@ -134,7 +134,7 @@ class SwapRegretOptimizerTest(test.TestCase): expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) - with self.test_session() as session: + with self.cached_session() as session: projected_matrix = session.run( standard_ops.exp( swap_regret_optimizer. @@ -165,7 +165,7 @@ class SwapRegretOptimizerTest(test.TestCase): ] matrices = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(matrices) < len(expected_matrices): matrices.append(session.run(optimizer.stochastic_matrix)) @@ -198,7 +198,7 @@ class SwapRegretOptimizerTest(test.TestCase): ] matrices = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(matrices) < len(expected_matrices): matrices.append(session.run(optimizer.stochastic_matrix)) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py index 1850b6921a..db380c02a9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py @@ -40,7 +40,7 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): get_next = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertEqual(1 * 1, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index 6a7ef877f9..dde115925e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -74,7 +74,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for x in range(5): result = sess.run(get_next) r = x @@ -131,7 +131,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): def _testMapAndFilter(self, dataset, function, predicate): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for x in range(10): r = function(x) if isinstance(r, tuple): diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 7d2274db9b..48d093e075 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -117,7 +117,7 @@ class EvaluatorTest(test.TestCase): self.assertEqual(6.0, results["mean"].numpy()) def testDatasetGraph(self): - with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + with context.graph_mode(), ops.Graph().as_default(), self.cached_session(): e = SimpleEvaluator(IdentityModel()) ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) init_op, call_op, results_op = e.evaluate_on_dataset(ds) @@ -126,7 +126,7 @@ class EvaluatorTest(test.TestCase): self.assertEqual(6.0, results["mean"]) def testWriteSummariesGraph(self): - with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + with context.graph_mode(), ops.Graph().as_default(), self.cached_session(): e = SimpleEvaluator(IdentityModel()) ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) training_util.get_or_create_global_step() diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index dcc7b71d79..9d2d172752 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -216,7 +216,7 @@ class MetricsTest(test.TestCase): self.assertEqual(m1.numer.name, "has_space/numer:0") def testGraphWithPlaceholder(self): - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: m = metrics.Mean() p = array_ops.placeholder(dtypes.float32) accumulate = m(p) @@ -309,7 +309,7 @@ class MetricsTest(test.TestCase): self.assertTrue(old_numer is m.numer) def testMetricsChain(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): m1 = metrics.Mean() m2 = metrics.Mean(name="m2") update_m2 = m2(3.0) diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py index 4f591367fd..77a424145a 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py @@ -82,7 +82,7 @@ class CheckpointsTest(test.TestCase): def testNoTensor(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) with self.assertRaises(errors_impl.OpError): self.assertAllEqual( @@ -90,7 +90,7 @@ class CheckpointsTest(test.TestCase): def testGetTensor(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1) @@ -103,7 +103,7 @@ class CheckpointsTest(test.TestCase): def testGetAllVariables(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _create_checkpoints(session, checkpoint_dir) self.assertEqual( checkpoint_utils.list_variables(checkpoint_dir), @@ -112,7 +112,7 @@ class CheckpointsTest(test.TestCase): def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -146,7 +146,7 @@ class CheckpointsTest(test.TestCase): def testInitWithScopeDoesNotCaptureSuffixes(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, v4 = _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default() as g: @@ -165,7 +165,7 @@ class CheckpointsTest(test.TestCase): def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -189,7 +189,7 @@ class CheckpointsTest(test.TestCase): def testInitToRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -212,7 +212,7 @@ class CheckpointsTest(test.TestCase): def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. @@ -266,7 +266,7 @@ class CheckpointsTest(test.TestCase): def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 2479fe5b8d..b1820c10c8 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -39,7 +39,7 @@ from tensorflow.python.platform import test class LocalVariabletest(test.TestCase): def test_local_variable(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEquals([], variables_lib.local_variables()) value0 = 42 variables_lib2.local_variable(value0) @@ -55,7 +55,7 @@ class LocalVariabletest(test.TestCase): class ReduceSumNTest(test.TestCase): def test_reduce_sum_n(self): - with self.test_session(): + with self.cached_session(): a = constant_op.constant(1) b = constant_op.constant([2]) c = constant_op.constant([[3, 4], [5, 6]]) @@ -119,13 +119,13 @@ class WithShapeTest(test.TestCase): })) def test_with_shape_invalid_expected_shape(self): - with self.test_session(): + with self.cached_session(): self.assertRaisesRegexp(ValueError, "Invalid rank", tensor_util.with_shape, [[1], [2]], constant_op.constant(1.0)) def test_with_shape_invalid_type(self): - with self.test_session(): + with self.cached_session(): self.assertRaisesRegexp(ValueError, "Invalid dtype", tensor_util.with_shape, [1.1], constant_op.constant([1.0])) @@ -138,7 +138,7 @@ class WithShapeTest(test.TestCase): constant_op.constant(1.0)) def test_with_shape_0(self): - with self.test_session(): + with self.cached_session(): value = 42 shape = [0] unexpected_shapes = [[1], [2], [1, 1]] @@ -150,7 +150,7 @@ class WithShapeTest(test.TestCase): unexpected_shapes) def test_with_shape_1(self): - with self.test_session(): + with self.cached_session(): value = [42] shape = [1] unexpected_shapes = [[0], [2], [1, 1]] @@ -162,7 +162,7 @@ class WithShapeTest(test.TestCase): unexpected_shapes) def test_with_shape_2(self): - with self.test_session(): + with self.cached_session(): value = [42, 43] shape = [2] unexpected_shapes = [[0], [1], [2, 1]] @@ -174,7 +174,7 @@ class WithShapeTest(test.TestCase): unexpected_shapes) def test_with_shape_2x2(self): - with self.test_session(): + with self.cached_session(): value = [[42, 43], [44, 45]] shape = [2, 2] unexpected_shapes = [[0], [1], [2, 1]] @@ -196,7 +196,7 @@ class WithShapeTest(test.TestCase): np.testing.assert_array_equal(value, tensor_with_shape.eval()) def test_with_shape_none(self): - with self.test_session(): + with self.cached_session(): tensor_no_shape = array_ops.placeholder(dtypes.float32) compatible_shape = [2, 2] @@ -220,7 +220,7 @@ class WithShapeTest(test.TestCase): @test_util.enable_c_shapes def test_with_shape_partial(self): - with self.test_session(): + with self.cached_session(): tensor_partial_shape = array_ops.placeholder(dtypes.float32) tensor_partial_shape.set_shape([None, 2]) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 9f5fee4542..e3c780ac1a 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -51,7 +51,7 @@ class _LossesTest(object): loss = self._g_loss_fn(self._discriminator_gen_outputs) self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) self.assertEqual(self._generator_loss_name, loss.op.name) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) def test_discriminator_all_correct(self): @@ -59,7 +59,7 @@ class _LossesTest(object): self._discriminator_real_outputs, self._discriminator_gen_outputs) self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) self.assertEqual(self._discriminator_loss_name, loss.op.name) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) def test_generator_loss_collection(self): @@ -90,7 +90,7 @@ class _LossesTest(object): loss = self._g_loss_fn( array_ops.reshape(self._discriminator_gen_outputs, [2, 2])) self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) def test_discriminator_patch(self): @@ -98,7 +98,7 @@ class _LossesTest(object): array_ops.reshape(self._discriminator_real_outputs, [2, 2]), array_ops.reshape(self._discriminator_gen_outputs, [2, 2])) self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) def test_generator_loss_with_placeholder_for_logits(self): @@ -108,7 +108,7 @@ class _LossesTest(object): loss = self._g_loss_fn(logits, weights=weights) self.assertEqual(logits.dtype, loss.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={ logits: [[10.0, 4.4, -5.5, 3.6]], @@ -125,7 +125,7 @@ class _LossesTest(object): logits, logits2, real_weights=real_weights, generated_weights=generated_weights) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run(loss, feed_dict={ logits: [self._discriminator_real_outputs_np], @@ -136,7 +136,7 @@ class _LossesTest(object): def test_generator_with_python_scalar_weight(self): loss = self._g_loss_fn( self._discriminator_gen_outputs, weights=self._weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss * self._weights, loss.eval(), 4) @@ -144,14 +144,14 @@ class _LossesTest(object): loss = self._d_loss_fn( self._discriminator_real_outputs, self._discriminator_gen_outputs, real_weights=self._weights, generated_weights=self._weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss * self._weights, loss.eval(), 4) def test_generator_with_scalar_tensor_weight(self): loss = self._g_loss_fn(self._discriminator_gen_outputs, weights=constant_op.constant(self._weights)) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss * self._weights, loss.eval(), 4) @@ -160,7 +160,7 @@ class _LossesTest(object): loss = self._d_loss_fn( self._discriminator_real_outputs, self._discriminator_gen_outputs, real_weights=weights, generated_weights=weights) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss * self._weights, loss.eval(), 4) @@ -284,7 +284,7 @@ class ACGANLossTest(test.TestCase): self.assertEqual( self._discriminator_gen_classification_logits.dtype, loss.dtype) self.assertEqual(self._generator_loss_name, loss.op.name) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) def test_discriminator_all_correct(self): @@ -292,7 +292,7 @@ class ACGANLossTest(test.TestCase): self.assertEqual( self._discriminator_gen_classification_logits.dtype, loss.dtype) self.assertEqual(self._discriminator_loss_name, loss.op.name) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) def test_generator_loss_collection(self): @@ -319,14 +319,14 @@ class ACGANLossTest(test.TestCase): patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in self._generator_kwargs.items()} loss = self._g_loss_fn(**patch_args) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) def test_discriminator_patch(self): patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in self._discriminator_kwargs.items()} loss = self._d_loss_fn(**patch_args) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) def test_generator_loss_with_placeholder_for_logits(self): @@ -334,7 +334,7 @@ class ACGANLossTest(test.TestCase): one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4)) loss = self._g_loss_fn(gen_logits, one_hot_labels) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run( loss, feed_dict={ gen_logits: self._discriminator_gen_classification_logits_np, @@ -349,7 +349,7 @@ class ACGANLossTest(test.TestCase): loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels) - with self.test_session() as sess: + with self.cached_session() as sess: loss = sess.run( loss, feed_dict={ gen_logits: self._discriminator_gen_classification_logits_np, @@ -360,7 +360,7 @@ class ACGANLossTest(test.TestCase): def test_generator_with_python_scalar_weight(self): loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss * self._weights, loss.eval(), 4) @@ -368,14 +368,14 @@ class ACGANLossTest(test.TestCase): loss = self._d_loss_fn( real_weights=self._weights, generated_weights=self._weights, **self._discriminator_kwargs) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss * self._weights, loss.eval(), 4) def test_generator_with_scalar_tensor_weight(self): loss = self._g_loss_fn( weights=constant_op.constant(self._weights), **self._generator_kwargs) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_g_loss * self._weights, loss.eval(), 4) @@ -383,7 +383,7 @@ class ACGANLossTest(test.TestCase): weights = constant_op.constant(self._weights) loss = self._d_loss_fn(real_weights=weights, generated_weights=weights, **self._discriminator_kwargs) - with self.test_session(): + with self.cached_session(): self.assertAlmostEqual(self._expected_d_loss * self._weights, loss.eval(), 4) @@ -404,7 +404,7 @@ class _PenaltyTest(object): loss = self._penalty_fn(**self._kwargs) self.assertEqual(self._expected_dtype, loss.dtype) self.assertEqual(self._expected_op_name, loss.op.name) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss, loss.eval(), 6) @@ -419,13 +419,13 @@ class _PenaltyTest(object): def test_python_scalar_weight(self): loss = self._penalty_fn(weights=2.3, **self._kwargs) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3) def test_scalar_tensor_weight(self): loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3) @@ -472,7 +472,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): self._kwargs['discriminator_scope']) self.assertEqual(generated_data.dtype, loss.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() loss = sess.run(loss, feed_dict={ @@ -494,7 +494,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): one_sided=True) self.assertEqual(generated_data.dtype, loss.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() loss = sess.run(loss, feed_dict={ @@ -516,7 +516,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): self._kwargs['discriminator_scope'], target=2.0) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() loss = sess.run( loss, diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index a559bbfa11..25d74a8c23 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -118,7 +118,7 @@ def add_loss_consistency_test(test_class, loss_name_str, loss_args): def consistency_test(self): self.assertEqual(arg_loss.__name__, tuple_loss.__name__) - with self.test_session(): + with self.cached_session(): self.assertEqual(arg_loss(**loss_args).eval(), tuple_loss(_tuple_from_dict(loss_args)).eval()) @@ -241,7 +241,7 @@ class StarGANLossWrapperTest(test.TestCase): self.discriminator_generated_data_source_predication) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) @@ -257,7 +257,7 @@ class StarGANLossWrapperTest(test.TestCase): self.discriminator_generated_data_source_predication) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) @@ -282,7 +282,7 @@ class StarGANLossWrapperTest(test.TestCase): discriminator_scope=self.discriminator_scope) wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) loss_result, wrapped_loss_result = sess.run( [loss_result_tensor, wrapped_loss_result_tensor]) diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py index 80d4923db3..ff190110c1 100644 --- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py +++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py @@ -33,7 +33,7 @@ class OpsTest(test.TestCase): """Ops tests.""" def test_softmax_classifier(self): - with self.test_session() as session: + with self.cached_session() as session: features = array_ops.placeholder(dtypes.float32, [None, 3]) labels = array_ops.placeholder(dtypes.float32, [None, 2]) weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]]) @@ -52,7 +52,7 @@ class OpsTest(test.TestCase): ids_shape = (2, 3, 4) embeds = np.random.randn(n_embed, d_embed) ids = np.random.randint(0, n_embed, ids_shape) - with self.test_session(): + with self.cached_session(): embed_np = embeds[ids] embed_tf = ops.embedding_lookup(embeds, ids).eval() self.assertEqual(embed_np.shape, embed_tf.shape) @@ -60,7 +60,7 @@ class OpsTest(test.TestCase): def test_categorical_variable(self): random_seed.set_random_seed(42) - with self.test_session() as sess: + with self.cached_session() as sess: cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2]) embeddings = ops.categorical_variable( cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var") diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py index 95aec61955..5a7e4ebfea 100644 --- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py +++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py @@ -31,7 +31,7 @@ class Seq2SeqOpsTest(test.TestCase): """Sequence-to-sequence tests.""" def test_sequence_classifier(self): - with self.test_session() as session: + with self.cached_session() as session: decoding = [ array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3) ] @@ -60,7 +60,7 @@ class Seq2SeqOpsTest(test.TestCase): def test_seq2seq_inputs(self): inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]]) out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]]) - with self.test_session() as session: + with self.cached_session() as session: x = array_ops.placeholder(dtypes.float32, [2, 3, 2]) y = array_ops.placeholder(dtypes.float32, [2, 2, 3]) in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2) @@ -77,7 +77,7 @@ class Seq2SeqOpsTest(test.TestCase): [[0, 0, 0], [0, 0, 0]]]) def test_rnn_decoder(self): - with self.test_session(): + with self.cached_session(): decoder_inputs = [ array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3) ] diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py index 9a4ad36793..b7ce6aa20a 100644 --- a/tensorflow/contrib/specs/python/specs_test.py +++ b/tensorflow/contrib/specs/python/specs_test.py @@ -38,7 +38,7 @@ def _rand(*size): class SpecsTest(test.TestCase): def testSimpleConv(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(1, 18, 19, 5)) spec = "net = Cr(64, [5, 5])" outputs = specs.create_net(spec, inputs) @@ -53,7 +53,7 @@ class SpecsTest(test.TestCase): def testUnary(self): # This is just a quick and dirty check that these ops exist # and work as unary ops. - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(17, 55)) spec = "net = Do(0.5) | Bn | Unit(1) | Relu | Sig | Tanh | Smax" outputs = specs.create_net(spec, inputs) @@ -63,7 +63,7 @@ class SpecsTest(test.TestCase): self.assertEqual(tuple(result.shape), (17, 55)) def testAdd(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(17, 55)) spec = "net = Fs(10) + Fr(10)" outputs = specs.create_net(spec, inputs) @@ -77,7 +77,7 @@ class SpecsTest(test.TestCase): "<> variablev2 dot variablev2 biasadd relu add") def testMpPower(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(1, 64, 64, 5)) spec = "M2 = Mp([2, 2]); net = M2**3" outputs = specs.create_net(spec, inputs) @@ -90,7 +90,7 @@ class SpecsTest(test.TestCase): "_ maxpool maxpool maxpool") def testAbbrevPower(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(1, 64, 64, 5)) spec = "C3 = Cr([3, 3]); M2 = Mp([2, 2]); net = (C3(5) | M2)**3" outputs = specs.create_net(spec, inputs) @@ -106,7 +106,7 @@ class SpecsTest(test.TestCase): " biasadd relu maxpool") def testAbbrevPower2(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(1, 64, 64, 5)) spec = "C3 = Cr(_1=[3, 3]); M2 = Mp([2, 2]);" spec += "net = (C3(_0=5) | M2)**3" @@ -123,7 +123,7 @@ class SpecsTest(test.TestCase): " maxpool") def testConc(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(10, 20)) spec = "net = Conc(1, Fs(20), Fs(10))" outputs = specs.create_net(spec, inputs) @@ -137,7 +137,7 @@ class SpecsTest(test.TestCase): "<> variablev2 dot variablev2 biasadd sig _ concatv2") def testImport(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(10, 20)) spec = ("S = Import('from tensorflow.python.ops" + " import math_ops; f = math_ops.sigmoid')") @@ -150,7 +150,7 @@ class SpecsTest(test.TestCase): self.assertEqual(summaries.tf_spec_structure(spec, inputs), "_ sig sig") def testKeywordRestriction(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(10, 20)) spec = "import re; net = Conc(1, Fs(20), Fs(10))" self.assertRaises(ValueError, lambda: specs.create_net(spec, inputs)) @@ -179,7 +179,7 @@ class SpecsTest(test.TestCase): # XXX: the cleverness of this code is over 9000 # TODO: original author please fix def DISABLED_testVar(self): - with self.test_session() as sess: + with self.cached_session() as sess: with specs.ops: # pylint: disable=undefined-variable v = Var("test_var", @@ -196,7 +196,7 @@ class SpecsTest(test.TestCase): # XXX: the cleverness of this code is over 9000 # TODO: original author please fix def DISABLED_testShared(self): - with self.test_session(): + with self.cached_session(): with specs.ops: # pylint: disable=undefined-variable f = Shared(Fr(100)) diff --git a/tensorflow/contrib/specs/python/summaries_test.py b/tensorflow/contrib/specs/python/summaries_test.py index 34ff4bc8ca..b82ba06d3f 100644 --- a/tensorflow/contrib/specs/python/summaries_test.py +++ b/tensorflow/contrib/specs/python/summaries_test.py @@ -34,7 +34,7 @@ def _rand(*size): class SummariesTest(test.TestCase): def testStructure(self): - with self.test_session(): + with self.cached_session(): inputs_shape = (1, 18, 19, 5) inputs = constant_op.constant(_rand(*inputs_shape)) spec = "net = Cr(64, [5, 5])" @@ -48,7 +48,7 @@ class SummariesTest(test.TestCase): "_ variablev2 conv variablev2 biasadd relu") def testStructureFromTensor(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(1, 18, 19, 5)) spec = "net = Cr(64, [5, 5])" outputs = specs.create_net(spec, inputs) @@ -60,7 +60,7 @@ class SummariesTest(test.TestCase): "_ variablev2 conv variablev2 biasadd relu") def testPrint(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(1, 18, 19, 5)) spec = "net = Cr(64, [5, 5])" outputs = specs.create_net(spec, inputs) @@ -70,7 +70,7 @@ class SummariesTest(test.TestCase): summaries.tf_spec_print(spec, inputs) def testSummary(self): - with self.test_session(): + with self.cached_session(): inputs = constant_op.constant(_rand(1, 18, 19, 5)) spec = "net = Cr(64, [5, 5])" outputs = specs.create_net(spec, inputs) diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py index 6a67093e48..89c3afb296 100644 --- a/tensorflow/python/data/util/convert_test.py +++ b/tensorflow/python/data/util/convert_test.py @@ -30,28 +30,28 @@ class ConvertTest(test.TestCase): def testInteger(self): resp = convert.optional_param_to_tensor("foo", 3) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(resp)) def testIntegerDefault(self): resp = convert.optional_param_to_tensor("foo", None) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(0, sess.run(resp)) def testStringDefault(self): resp = convert.optional_param_to_tensor("bar", None, "default", dtypes.string) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(compat.as_bytes("default"), sess.run(resp)) def testString(self): resp = convert.optional_param_to_tensor("bar", "value", "default", dtypes.string) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(compat.as_bytes("value"), sess.run(resp)) def testPartialShapeToTensorKnownDimension(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([1])))) self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,)))) @@ -60,7 +60,7 @@ class ConvertTest(test.TestCase): constant_op.constant([1], dtype=dtypes.int64)))) def testPartialShapeToTensorUnknownDimension(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([None])))) self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor( @@ -84,7 +84,7 @@ class ConvertTest(test.TestCase): convert.partial_shape_to_tensor(constant_op.constant([1., 1.])) def testPartialShapeToTensorMultipleDimensions(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([3, 6])))) self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor( @@ -113,7 +113,7 @@ class ConvertTest(test.TestCase): constant_op.constant([-1, -1], dtype=dtypes.int64)))) def testPartialShapeToTensorScalar(self): - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor( tensor_shape.TensorShape([])))) self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(()))) diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py index d49b3ff34b..056b32480f 100644 --- a/tensorflow/python/data/util/sparse_test.py +++ b/tensorflow/python/data/util/sparse_test.py @@ -291,7 +291,7 @@ class SparseTest(test.TestCase): self.assertEqual(a, b) return self.assertTrue(isinstance(b, sparse_tensor.SparseTensor)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(a.eval().indices, b.eval().indices) self.assertAllEqual(a.eval().values, b.eval().values) self.assertAllEqual(a.eval().dense_shape, b.eval().dense_shape) diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 08026a93c5..6e28c72151 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -1560,7 +1560,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): ops.reset_default_graph() expected_first, expected_second, expected_third = ( self._get_expected_ensembles_for_classification()) - with self.test_session() as sess: + with self.cached_session() as sess: # Train with train_in_memory mode. with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( @@ -1593,7 +1593,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): expected_first, expected_second, expected_third, expected_forth = ( self._get_expected_ensembles_for_classification_with_bias()) - with self.test_session() as sess: + with self.cached_session() as sess: with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( boosted_trees._create_classification_head(n_classes=2), @@ -1633,7 +1633,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): ops.reset_default_graph() expected_first, expected_second, expected_third = ( self._get_expected_ensembles_for_classification()) - with self.test_session() as sess: + with self.cached_session() as sess: # Train without train_in_memory mode. with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( @@ -1666,7 +1666,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): expected_first, expected_second, expected_third, expected_forth = ( self._get_expected_ensembles_for_classification_with_bias()) - with self.test_session() as sess: + with self.cached_session() as sess: with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( boosted_trees._create_classification_head(n_classes=2), @@ -1704,7 +1704,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): ops.reset_default_graph() expected_first, expected_second, expected_third = ( self._get_expected_ensembles_for_regression()) - with self.test_session() as sess: + with self.cached_session() as sess: # Train with train_in_memory mode. with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( @@ -1734,7 +1734,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): ops.reset_default_graph() expected_first, expected_second, expected_third, expected_forth = ( self._get_expected_ensembles_for_regression_with_bias()) - with self.test_session() as sess: + with self.cached_session() as sess: # Train with train_in_memory mode. with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( @@ -1774,7 +1774,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): ops.reset_default_graph() expected_first, expected_second, expected_third = ( self._get_expected_ensembles_for_regression()) - with self.test_session() as sess: + with self.cached_session() as sess: # Train without train_in_memory mode. with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( @@ -1804,7 +1804,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): ops.reset_default_graph() expected_first, expected_second, expected_third, expected_forth = ( self._get_expected_ensembles_for_regression_with_bias()) - with self.test_session() as sess: + with self.cached_session() as sess: # Train with train_in_memory mode. with sess.graph.as_default(): train_op, ensemble_serialized = self._get_train_op_and_ensemble( diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index bd2e0ae943..de9c84d2ef 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -260,7 +260,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): features={'x': np.array(((30.,), (42.,),))}, mode=model_fn.ModeKeys.PREDICT, logits=logits_placeholder) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors.OpError, 'logits shape'): spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval({ logits_placeholder: logits_2x2 @@ -293,7 +293,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_placeholder, labels=labels_placeholder)[0] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[2 2\]'): @@ -347,14 +347,14 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_placeholder, labels=labels_placeholder)[0] - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError('Labels must <= n_classes - 1'): training_loss.eval({ labels_placeholder: labels_2x1_with_large_id, logits_placeholder: logits_2x3 }) - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError('Labels must >= 0'): training_loss.eval({ labels_placeholder: labels_2x1_with_negative_id, @@ -413,7 +413,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_placeholder, labels=labels_placeholder)[0] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[3 1\]'): @@ -449,7 +449,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): spec.export_outputs.keys()) # Assert predictions and export_outputs. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) predictions = sess.run(spec.predictions) @@ -484,7 +484,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.PREDICT, logits=logits) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertAllEqual( expected_classes, @@ -510,7 +510,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.PREDICT, logits=logits) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) predictions = sess.run(spec.predictions) self.assertAllClose(logits, @@ -534,7 +534,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2) @@ -561,7 +561,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_input, labels=labels_input)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(np.sum(loss), actual_training_loss.eval()) @@ -581,7 +581,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -632,7 +632,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert predictions, loss, and metrics. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} @@ -698,7 +698,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert predictions, loss, and metrics. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} @@ -727,7 +727,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2) @@ -755,7 +755,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): } tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} @@ -804,7 +804,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert loss, and metrics. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} @@ -837,7 +837,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): logits=logits, labels=labels) tol = 1e-2 - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=tol, atol=tol) @@ -866,7 +866,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): logits=logits, labels=labels) tol = 1e-2 - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=tol, atol=tol) @@ -921,7 +921,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert predictions, loss, train_op, and summaries. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, @@ -962,7 +962,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): optimizer=_Optimizer()) tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) loss, train_result = sess.run((spec.loss, spec.train_op)) self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) @@ -992,7 +992,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): labels=np.array(((1,), (1,)), dtype=np.int64), train_op_fn=_train_op_fn) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) sess.run(spec.train_op) w_value, t_value = sess.run([w, t]) @@ -1023,7 +1023,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert summaries. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) summary_str = sess.run(spec.scaffold.summary_op) @@ -1064,7 +1064,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert predictions, loss, train_op, and summaries. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, @@ -1104,7 +1104,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): logits=logits, labels=labels_rank_1) tol = 1e-2 - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=tol, atol=tol) @@ -1153,7 +1153,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert predictions, loss, train_op, and summaries. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, @@ -1183,7 +1183,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2) @@ -1211,7 +1211,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): train_op_fn=_train_op_fn) tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) loss = sess.run(spec.loss) self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) @@ -1253,7 +1253,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert predictions, loss, train_op, and summaries. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, @@ -1292,7 +1292,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): logits=logits, labels=labels) tol = 1e-2 - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=tol, atol=tol) @@ -1327,7 +1327,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert predictions, loss, train_op, and summaries. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) loss, train_result = sess.run((spec.loss, spec.train_op)) self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) @@ -1353,7 +1353,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): logits=logits, labels=labels, train_op_fn=_no_op_train_fn) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -1380,7 +1380,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): logits=logits, labels=labels, train_op_fn=_no_op_train_fn) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -1413,7 +1413,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): # Assert predictions, loss, and metrics. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} @@ -1506,7 +1506,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): features={'x': np.array(((42.,),))}, mode=model_fn.ModeKeys.PREDICT, logits=logits_placeholder) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors.OpError, 'logits shape'): spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval({ logits_placeholder: logits_2x2 @@ -1536,7 +1536,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_placeholder, labels=labels_placeholder)[0] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[2 2\]'): @@ -1577,7 +1577,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_placeholder, labels=labels_placeholder)[0] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[3 1\] \[labels_shape: \] \[2 1\]'): @@ -1585,7 +1585,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): labels_placeholder: values_2x1, logits_placeholder: values_3x1 }) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[3 1\]'): @@ -1624,7 +1624,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) predictions = sess.run(spec.predictions) @@ -1660,7 +1660,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.PREDICT, logits=logits) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertAllEqual( expected_classes, @@ -1680,7 +1680,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2) @@ -1733,7 +1733,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} @@ -1808,7 +1808,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): } # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} @@ -1832,7 +1832,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(41., training_loss.eval()) @@ -1849,7 +1849,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): logits=logits, labels=labels) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} @@ -1877,7 +1877,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2) @@ -1924,7 +1924,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): } self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} @@ -1957,7 +1957,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(expected_training_loss, training_loss.eval()) self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval()) @@ -1983,7 +1983,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(expected_training_loss, training_loss.eval()) self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval()) @@ -2011,7 +2011,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_input, labels=labels_input)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(np.sum(loss), actual_training_loss.eval()) @@ -2031,7 +2031,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -2086,7 +2086,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions, loss, train_op, and summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, @@ -2126,7 +2126,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): labels=labels, optimizer=_Optimizer()) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) loss, train_result = sess.run((spec.loss, spec.train_op)) self.assertAllClose(expected_loss, loss) @@ -2153,7 +2153,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): labels=np.array(((1,), (1,),), dtype=np.float64), train_op_fn=_train_op_fn) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) sess.run(spec.train_op) w_value, t_value = sess.run([w, t]) @@ -2182,7 +2182,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): labels=labels, train_op_fn=_train_op_fn) # Assert summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) summary_str = sess.run(spec.scaffold.summary_op) @@ -2227,7 +2227,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): regularization_losses=regularization_losses) # Assert predictions, loss, train_op, and summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, @@ -2254,7 +2254,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'Labels must <= n_classes - 1'): - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) training_loss.eval() @@ -2277,7 +2277,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2) @@ -2309,7 +2309,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): train_op_fn=_train_op_fn) # Assert predictions, loss, train_op, and summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) loss, train_result = sess.run((spec.loss, spec.train_op)) self.assertAlmostEqual(expected_loss, loss, delta=1.e-5) @@ -2334,7 +2334,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2) @@ -2360,7 +2360,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): expected_loss = 1.2484322 # Assert loss. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} @@ -2385,7 +2385,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): logits=logits) # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) predictions = sess.run(spec.predictions) self.assertAllClose( @@ -2447,7 +2447,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} @@ -2483,7 +2483,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels_rank_1) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), @@ -2531,7 +2531,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertIsNotNone(spec.train_op) # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) loss, train_result, summary_str = sess.run(( @@ -2577,7 +2577,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertIsNotNone(spec.train_op) # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) loss, train_result, summary_str = sess.run(( @@ -2612,7 +2612,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): logits=logits, labels=labels) tol = 1e-2 - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( expected_training_loss, training_loss.eval(), @@ -2649,7 +2649,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): # Assert predictions, loss, train_op, and summaries. tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) loss, train_result = sess.run((spec.loss, spec.train_op)) self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) @@ -2675,7 +2675,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): logits=logits, labels=labels, train_op_fn=_no_op_train_fn) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -2700,7 +2700,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): logits=logits, labels=labels, train_op_fn=_no_op_train_fn) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -2744,7 +2744,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): } tol = 1e-2 - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} @@ -2825,7 +2825,7 @@ class RegressionHead(test.TestCase): features={'x': np.array(((42.,),))}, mode=model_fn.ModeKeys.PREDICT, logits=logits_placeholder) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors.OpError, 'logits shape'): spec.predictions[prediction_keys.PredictionKeys.PREDICTIONS].eval({ logits_placeholder: logits_1d @@ -2857,7 +2857,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_placeholder, labels=labels_placeholder) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors.OpError, 'logits shape'): spec.loss.eval({ labels_placeholder: values_3d, @@ -2868,7 +2868,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_placeholder, labels=labels_placeholder)[0] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'): @@ -2908,7 +2908,7 @@ class RegressionHead(test.TestCase): logits=logits_placeholder, labels=labels_placeholder, train_op_fn=lambda x: x) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(errors.OpError, 'logits shape'): spec.loss.eval({ labels_placeholder: values_3d, @@ -2919,7 +2919,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits_placeholder, labels=labels_placeholder)[0] - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'): @@ -2957,7 +2957,7 @@ class RegressionHead(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions. - with self.test_session(): + with self.cached_session(): _initialize_variables(self, spec.scaffold) self.assertAllClose(logits, spec.predictions[prediction_key].eval()) self.assertAllClose( @@ -2992,7 +2992,7 @@ class RegressionHead(test.TestCase): spec.export_outputs.keys()) # Assert predictions. - with self.test_session(): + with self.cached_session(): _initialize_variables(self, spec.scaffold) self.assertAllClose( expected_predictions, spec.predictions[keys.PREDICTIONS].eval()) @@ -3019,7 +3019,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) # loss = [(43-45)^2, (44-41)] = [4, 9] self.assertAllClose(13., training_loss.eval()) @@ -3045,7 +3045,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits_input, labels=labels_input)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(np.sum(loss), actual_training_loss.eval()) @@ -3064,7 +3064,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -3112,7 +3112,7 @@ class RegressionHead(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[ @@ -3180,7 +3180,7 @@ class RegressionHead(test.TestCase): } # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} @@ -3212,7 +3212,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(expected_training_loss, training_loss.eval()) self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval()) @@ -3237,7 +3237,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(expected_training_loss, training_loss.eval()) self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval()) @@ -3294,7 +3294,7 @@ class RegressionHead(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions, loss, train_op, and summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) predictions, loss, train_result, summary_str = sess.run(( @@ -3337,7 +3337,7 @@ class RegressionHead(test.TestCase): labels=labels, optimizer=_Optimizer()) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) loss, train_result = sess.run((spec.loss, spec.train_op)) self.assertAllClose(expected_loss, loss) @@ -3364,7 +3364,7 @@ class RegressionHead(test.TestCase): labels=np.array(((43.,), (44.,),), dtype=np.float64), train_op_fn=_train_op_fn) - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) sess.run(spec.train_op) w_value, t_value = sess.run([w, t]) @@ -3394,7 +3394,7 @@ class RegressionHead(test.TestCase): train_op_fn=_train_op_fn) # Assert summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) summary_str = sess.run(spec.scaffold.summary_op) @@ -3441,7 +3441,7 @@ class RegressionHead(test.TestCase): regularization_losses=regularization_losses) # Assert predictions, loss, train_op, and summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) prediction_key = prediction_keys.PredictionKeys.PREDICTIONS @@ -3487,7 +3487,7 @@ class RegressionHead(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[ @@ -3523,7 +3523,7 @@ class RegressionHead(test.TestCase): labels=np.array(((35,), (42,), (45,)), dtype=np.int32)) # Assert loss. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) loss = sess.run(spec.loss) # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6 @@ -3565,7 +3565,7 @@ class RegressionHead(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions, loss, train_op, and summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) predictions, loss, train_result, summary_str = sess.run(( @@ -3600,7 +3600,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels_rank_1) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(expected_training_loss, training_loss.eval()) self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval()) @@ -3648,7 +3648,7 @@ class RegressionHead(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions, loss, train_op, and summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) predictions, loss, train_result, summary_str = sess.run(( @@ -3679,7 +3679,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) # loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1]. # weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6 @@ -3718,7 +3718,7 @@ class RegressionHead(test.TestCase): _assert_no_hooks(self, spec) # Assert predictions, loss, and metrics. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[ @@ -3750,7 +3750,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels)[0] - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) # loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1]. # weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6 @@ -3796,7 +3796,7 @@ class RegressionHead(test.TestCase): _assert_no_hooks(self, spec) # Evaluate predictions, loss, train_op, and summaries. - with self.test_session() as sess: + with self.cached_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) predictions, loss, train_result, summary_str = sess.run(( @@ -3857,7 +3857,7 @@ class RegressionHead(test.TestCase): self.assertIsNone(spec.train_op) _assert_no_hooks(self, spec) - with self.test_session() as sess: + with self.cached_session() as sess: # Finalize graph and initialize variables. _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) @@ -3915,7 +3915,7 @@ class RegressionHead(test.TestCase): self.assertEqual(dtypes.float32, spec.loss.dtype) self.assertIsNotNone(spec.train_op) - with self.test_session() as sess: + with self.cached_session() as sess: # Finalize graph and initialize variables. _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) @@ -3955,7 +3955,7 @@ class RegressionHead(test.TestCase): mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(expected_training_loss, training_loss.eval()) self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval()) @@ -3988,7 +3988,7 @@ class RegressionHead(test.TestCase): logits=logits, labels=labels, train_op_fn=_train_op_fn) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose(expected_loss, spec.loss.eval()) @@ -4013,7 +4013,7 @@ class RegressionHead(test.TestCase): logits=logits, labels=labels, train_op_fn=_no_op_train_fn) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -4042,7 +4042,7 @@ class RegressionHead(test.TestCase): logits=logits, labels=labels, train_op_fn=_no_op_train_fn) - with self.test_session(): + with self.cached_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py index 4e7b00b307..632908415f 100644 --- a/tensorflow/python/estimator/inputs/numpy_io_test.py +++ b/tensorflow/python/estimator/inputs/numpy_io_test.py @@ -42,7 +42,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = np.arange(-32, -28) - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) features, target = input_fn() @@ -68,7 +68,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = np.arange(-32, -30) - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=128, shuffle=False, num_epochs=2) features, target = input_fn() @@ -93,7 +93,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = np.arange(-32, -28) - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=0) features, target = input_fn() @@ -114,7 +114,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = np.arange(-32, -27) - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=batch_size, shuffle=False, num_epochs=1) features, target = input_fn() @@ -150,7 +150,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = np.arange(-32, -29) - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=batch_size, shuffle=False, num_epochs=3) features, target = input_fn() @@ -196,7 +196,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = np.arange(-32, -28) - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=batch_size, shuffle=False, num_epochs=1) features, target = input_fn() @@ -221,7 +221,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = np.arange(-32, -30) - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) features, target = input_fn() @@ -240,7 +240,7 @@ class NumpyIoTest(test.TestCase): def testNumpyInputFnWithXAsNonDict(self): x = list(range(32, 36)) y = np.arange(4) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, 'x must be a dict or array'): failing_input_fn = numpy_io.numpy_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) @@ -249,7 +249,7 @@ class NumpyIoTest(test.TestCase): def testNumpyInputFnWithXIsEmptyDict(self): x = {} y = np.arange(4) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'): failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) failing_input_fn() @@ -257,7 +257,7 @@ class NumpyIoTest(test.TestCase): def testNumpyInputFnWithXIsEmptyArray(self): x = np.array([[], []]) y = np.arange(4) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'): failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) failing_input_fn() @@ -268,7 +268,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = None - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) features_tensor = input_fn() @@ -291,7 +291,7 @@ class NumpyIoTest(test.TestCase): def testNumpyInputFnWithNonBoolShuffle(self): x = np.arange(32, 36) y = np.arange(4) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, 'shuffle must be provided and explicitly ' 'set as boolean'): @@ -303,7 +303,7 @@ class NumpyIoTest(test.TestCase): x = {'__target_key__': array} y = np.arange(4) - with self.test_session(): + with self.cached_session(): input_fn = numpy_io.numpy_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) input_fn() @@ -318,7 +318,7 @@ class NumpyIoTest(test.TestCase): x_mismatch_length = {'a': np.arange(1), 'b': b} y_longer_length = np.arange(10) - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( ValueError, 'Length of tensors in x and y is mismatched.'): failing_input_fn = numpy_io.numpy_input_fn( @@ -341,7 +341,7 @@ class NumpyIoTest(test.TestCase): x = {'a': a, 'b': b} y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)} - with self.test_session() as session: + with self.cached_session() as session: input_fn = numpy_io.numpy_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) features_tensor, targets_tensor = input_fn() @@ -369,7 +369,7 @@ class NumpyIoTest(test.TestCase): b = np.arange(32, 36) x = {'a': a, 'b': b} y = {} - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, 'y cannot be empty'): failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) failing_input_fn() @@ -379,7 +379,7 @@ class NumpyIoTest(test.TestCase): b = np.arange(32, 36) x = {'a': a, 'b': b} y = {'y1': np.arange(-32, -28), 'a': a, 'y2': np.arange(32, 28, -1), 'b': b} - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp( ValueError, '2 duplicate keys are found in both x and y'): failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False) diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py index 6f13bc95d2..9e69fc72dc 100644 --- a/tensorflow/python/estimator/inputs/pandas_io_test.py +++ b/tensorflow/python/estimator/inputs/pandas_io_test.py @@ -102,7 +102,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_ProducesExpectedOutputs(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) @@ -116,7 +116,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrameWithYAsDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) @@ -131,7 +131,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrameWithYAsDataFrame() y = y.rename(columns={'a_target': 'a', 'b_target': 'b'}) input_fn = pandas_io.pandas_input_fn( @@ -147,7 +147,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrameWithYAsDataFrame() y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'}) input_fn = pandas_io.pandas_input_fn( @@ -163,7 +163,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: index = np.arange(100, 102) a = np.arange(2) b = np.arange(32, 34) @@ -191,7 +191,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: index = np.arange(100, 105) a = np.arange(5) b = np.arange(32, 37) @@ -230,7 +230,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_OnlyX(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, _ = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y=None, batch_size=2, shuffle=False, num_epochs=1) @@ -243,7 +243,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_ExcludesIndex(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=2, shuffle=False, num_epochs=1) @@ -266,7 +266,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_RespectsEpoch_NoShuffle(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=4, shuffle=False, num_epochs=1) @@ -276,7 +276,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_RespectsEpoch_WithShuffle(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=4, shuffle=True, num_epochs=1) @@ -286,7 +286,7 @@ class PandasIoTest(test.TestCase): def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self): if not HAS_PANDAS: return - with self.test_session() as session: + with self.cached_session() as session: x, y = self.makeTestDataFrame() input_fn = pandas_io.pandas_input_fn( x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2) @@ -297,7 +297,7 @@ class PandasIoTest(test.TestCase): if not HAS_PANDAS: return x, y = self.makeTestDataFrame() - with self.test_session() as session: + with self.cached_session() as session: input_fn = pandas_io.pandas_input_fn( x, y, batch_size=3, shuffle=False, num_epochs=1) diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py index e85f812ce2..a44c570fb9 100644 --- a/tensorflow/python/training/checkpointable/tracking_test.py +++ b/tensorflow/python/training/checkpointable/tracking_test.py @@ -165,7 +165,7 @@ class InterfaceTests(test.TestCase): self.assertEqual([c], a.attribute["c"].layers) checkpoint = util.Checkpoint(a=a) save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) - with self.test_session(): + with self.cached_session(): checkpoint.restore(save_path).assert_consumed().initialize_or_restore() @test_util.run_in_graph_and_eager_modes diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index 0d32d21426..f8b5bd8501 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -384,7 +384,7 @@ class CheckpointingTests(test.TestCase): saver = saver_lib.Saver(var_list=[v]) test_dir = self.get_temp_dir() prefix = os.path.join(test_dir, "ckpt") - with self.test_session() as sess: + with self.cached_session() as sess: self.evaluate(v.non_dep_variable.assign(42.)) save_path = saver.save(sess, prefix) self.evaluate(v.non_dep_variable.assign(43.)) -- cgit v1.2.3 From 6d3af1df20f611641665f63e8bb49a875823432b Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 10 Sep 2018 14:40:21 -0700 Subject: Add support for list literals in template replacement values. PiperOrigin-RevId: 212337233 --- tensorflow/contrib/autograph/pyct/templates.py | 6 ++-- .../contrib/autograph/pyct/templates_test.py | 36 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index 5831d57ceb..d81c50f524 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -113,7 +113,7 @@ class ReplaceTransformer(gast.NodeTransformer): if isinstance(node, gast.Attribute): self._check_inner_children_have_context(node.value) self._check_has_context(node) - elif isinstance(node, gast.Tuple): + elif isinstance(node, (gast.Tuple, gast.List)): for e in node.elts: self._check_inner_children_have_context(e) self._check_has_context(node) @@ -142,7 +142,7 @@ class ReplaceTransformer(gast.NodeTransformer): if isinstance(node, gast.Attribute): self._set_inner_child_context(node.value, gast.Load()) node.ctx = ctx - elif isinstance(node, gast.Tuple): + elif isinstance(node, (gast.Tuple, gast.List)): for e in node.elts: self._set_inner_child_context(e, ctx) node.ctx = ctx @@ -191,7 +191,7 @@ class ReplaceTransformer(gast.NodeTransformer): # Preserve the target context. for n in new_nodes: - if isinstance(n, gast.Tuple): + if isinstance(n, (gast.Tuple, gast.List)): for e in n.elts: self._set_inner_child_context(e, node.ctx) if isinstance(n, gast.Attribute): diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py index 77e8ff62fd..074105ea50 100644 --- a/tensorflow/contrib/autograph/pyct/templates_test.py +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -110,6 +110,42 @@ class TemplatesTest(test.TestCase): self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load) self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load) + def test_replace_list_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) + + def test_replace_tuple_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store) + + def test_replace_complex_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace( + template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + function_call_arg = node.body[0].targets[0].value.args[0] + self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load) + self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load) + self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load) + def test_replace_call_keyword(self): template = """ def test_fn(): -- cgit v1.2.3 From a5752eb9cb266262f3b7a289f12c21e268b3041d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 14:44:43 -0700 Subject: Move from deprecated self.test_session() to self.cached_session(). self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212338134 --- .../python/kernel_tests/batch_dataset_op_test.py | 54 +++++++++--------- .../data/python/kernel_tests/bucketing_test.py | 32 +++++------ .../directed_interleave_dataset_test.py | 6 +- .../python/kernel_tests/get_single_element_test.py | 4 +- .../kernel_tests/indexed_dataset_ops_test.py | 6 +- .../kernel_tests/interleave_dataset_op_test.py | 36 ++++++------ .../python/kernel_tests/lmdb_dataset_op_test.py | 2 +- .../python/kernel_tests/map_dataset_op_test.py | 6 +- .../data/python/kernel_tests/parsing_ops_test.py | 2 +- .../python/kernel_tests/prefetching_ops_test.py | 28 +++++----- .../python/kernel_tests/range_dataset_op_test.py | 4 +- .../python/kernel_tests/reader_dataset_ops_test.py | 2 +- .../data/python/kernel_tests/resample_test.py | 6 +- .../python/kernel_tests/scan_dataset_op_test.py | 6 +- .../python/kernel_tests/shuffle_dataset_op_test.py | 2 +- .../python/kernel_tests/slide_dataset_op_test.py | 14 ++--- .../python/kernel_tests/sql_dataset_op_test.py | 64 +++++++++++----------- .../contrib/data/python/kernel_tests/test_utils.py | 4 +- .../kernel_tests/threadpool_dataset_ops_test.py | 2 +- .../python/kernel_tests/unique_dataset_op_test.py | 2 +- .../python/kernel_tests/window_dataset_op_test.py | 22 ++++---- .../data/python/kernel_tests/writer_ops_test.py | 6 +- 22 files changed, 155 insertions(+), 155 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 67242fecfe..8e368bf2bc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize with an input tensor of incompatible rank. sess.run(init_op, feed_dict={input_tensor: [[1]]}) with self.assertRaisesRegexp(errors.InvalidArgumentError, @@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i,) * 3, sess.run(op)) @@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) @@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): st_row = sess.run(next_element) self.assertEqual([i], st_row.indices) @@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): dense_elem, st_row = sess.run(next_element) self.assertEqual(i, dense_elem) @@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i,),) * 3, sess.run(op)) @@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) @@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) @@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Mismatch in the 0th dimension. sess.run( iterator.initializer, @@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Batch of a finite input, where the batch_size divides the # total number of elements. sess.run(init_op, feed_dict={count: 28, batch_size: 14}) @@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): else: self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) if not drop_remainder: @@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_one_shot_iterator()) self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) self.assertAllEqual([[64], [81]], sess.run(next_element)) @@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(4): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) .make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) @@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp(errors.InvalidArgumentError, "number of elements does not match"): @@ -659,7 +659,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(3): sess.run(get_next) @@ -686,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=10)).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(threshold // 10): self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) if threshold % 10 != 0: @@ -718,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(10): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) @@ -784,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase): iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) @@ -908,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 2022c1f2bd..293be2bd06 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase): def checkResults(self, dataset, shapes, values): self.assertEqual(shapes, dataset.output_shapes) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for expected in values: got = sess.run(get_next) self.assertEqual(got, expected) @@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase): self.assertIs(None, dataset.output_shapes[1].ndims) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual([0] * (2**i), x) self.assertAllEqual(np.array(1, ndmin=i), y) @@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase): (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual(x, np.asarray([x for x in range(10)])) self.assertEqual(y, 45) @@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # The input is infinite, so this test demonstrates that: # 1. We produce output without having to consume the entire input, @@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) @@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -376,7 +376,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) which_bucket, bucketed_values = sess.run(get_next) @@ -411,7 +411,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches (one containing even values, one containing odds) @@ -482,7 +482,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches ([0, 2, ...] and [64, 66, ...]) @@ -515,7 +515,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.OutOfRangeError): batches = 0 @@ -556,7 +556,7 @@ class BucketBySequenceLength(test.TestCase): element_len, boundaries, batch_sizes)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(4): batches.append(sess.run(batch)) @@ -600,7 +600,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(3): batches.append(sess.run(batch)) @@ -637,7 +637,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(5): batches.append(sess.run(batch)) diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index 9020a499c4..eb110324d1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for _ in range(100): for i in range(10): @@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: freqs = np.zeros([num_datasets]) for _ in range(num_samples): freqs[sess.run(next_element)] += 1 @@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in choice_array: self.assertEqual(words[i], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index e6883d53e0..f3968cdc15 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase): lambda x: (x * x, make_sparse(x))).take(take_t) element = get_single_element.get_single_element(dataset) - with self.test_session() as sess: + with self.cached_session() as sess: if error is None: dense_val, sparse_val = sess.run( element, feed_dict={ @@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase): dataset = dataset_ops.Dataset.range(stop_t) element = get_single_element.reduce_dataset(dataset, sum_reducer) - with self.test_session() as sess: + with self.cached_session() as sess: value = sess.run(element, feed_dict={stop_t: stop}) self.assertEqual(stop * (stop - 1) / 2, value) diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py index db2ab815ee..9c508d686d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py @@ -44,14 +44,14 @@ class IndexedDatasetOpsTest(test.TestCase): get_op = gen_dataset_ops.indexed_dataset_get( handle, index, output_types=[dtypes.uint64], output_shapes=[[]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(materialize) self.assertEqual([3], sess.run(get_op, feed_dict={index: 3})) def testIdentityIndexedDataset(self): ds = indexed_dataset_ops.IdentityIndexedDataset(16) materialized = ds.materialize() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(materialized.initializer) placeholder = array_ops.placeholder(dtypes.uint64, shape=[]) for i in range(16): @@ -66,7 +66,7 @@ class IndexedDatasetOpsTest(test.TestCase): ds = indexed_dataset_ops.IdentityIndexedDataset(16) itr = ds.make_initializable_iterator() n = itr.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(itr.initializer) for i in range(16): output = sess.run(n) diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 7a3215f6cc..b9e74dfddb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testSingleThreadedRagged(self): # Tests a sequence with wildly different elements per iterator. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): Args: sloppy: Whether to be sloppy or not. """ - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionBlockLength(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): Args: sloppy: Whether to be sloppy or not. """ - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) def _testEmptyInput(self, sloppy=False): - with self.test_session() as sess: + with self.cached_session() as sess: # Empty input. self._clear_coordination_events() sess.run( @@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False): # Non-empty input leading to empty output. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds # Mixture of non-empty and empty interleaved datasets. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid # head-of-line blocking. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) def testBlockLengthWithContentionSloppy(self): - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testEarlyExit(self, sloppy=False): # Exiting without consuming all input should not block - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: output_values = [] for _ in range(30): output_values.append(sess.run(iterator.get_next())) @@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): for j in range(2): @@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(get_next) def testErrorsInOutputFn(self): - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={ @@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={ @@ -792,7 +792,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): next_element = iterator.get_next() results = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): elements = [] sess.run(iterator.initializer) diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index 7bc582ebaa..1cc5ddc9a2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -51,7 +51,7 @@ class LMDBDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(num_repeats): # Dataset is repeated. for i in range(10): # 10 records. diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 55c9ac68dd..e8519381d6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for x in [1., 2., 3., 5.]: self.assertEqual(x, sess.run(get_next)) @@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for x in [1., 2., 3., 5.]: self.assertEqual(x, sess.run(get_next)) @@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # All of the files are present. sess.run(init_op) for filename in filenames: diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py index f6c4a984b8..c4623bca73 100644 --- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py @@ -80,7 +80,7 @@ class ParseExampleTest(test.TestCase): expected_values=None, expected_err=None): - with self.test_session() as sess: + with self.cached_session() as sess: if expected_err: with self.assertRaisesWithPredicateMatch(expected_err[0], expected_err[1]): diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 361fe0dd39..0166ba0d44 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): destroy_op = resource_variable_ops.destroy_resource_op( buffer_resource_handle, ignore_lookup_error=True) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual([b"a"], sess.run(prefetch_op)) self.assertEqual([b"b"], sess.run(prefetch_op)) self.assertEqual([b"c"], sess.run(prefetch_op)) @@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase): iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase): iterator = back_to_cpu_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase): elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() - with self.test_session() as sess: + with self.cached_session() as sess: # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. with self.assertRaises(errors.FailedPreconditionError): diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 592642da0c..db8fe6aa1b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase): self.assertEqual([tensor_shape.TensorShape([])] * 3, [t.shape for t in get_next[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next)) self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next)) @@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase): .make_one_shot_iterator()) negative_get_next = negative_iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(get_next)) self.assertEqual(3 + 4, sess.run(get_next)) self.assertEqual(3 + 2 * 4, sess.run(get_next)) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index fd00cdc5c6..ed75b27a44 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -116,7 +116,7 @@ class ReadBatchFeaturesTest( init_op = iterator.initializer next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for file_batch, _, _, _, record_batch, _ in self._next_expected_batch( range(self._num_files), 2, 10): diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index c5cfddb72b..16b1441baa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): class_func=lambda c, _: c, seed=27)).make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] while len(returned) < 4000: returned.append(sess.run(get_next)) @@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] with self.assertRaises(errors.OutOfRangeError): while True: @@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] with self.assertRaises(errors.OutOfRangeError): while True: diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 42cada0b97..dde678bd54 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase): start, make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), (10, 2, 10), (10, -1, 10), @@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase): make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), (10, 2, 10), (10, -1, 10), @@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): (longer_vector_val, larger_rank_val), _ = sess.run(next_element) self.assertAllEqual([0] * (2**i), longer_vector_val) diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 077abd6b30..440e48db30 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase): def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): get_next = ds_fn().make_one_shot_iterator().get_next() outputs = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(num_outputs): outputs.append(sess.run(get_next)) if verify_exhausted: diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 6b3e8e9f6e..90d18dca2a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): window_stride=window_stride_t)).make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( init_op, @@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): @@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): @@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Slide: 1st batch. actual = sess.run(get_next) @@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) with self.assertRaisesRegexp( errors.InvalidArgumentError, diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 2c2cfbebff..52823d3fca 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string), 2) - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): # Run twice to verify statelessness of db operations. sess.run( init_op, @@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetJoinQuery(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetNullTerminator(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetReuseSqlDataset(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadEmptyResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithInvalidDriverName(self): init_op = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string))[0] - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( init_op, @@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithInvalidColumnName(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetOfQueryWithSyntaxError(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetOfInsertQuery(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int8` tensor. def testReadResultSetInt8(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetInt8NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int8` tensor. def testReadResultSetInt8MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int16` tensor. def testReadResultSetInt16(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetInt16NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int16` tensor. def testReadResultSetInt16MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int32` tensor. def testReadResultSetInt32(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # table and place it in an `int32` tensor. def testReadResultSetInt32VarCharColumnAsInt(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # and place it in an `int64` tensor. def testReadResultSetInt64(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place it in an `int64` tensor. def testReadResultSetInt64NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int64` tensor. def testReadResultSetInt64MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in a `uint8` tensor. def testReadResultSetUInt8(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place them in `uint8` tensors. def testReadResultSetUInt8MinAndMaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # and place it in a `uint16` tensor. def testReadResultSetUInt16(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place them in `uint16` tensors. def testReadResultSetUInt16MinAndMaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # in `bool` tensors. def testReadResultSetBool(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # from a SQLite database table and place it as `True` in a `bool` tensor. def testReadResultSetBoolNotZeroOrOne(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64OverlyPrecise(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py index 1d70b16041..1def07179a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -31,7 +31,7 @@ class DatasetTestBase(test.TestCase): # TODO(rachelim): support sparse tensor outputs next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: while True: try: op1 = sess.run(next1) @@ -54,7 +54,7 @@ class DatasetTestBase(test.TestCase): replacements=None): next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: try: sess.run(next1) raise ValueError( diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 4b08ec759d..8d335e87d5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -69,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) thread_ids = [] try: diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index d79a842e7a..f994c8563f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_case, expected in test_cases: current_test_case = test_case sess.run(iterator.initializer) diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index ff4d9b3260..6eaa0b1959 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -92,7 +92,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): dataset = self._structuredDataset(structure, shape, dtype).apply( grouping.window_dataset(5)).flat_map(fn) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run(self._structuredElement(structure, shape, dtype)) actual = sess.run(get_next) self._assertEqual(expected, actual) @@ -128,7 +128,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredElement(structure, np.concatenate( ([5], shape), axis=0), dtype)) @@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredElement(None, np.concatenate(([5], shape), axis=0), @@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredSparseElement(structure, np.concatenate(([5], shape), axis=0), @@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredSparseElement(None, @@ -321,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) expected = sess.run( self._structuredElement( @@ -352,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) expected = sess.run( @@ -380,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -458,7 +458,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): structure, shapes, dtype).apply(grouping.window_dataset( len(shapes))).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredRaggedSparseElement(structure, shapes, dtype, padded_shape)) @@ -489,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected = sess.run( self._structuredRaggedSparseElement(None, shapes, dtypes.int32, @@ -516,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py index c603ecc5ab..867ee2ba37 100644 --- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py @@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase): return os.path.join(self.get_temp_dir(), "tf_record.out.txt") def testWrite(self): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ self.filename: self._createFile(), @@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase): def testWriteZLIB(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ @@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase): def testWriteGZIP(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ -- cgit v1.2.3 From e6cce55e57722d8ba587965b8ef511838c6d1391 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 15:35:18 -0700 Subject: Fix some build breakage due to de-std::unique_ptr cleanup. PiperOrigin-RevId: 212347506 --- .../compiler/xla/service/cpu/sample_harness.cc | 30 ++++++++++------------ tensorflow/compiler/xla/tools/show_literal.cc | 4 +-- tensorflow/compiler/xla/tools/show_text_literal.cc | 16 ++++++------ 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 942e2ddd39..55d5925642 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -37,21 +37,20 @@ int main(int argc, char** argv) { xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie()); // Transfer parameters. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); // Build computation. xla::XlaBuilder builder(""); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p1, p0, {0}); xla::StatusOr computation_status = builder.Build(); @@ -59,17 +58,16 @@ int main(int argc, char** argv) { // Execute and transfer result of computation. xla::ExecutionProfile profile; - xla::StatusOr> result = - client->ExecuteAndTransfer( - computation, - /*arguments=*/{param0_data.get(), param1_data.get()}, - /*execution_options=*/nullptr, - /*execution_profile=*/&profile); - std::unique_ptr actual = result.ConsumeValueOrDie(); + xla::StatusOr result = client->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get(), param1_data.get()}, + /*execution_options=*/nullptr, + /*execution_profile=*/&profile); + xla::Literal actual = result.ConsumeValueOrDie(); LOG(INFO) << absl::StrFormat("computation took %dns", profile.compute_time_ns()); - LOG(INFO) << actual->ToString(); + LOG(INFO) << actual.ToString(); return 0; } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index 51909190a3..4f8852f8c1 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -40,8 +40,8 @@ int main(int argc, char **argv) { xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], &literal_proto)); - std::unique_ptr literal = + xla::Literal literal = xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", literal->ToString().c_str()); + fprintf(stderr, "%s\n", literal.ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 48c8374811..4b5c276bdf 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -36,16 +36,16 @@ int main(int argc, char **argv) { LOG(QFATAL) << "Usage: " << argv[0] << " "; } - std::unique_ptr literal = + xla::Literal literal = xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); - LOG(INFO) << "literal: " << *literal; - fprintf(stderr, "%s\n", literal->ToString().c_str()); - if (literal->shape().element_type() == xla::F32) { - float min = *std::min_element(literal->data().begin(), - literal->data().end()); - float max = *std::max_element(literal->data().begin(), - literal->data().end()); + LOG(INFO) << "literal: " << literal; + fprintf(stderr, "%s\n", literal.ToString().c_str()); + if (literal.shape().element_type() == xla::F32) { + float min = *std::min_element(literal.data().begin(), + literal.data().end()); + float max = *std::max_element(literal.data().begin(), + literal.data().end()); fprintf(stderr, "min: %a=%f\n", min, min); fprintf(stderr, "max: %a=%f\n", max, max); } -- cgit v1.2.3 From 6951e0646d7dc8931b6cbe4388dcc3921249d462 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 10 Sep 2018 15:38:23 -0700 Subject: Only keep alive outputs and inputs that are required to be kept alive. The backward function used to keep all or none of the inputs/outputs alive. This CL makes that a little more granular. PiperOrigin-RevId: 212348042 --- tensorflow/python/eager/pywrap_tfe_src.cc | 293 +++++++++++++++++++----------- 1 file changed, 182 insertions(+), 111 deletions(-) diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 1ed814258b..c6a55949ab 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1740,117 +1740,167 @@ PyObject* MaybeGetDTypeForAttr(const string& attr, Py_RETURN_NONE; } -bool OpDoesntRequireOutput(const string& op_name) { - static tensorflow::gtl::FlatSet* ops_that_dont_require_outputs = - new tensorflow::gtl::FlatSet({ - "Identity", - "MatMul", - "Conv2DBackpropInput", - "Conv2DBackpropFilter", - "Conv3D", - "Conv3DBackpropInputV2", - "AvgPool3D", - "AvgPool3DGrad", - "MaxPool3D", - "MaxPool3DGrad", - "MaxPool3DGradGrad", - "BiasAdd", - "BiasAddV1", - "BiasAddGrad", - "Softplus", - "SoftplusGrad", - "Softsign", - "ReluGrad", - "Conv2D", - "DepthwiseConv2dNative", - "Dilation2D", - "AvgPool", - "AvgPoolGrad", - "BatchNormWithGlobalNormalization", - "L2Loss", - "Sum", - "Prod", - "SegmentSum", - "SegmentMean", - "SparseSegmentSum", - "SparseSegmentMean", - "SparseSegmentSqrtN", - "SegmentMin", - "SegmentMax", - "UnsortedSegmentSum", - "UnsortedSegmentMax", - "Abs", - "Neg", - "ReciprocalGrad", - "Square", - "Expm1", - "Log", - "Log1p", - "TanhGrad", - "SigmoidGrad", - "Sign", - "Sin", - "Cos", - "Tan", - "Add", - "Sub", - "Mul", - "Div", - "RealDiv", - "Maximum", - "Minimum", - "SquaredDifference", - "Select", - "SparseMatMul", - "BatchMatMul", - "Complex", - "Real", - "Imag", - "Angle", - "Conj", - "Cast", - "Cross", - "Cumsum", - "Cumprod", - "ReadVariableOp", - "VarHandleOp", - "Shape", - "StridedSlice", +// Returns a pair where the first value of the pair indicates whether or not all +// outputs are unused. If the first value is false, the second value is a +// set that identifies which of the output indices are unused. +bool OpGradientDoesntRequireOutputIndices( + const string& op_name, + std::pair>** output) { + static tensorflow::gtl::FlatMap< + string, std::pair>>* m = + new tensorflow::gtl::FlatMap< + string, std::pair>>({ + // Ops that don't require any outputs. + {"Identity", {true, {}}}, + {"MatMul", {true, {}}}, + {"Conv2DBackpropInput", {true, {}}}, + {"Conv2DBackpropFilter", {true, {}}}, + {"Conv3D", {true, {}}}, + {"Conv3DBackpropInputV2", {true, {}}}, + {"AvgPool3D", {true, {}}}, + {"AvgPool3DGrad", {true, {}}}, + {"MaxPool3D", {true, {}}}, + {"MaxPool3DGrad", {true, {}}}, + {"MaxPool3DGradGrad", {true, {}}}, + {"BiasAdd", {true, {}}}, + {"BiasAddV1", {true, {}}}, + {"BiasAddGrad", {true, {}}}, + {"Softplus", {true, {}}}, + {"SoftplusGrad", {true, {}}}, + {"Softsign", {true, {}}}, + {"ReluGrad", {true, {}}}, + {"Conv2D", {true, {}}}, + {"DepthwiseConv2dNative", {true, {}}}, + {"Dilation2D", {true, {}}}, + {"AvgPool", {true, {}}}, + {"AvgPoolGrad", {true, {}}}, + {"BatchNormWithGlobalNormalization", {true, {}}}, + {"L2Loss", {true, {}}}, + {"Sum", {true, {}}}, + {"Prod", {true, {}}}, + {"SegmentSum", {true, {}}}, + {"SegmentMean", {true, {}}}, + {"SparseSegmentSum", {true, {}}}, + {"SparseSegmentMean", {true, {}}}, + {"SparseSegmentSqrtN", {true, {}}}, + {"SegmentMin", {true, {}}}, + {"SegmentMax", {true, {}}}, + {"UnsortedSegmentSum", {true, {}}}, + {"UnsortedSegmentMax", {true, {}}}, + {"Abs", {true, {}}}, + {"Neg", {true, {}}}, + {"ReciprocalGrad", {true, {}}}, + {"Square", {true, {}}}, + {"Expm1", {true, {}}}, + {"Log", {true, {}}}, + {"Log1p", {true, {}}}, + {"TanhGrad", {true, {}}}, + {"SigmoidGrad", {true, {}}}, + {"Sign", {true, {}}}, + {"Sin", {true, {}}}, + {"Cos", {true, {}}}, + {"Tan", {true, {}}}, + {"Add", {true, {}}}, + {"Sub", {true, {}}}, + {"Mul", {true, {}}}, + {"Div", {true, {}}}, + {"RealDiv", {true, {}}}, + {"Maximum", {true, {}}}, + {"Minimum", {true, {}}}, + {"SquaredDifference", {true, {}}}, + {"Select", {true, {}}}, + {"SparseMatMul", {true, {}}}, + {"BatchMatMul", {true, {}}}, + {"Complex", {true, {}}}, + {"Real", {true, {}}}, + {"Imag", {true, {}}}, + {"Angle", {true, {}}}, + {"Conj", {true, {}}}, + {"Cast", {true, {}}}, + {"Cross", {true, {}}}, + {"Cumsum", {true, {}}}, + {"Cumprod", {true, {}}}, + {"ReadVariableOp", {true, {}}}, + {"VarHandleOp", {true, {}}}, + {"Shape", {true, {}}}, + {"StridedSlice", {true, {}}}, + {"Fill", {true, {}}}, + + // Ops that don't require a subset of outputs. + {"FusedBatchNorm", {false, {0, 1, 2}}}, }); - return ops_that_dont_require_outputs->find(op_name) != - ops_that_dont_require_outputs->end(); -} - -bool OpDoesntRequireInput(const string& op_name) { - static tensorflow::gtl::FlatSet* ops_that_dont_require_inputs = - new tensorflow::gtl::FlatSet({ - "Identity", - "Softmax", - "LogSoftmax", - "BiasAdd", - "Relu", - "Relu6", - "Elu", - "Selu", - "SparseSoftmaxCrossEntropyWithLogits", - "Neg", - "Inv", - "Reciprocal", - "Sqrt", - "Exp", - "Tanh", - "Sigmoid", - "Real", - "Imag", - "Conj", - "ReadVariableOp", - "VarHandleOp", - "Shape", + auto it = m->find(op_name); + + if (it == m->end()) return false; + + *output = &it->second; + return true; +} + +// Returns a pair where the first value of the pair indicates whether or not all +// inputs are unused. If the first value is false, the second value is a +// set that identifies which of the input indices are unused. +bool OpGradientDoesntRequireInputIndices( + const string& op_name, + std::pair>** output) { + static tensorflow::gtl::FlatMap< + string, std::pair>>* m = + new tensorflow::gtl::FlatMap< + string, std::pair>>({ + // Ops that don't require any inputs. + {"Identity", {true, {}}}, + {"Softmax", {true, {}}}, + {"LogSoftmax", {true, {}}}, + {"BiasAdd", {true, {}}}, + {"Relu", {true, {}}}, + {"Relu6", {true, {}}}, + {"Elu", {true, {}}}, + {"Selu", {true, {}}}, + {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}}, + {"Neg", {true, {}}}, + {"Inv", {true, {}}}, + {"Reciprocal", {true, {}}}, + {"Sqrt", {true, {}}}, + {"Exp", {true, {}}}, + {"Tanh", {true, {}}}, + {"Sigmoid", {true, {}}}, + {"Real", {true, {}}}, + {"Imag", {true, {}}}, + {"Conj", {true, {}}}, + {"ReadVariableOp", {true, {}}}, + {"VarHandleOp", {true, {}}}, + {"Shape", {true, {}}}, + {"Fill", {true, {}}}, + + // Ops that don't require a subset of inputs. + {"FusedBatchNorm", {false, {2}}}, }); - return ops_that_dont_require_inputs->find(op_name) != - ops_that_dont_require_inputs->end(); + auto it = m->find(op_name); + + if (it == m->end()) return false; + + *output = &it->second; + return true; +} + +PyObject* CopySequenceSettingIndicesToNull( + PyObject* seq, const tensorflow::gtl::FlatSet& indices) { + tensorflow::Safe_PyObjectPtr fast_seq( + PySequence_Fast(seq, "unable to allocate")); + PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get())); + for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) { + PyObject* item; + if (indices.find(i) != indices.end()) { + item = Py_None; + } else { + item = PySequence_Fast_GET_ITEM(fast_seq.get(), i); + } + Py_INCREF(item); + PyTuple_SET_ITEM(result, i, item); + } + return result; } PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, @@ -1870,16 +1920,35 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, if (!should_record) Py_RETURN_NONE; string c_op_name = TFE_GetPythonString(op_name); + PyObject* op_outputs; - if (OpDoesntRequireOutput(c_op_name)) { - op_outputs = Py_None; + bool op_outputs_tuple_created = false; + std::pair>* outputs_not_required; + + if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) { + if (outputs_not_required->first) { + op_outputs = Py_None; + } else { + op_outputs_tuple_created = true; + op_outputs = CopySequenceSettingIndicesToNull( + results, outputs_not_required->second); + } } else { op_outputs = results; } PyObject* op_inputs; - if (OpDoesntRequireInput(c_op_name)) { - op_inputs = Py_None; + bool op_inputs_tuple_created = false; + std::pair>* inputs_not_required; + + if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) { + if (inputs_not_required->first) { + op_inputs = Py_None; + } else { + op_inputs_tuple_created = true; + op_inputs = + CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second); + } } else { op_inputs = inputs; } @@ -1922,6 +1991,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, }); Py_DECREF(num_inputs); + if (op_outputs_tuple_created) Py_DECREF(op_outputs); + if (op_inputs_tuple_created) Py_DECREF(op_inputs); Py_RETURN_NONE; } -- cgit v1.2.3 From e32029541ae270a021b266fcc3929b2528f8dff1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 15:43:51 -0700 Subject: Move from deprecated self.test_session() to self.cached_session(). self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212348850 --- .../python/ops/factorization_ops_test.py | 16 +++--- .../factorization/python/ops/gmm_ops_test.py | 6 +- .../factorization/python/ops/kmeans_test.py | 2 +- .../contrib/factorization/python/ops/wals_test.py | 8 +-- .../timeseries/python/timeseries/head_test.py | 2 +- .../python/timeseries/input_pipeline_test.py | 6 +- .../python/timeseries/math_utils_test.py | 23 ++++---- .../python/timeseries/model_utils_test.py | 2 +- .../python/timeseries/state_management_test.py | 6 +- tensorflow/python/framework/file_system_test.py | 2 +- tensorflow/python/framework/function_test.py | 10 ++-- tensorflow/python/framework/importer_test.py | 18 +++--- tensorflow/python/framework/meta_graph_test.py | 9 +-- tensorflow/python/framework/ops_test.py | 50 ++++++++--------- tensorflow/python/framework/sparse_tensor_test.py | 6 +- tensorflow/python/framework/subscribe_test.py | 14 ++--- tensorflow/python/framework/tensor_util_test.py | 2 +- tensorflow/python/keras/engine/saving_test.py | 38 ++++++------- tensorflow/python/keras/engine/sequential_test.py | 4 +- tensorflow/python/keras/engine/topology_test.py | 19 ++++--- tensorflow/python/keras/engine/training_test.py | 64 +++++++++++----------- 21 files changed, 155 insertions(+), 152 deletions(-) diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index bb5140aeb3..6aa62fb82e 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -126,7 +126,7 @@ class WalsModelTest(test.TestCase): observed *= num_rows / 3. if test_rows else num_cols / 2. want_weight_sum = unobserved + observed - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: wals_model = factorization_ops.WALSModel( input_rows=num_rows, input_cols=num_cols, @@ -161,7 +161,7 @@ class WalsModelTest(test.TestCase): def _run_test_process_input(self, use_factors_weights_cache, compute_loss=False): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 @@ -330,7 +330,7 @@ class WalsModelTest(test.TestCase): def _run_test_process_input_transposed(self, use_factors_weights_cache, compute_loss=False): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 @@ -505,7 +505,7 @@ class WalsModelTest(test.TestCase): # trigger the more efficient ALS updates. # Here we test that those two give identical results. def _run_test_als(self, use_factors_weights_cache): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( @@ -583,7 +583,7 @@ class WalsModelTest(test.TestCase): atol=1e-2) def _run_test_als_transposed(self, use_factors_weights_cache): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( @@ -673,7 +673,7 @@ class WalsModelTest(test.TestCase): rows = 15 cols = 11 dims = 3 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand( 3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] @@ -703,7 +703,7 @@ class WalsModelTest(test.TestCase): cols = 11 dims = 3 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand( 3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] @@ -736,7 +736,7 @@ class WalsModelTest(test.TestCase): def keep_index(x): return not (x[0] + x[1]) % 4 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): row_wts = 0.1 + np.random.rand(rows) col_wts = 0.1 + np.random.rand(cols) data = np.dot(np.random.rand(rows, 3), np.random.rand( diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py index 888c3c238c..112e4d289b 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py @@ -99,7 +99,7 @@ class GmmOpsTest(test.TestCase): logging.info('Numpy took %f', time.time() - start_time) start_time = time.time() - with self.test_session() as sess: + with self.cached_session() as sess: op = gmm_ops._covariance( constant_op.constant( data.T, dtype=dtypes.float32), False) @@ -120,7 +120,7 @@ class GmmOpsTest(test.TestCase): graph = ops.Graph() with graph.as_default() as g: g.seed = 5 - with self.test_session() as sess: + with self.cached_session() as sess: data = constant_op.constant(self.data, dtype=dtypes.float32) loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm( data, 'random', num_classes, random_seed=self.seed) @@ -144,7 +144,7 @@ class GmmOpsTest(test.TestCase): def testParams(self): """Tests that the params work as intended.""" num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: # Experiment 1. Update weights only. data = constant_op.constant(self.data, dtype=dtypes.float32) gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes, diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index 88eb9cf692..1ab5418fe4 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -232,7 +232,7 @@ class KMeansTest(KMeansTestBase): self.assertEqual(features.shape, parsed_feature_dict.shape) self.assertEqual(features.dtype, parsed_feature_dict.dtype) # Then check that running the tensor yields the original list of points. - with self.test_session() as sess: + with self.cached_session() as sess: parsed_points = sess.run(parsed_feature_dict) self.assertAllEqual(self.points, parsed_points) diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 31820a18b4..9bdbd05015 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -336,7 +336,7 @@ class WALSMatrixFactorizationTest(test.TestCase): loss = self._model.evaluate( input_fn=eval_input_fn_row, steps=self._num_rows)['loss'] - with self.test_session(): + with self.cached_session(): true_loss = self.calculate_loss() self.assertNear( @@ -354,7 +354,7 @@ class WALSMatrixFactorizationTest(test.TestCase): loss = self._model.evaluate( input_fn=eval_input_fn_col, steps=self._num_cols)['loss'] - with self.test_session(): + with self.cached_session(): true_loss = self.calculate_loss() self.assertNear( @@ -440,7 +440,7 @@ class SweepHookTest(test.TestCase): math_ops.logical_not(is_row_sweep_var))) mark_sweep_done = state_ops.assign(is_sweep_done_var, True) - with self.test_session() as sess: + with self.cached_session() as sess: sweep_hook = wals_lib._SweepHook( is_row_sweep_var, is_sweep_done_var, @@ -491,7 +491,7 @@ class StopAtSweepHookTest(test.TestCase): train_op = state_ops.assign_add(completed_sweeps, 1) hook.begin() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([variables.global_variables_initializer()]) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(train_op) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index e65e7b74d4..647455ae42 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -122,7 +122,7 @@ class EvaluationMetricsTests(test.TestCase): metric[1] for metric in outputs.eval_metric_ops.values()] loss_mean, loss_update = metrics.mean(outputs.loss) metric_update_ops.append(loss_update) - with self.test_session() as sess: + with self.cached_session() as sess: coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(sess, coord=coordinator) variables.local_variables_initializer().run() diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py index 703537abf0..f92148b788 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py @@ -88,7 +88,7 @@ class RandomWindowInputFnTests(test.TestCase): window_size=window_size, batch_size=batch_size) result, _ = input_fn() init_op = variables.local_variables_initializer() - with self.test_session() as session: + with self.cached_session() as session: coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) session.run(init_op) @@ -261,7 +261,7 @@ class WholeDatasetInputFnTests(test.TestCase): def _whole_dataset_input_fn_test_template( self, time_series_reader, num_features, num_samples): result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)() - with self.test_session() as session: + with self.cached_session() as session: session.run(variables.local_variables_initializer()) coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) @@ -340,7 +340,7 @@ class AllWindowInputFnTests(test.TestCase): window_size=window_size) features, _ = input_fn() init_op = variables.local_variables_initializer() - with self.test_session() as session: + with self.cached_session() as session: coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) session.run(init_op) diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py index 02d2524b66..c0de42b15b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py @@ -55,7 +55,7 @@ class MathUtilsTest(test.TestCase): running_sum = running_sum + current_contribution # pylint: enable=g-no-augmented-assignment transition_power = numpy.dot(transition, transition_power) - with self.test_session(): + with self.cached_session(): self.assertAllClose(result, math_utils.power_sums_tensor( array_size, transition, addition).eval()) @@ -66,7 +66,7 @@ class MathUtilsTest(test.TestCase): result = [] for i in range(powers.shape[0]): result.append(numpy.linalg.matrix_power(matrix, powers[i])) - with self.test_session(): + with self.cached_session(): self.assertAllClose(result, math_utils.matrix_to_powers(matrix, powers).eval(), rtol=1e-5, @@ -78,7 +78,7 @@ class MathUtilsTest(test.TestCase): result = [] for i in range(batch.shape[0]): result.append(numpy.linalg.matrix_power(batch[i], powers[i])) - with self.test_session(): + with self.cached_session(): # TODO(allenl): Numerical errors seem to be creeping in. Maybe it can be # made slightly more stable? self.assertAllClose(result, @@ -91,7 +91,7 @@ class MathUtilsTest(test.TestCase): left_transpose = numpy.transpose(left, [0, 2, 1]) right = numpy.random.normal(size=[2, 3]).astype(numpy.float32) expected_result = numpy.dot(left, right) - with self.test_session(): + with self.cached_session(): self.assertAllClose(expected_result, math_utils.batch_times_matrix( left, right).eval()) @@ -114,7 +114,7 @@ class MathUtilsTest(test.TestCase): right_transpose = numpy.transpose(right, [0, 2, 1]) expected_result = numpy.transpose(numpy.dot(right_transpose, left.T), [0, 2, 1]) - with self.test_session(): + with self.cached_session(): self.assertAllClose(expected_result, math_utils.matrix_times_batch( left, right).eval()) @@ -132,7 +132,7 @@ class MathUtilsTest(test.TestCase): adj_x=True, adj_y=True).eval()) def test_make_diagonal_undefined_shapes(self): - with self.test_session(): + with self.cached_session(): completely_undefined = array_ops.placeholder(dtype=dtypes.float32) partly_undefined = array_ops.placeholder( shape=[None, None], dtype=dtypes.float32) @@ -152,7 +152,7 @@ class MathUtilsTest(test.TestCase): [5., 6.]]})) def test_make_diagonal_mostly_defined_shapes(self): - with self.test_session(): + with self.cached_session(): mostly_defined = array_ops.placeholder( shape=[None, 2], dtype=dtypes.float32) blocked = math_utils.block_diagonal([[[2.]], @@ -192,7 +192,7 @@ class TestMakeToeplitzMatrix(test.TestCase): def _test_make_toeplitz_matrix(self, inputs, output_expected): output_tf = math_utils.make_toeplitz_matrix(inputs) - with self.test_session() as sess: + with self.cached_session() as sess: output_tf_np = sess.run(output_tf) self.assertAllClose(output_tf_np, output_expected) @@ -201,13 +201,13 @@ class TestMakeCovarianceMatrix(test.TestCase): def test_zero_size_matrix(self): raw = numpy.zeros([0, 0]) - with self.test_session(): + with self.cached_session(): constructed = math_utils.sign_magnitude_positive_definite(raw=raw).eval() self.assertEqual((0, 0), constructed.shape) def test_sign_magnitude_positive_definite(self): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): matrix_tensor = math_utils.sign_magnitude_positive_definite( raw=constant_op.constant([[-1., -2.], [3., 4.]], dtype=dtype), off_diagonal_scale=constant_op.constant(-1., dtype=dtype), @@ -230,7 +230,8 @@ class TestLookupTable(test.TestCase): name="test_lookup") def stack_tensor(base_tensor): return array_ops.stack([base_tensor + 1, base_tensor + 2]) - with self.test_session() as session: + + with self.cached_session() as session: ((float_output, double_output), int_output) = session.run( hash_table.lookup([2, 1, 0])) def expected_output_before_insert(base_tensor): diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py index cfd31cc70d..a049dbe773 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py @@ -29,7 +29,7 @@ class ModelUtilsTest(test.TestCase): def test_parameter_switching(self): parameter = array_ops.constant(5) overridden_parameter = array_ops.constant(3) - with self.test_session(): + with self.cached_session(): getter = model_utils.parameter_switch({overridden_parameter: 4}) self.assertEqual(5, getter(parameter)) self.assertEqual(4, getter(overridden_parameter)) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py index 5f7e3da2db..42ba6e1c25 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py @@ -127,7 +127,7 @@ class ChainingStateManagerTest(test.TestCase): chainer.initialize_graph(model=stub_model) model_outputs = chainer.define_loss( model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN) - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) @@ -178,7 +178,7 @@ class ChainingStateManagerTest(test.TestCase): result_model_outputs = chainer.define_loss( model=stub_model, features=result_input_fn()[0], mode=estimator_lib.ModeKeys.TRAIN) - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) @@ -221,7 +221,7 @@ class ChainingStateManagerTest(test.TestCase): chainer.initialize_graph(model=stub_model) model_outputs = chainer.define_loss( model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN) - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() coordinator = coordinator_lib.Coordinator() queue_runner_impl.start_queue_runners(session, coord=coordinator) diff --git a/tensorflow/python/framework/file_system_test.py b/tensorflow/python/framework/file_system_test.py index 5eb59141a2..6901715e5d 100644 --- a/tensorflow/python/framework/file_system_test.py +++ b/tensorflow/python/framework/file_system_test.py @@ -37,7 +37,7 @@ class FileSystemTest(test.TestCase): load_library.load_file_system_library(file_system_library) def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: reader = io_ops.WholeFileReader("test_reader") queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) queue.enqueue_many([["test://foo"]]).run() diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index ee723bacaf..903768a039 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -419,7 +419,7 @@ class FunctionTest(test.TestCase): with ops.control_dependencies([z]): return x * 2 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): z = Foo(constant_op.constant(3.0)) self.assertAllEqual(z.eval(), 6.0) @@ -434,7 +434,7 @@ class FunctionTest(test.TestCase): # Foo contains a stateful op (Assert). self.assertEqual([("Assert", "Assert")], Foo.stateful_ops) g = ops.Graph() - with g.as_default(), self.test_session(): + with g.as_default(), self.cached_session(): self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "assertion failed.*-3"): @@ -448,7 +448,7 @@ class FunctionTest(test.TestCase): [control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]): return array_ops.identity(x) - with self.test_session(): + with self.cached_session(): self.assertEqual(1.0, MyFn(1.0).eval()) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "assertion"): @@ -667,7 +667,7 @@ class FunctionTest(test.TestCase): with ops.Graph().as_default(): z = CubeXPlusY(3.0, -2.0) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(z.eval(), 25.0) def testNestedDefinedFunction(self): @@ -683,7 +683,7 @@ class FunctionTest(test.TestCase): with ops.Graph().as_default(): z = CubeXPlusY(3.0, -2.0) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(z.eval(), 25.0) def testUnusedFunction(self): diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 18e7d8aa14..2b4d8e7299 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -396,7 +396,7 @@ class ImportGraphDefTest(test.TestCase): # Run the imported graph. # TODO(b/76173421): make this work (currently DCHECKS) - # with self.test_session() as sess: + # with self.cached_session() as sess: # sess.run(imported_init) # self.assertEqual(sess.run(imported_var), 1.0) # self.assertEqual(sess.run(imported_assign), 2.0) @@ -417,7 +417,7 @@ class ImportGraphDefTest(test.TestCase): imported_r, = importer.import_graph_def(graph_def, return_elements=[r.name]) self.assertEqual(imported_r.name, "import/" + r.name) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(sess.run(imported_r), 10) def testImportWhileLoopInCond(self): @@ -436,7 +436,7 @@ class ImportGraphDefTest(test.TestCase): pred = array_ops.placeholder(dtypes.bool) out = control_flow_ops.cond(pred, ImportFn, lambda: constant_op.constant(1)) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(sess.run(out, {pred: True}), 10) self.assertEqual(sess.run(out, {pred: False}), 1) @@ -457,7 +457,7 @@ class ImportGraphDefTest(test.TestCase): out = control_flow_ops.while_loop( lambda i: i < 2, ImportFn, [0], shape_invariants=[tensor_shape.TensorShape(None)]) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(sess.run(out), 10) def testTypeMismatchInGraphDef(self): @@ -929,7 +929,7 @@ class ImportGraphDefTest(test.TestCase): input_map={"a:0": constant_op.constant(5.0)}, name="", return_elements=["id:0"]) - with self.test_session(): + with self.cached_session(): self.assertEqual(5.0, t.eval()) def testInvalidInputForReturnOperations(self): @@ -958,7 +958,7 @@ class ImportGraphDefTest(test.TestCase): array_ops.stack([c, c], name="pack") gdef = g.as_graph_def() - with self.test_session(): + with self.cached_session(): pack, = importer.import_graph_def(gdef, return_elements=["pack"]) self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0]) @@ -1063,7 +1063,7 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual([10], biases_grad.get_shape()) def testLargeGraph(self): - with self.test_session(): + with self.cached_session(): # The default message byte limit is 64M. Ours is 2G with a warning at 512. # Adding a 130M entries float32 tensor should exceed the warning, but not # the hard limit. @@ -1254,7 +1254,7 @@ class ImportGraphDefTest(test.TestCase): z = TestFunc() - with self.test_session(): + with self.cached_session(): z_val = z.eval() self.assertEqual(z_val, -2.0) @@ -1284,7 +1284,7 @@ class ImportGraphDefTest(test.TestCase): z2 = importer.import_graph_def(gdef, return_elements=["z:0"], input_map=input_map)[0] - with self.test_session() as sess: + with self.cached_session() as sess: z1_val, z2_val = sess.run((z1, z2)) self.assertAllEqual(z1_val, z2_val) diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 6e5f7aafac..fc98b91a01 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -117,7 +117,7 @@ class SimpleMetaGraphTest(test.TestCase): self.assertEqual(new_output_value, output_value) def testStrippedOpListNestedFunctions(self): - with self.test_session(): + with self.cached_session(): # Square two levels deep @function.Defun(dtypes.int32) def f0(x): @@ -169,7 +169,7 @@ class SimpleMetaGraphTest(test.TestCase): # and "Tout" maps to complex64. Since these attr values map to their # defaults, they must be stripped unless stripping of default attrs is # disabled. - with self.test_session(): + with self.cached_session(): real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real") imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") @@ -212,7 +212,8 @@ class SimpleMetaGraphTest(test.TestCase): def testDefaultAttrStrippingNestedFunctions(self): """Verifies that default attributes are stripped from function node defs.""" - with self.test_session(): + with self.cached_session(): + @function.Defun(dtypes.float32, dtypes.float32) def f0(i, j): return math_ops.complex(i, j, name="double_nested_complex") @@ -251,7 +252,7 @@ class SimpleMetaGraphTest(test.TestCase): meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() meta_info_def.stripped_op_list.op.add() - with self.test_session(): + with self.cached_session(): meta_graph_def = meta_graph.create_meta_graph_def( meta_info_def=meta_info_def, graph_def=graph_def, strip_default_attrs=True) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index ced0581402..d59adf3d48 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -58,12 +58,12 @@ ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn) class ResourceTest(test_util.TensorFlowTestCase): def testBuildGraph(self): - with self.test_session(): + with self.cached_session(): pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") test_ops.resource_create_op(pt).run() def testInitialize(self): - with self.test_session(): + with self.cached_session(): handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") resources.register_resource( handle=handle, @@ -100,35 +100,35 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): pass def testAddShape(self): - with self.test_session(): + with self.cached_session(): a = array_ops.zeros([2, 3]) b = array_ops.ones([1, 3]) c = a + b self.assertEqual([2, 3], c.shape) def testUnknownDim(self): - with self.test_session(): + with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) c = a + b self.assertEqual([2, None, 3], c.shape.as_list()) def testUnknownShape(self): - with self.test_session(): + with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=None) b = array_ops.ones([1, 3]) c = a + b self.assertEqual(tensor_shape.unknown_shape(), c.shape) def testScalarShape(self): - with self.test_session(): + with self.cached_session(): a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) b = array_ops.ones([]) c = a + b self.assertEqual(tensor_shape.scalar(), c.shape) def testShapeFunctionError(self): - with self.test_session(): + with self.cached_session(): a = array_ops.ones([1, 2, 3]) b = array_ops.ones([4, 5, 6]) with self.assertRaisesRegexp( @@ -141,7 +141,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): class IndexedSlicesTest(test_util.TensorFlowTestCase): def testToTensor(self): - with self.test_session(): + with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) dense_shape = constant_op.constant([3, 2]) @@ -150,7 +150,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase): self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]]) def testNegation(self): - with self.test_session(): + with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) x = -ops.IndexedSlices(values, indices) @@ -158,7 +158,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase): self.assertAllEqual(x.indices.eval(), [0, 2]) def testScalarMul(self): - with self.test_session(): + with self.cached_session(): values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) indices = constant_op.constant([0, 2]) x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) @@ -307,14 +307,14 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) def testConvertToTensorNestedArray(self): - with self.test_session(): + with self.cached_session(): values = [[2], [3], [5], [7]] tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, tensor.eval()) def testShapeTuple(self): - with self.test_session(): + with self.cached_session(): c = constant_op.constant(1) self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access @@ -328,14 +328,14 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(converted, ops.EagerTensor)) def testConvertToTensorNestedTuple(self): - with self.test_session(): + with self.cached_session(): values = ((2,), (3,), (5,), (7,)) tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, ops.convert_to_tensor(values).eval()) def testConvertToTensorNestedTensors(self): - with self.test_session(): + with self.cached_session(): values = ((2,), (3,), (5,), (7,)) tensor = ops.convert_to_tensor( [constant_op.constant(row) for row in values]) @@ -347,25 +347,25 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertAllEqual(values, tensor.eval()) def testConvertToTensorNestedMix(self): - with self.test_session(): + with self.cached_session(): values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) tensor = ops.convert_to_tensor(values) self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval()) def testConvertToTensorPreferred(self): - with self.test_session(): + with self.cached_session(): values = [2, 3, 5, 7] tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) self.assertEqual(dtypes.float32, tensor.dtype) - with self.test_session(): + with self.cached_session(): # Convert empty tensor to anything. values = [] tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) self.assertEqual(dtypes.int64, tensor.dtype) - with self.test_session(): + with self.cached_session(): # The preferred dtype is a type error and will convert to # float32 instead. values = [1.23] @@ -941,7 +941,7 @@ class NameStackTest(test_util.TensorFlowTestCase): self.assertEqual("bar_2", g.unique_name("bar")) def testNameAndVariableScope(self): - with self.test_session() as sess: + with self.cached_session() as sess: with sess.graph.name_scope("l0"): with variable_scope.variable_scope("l1"): with sess.graph.name_scope("l1") as scope: @@ -2164,7 +2164,7 @@ class InitScopeTest(test_util.TensorFlowTestCase): g = ops.Graph() with g.as_default(): - with self.test_session(): + with self.cached_session(): # First ensure that graphs that are not building functions are # not escaped. function_with_variables("foo") @@ -2416,11 +2416,11 @@ class AttrScopeTest(test_util.TensorFlowTestCase): return (a, b) def testNoLabel(self): - with self.test_session(): + with self.cached_session(): self.assertAllEqual((None, None), self._get_test_attrs()) def testLabelMap(self): - with self.test_session() as sess: + with self.cached_session() as sess: a1 = self._get_test_attrs() with sess.graph._attr_scope({ "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) @@ -2454,12 +2454,12 @@ ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) class KernelLabelTest(test_util.TensorFlowTestCase): def testNoLabel(self): - with self.test_session(): + with self.cached_session(): self.assertAllEqual(b"My label is: default", test_ops.kernel_label().eval()) def testLabelMap(self): - with self.test_session() as sess: + with self.cached_session() as sess: default_1 = test_ops.kernel_label() # pylint: disable=protected-access with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): @@ -2900,7 +2900,7 @@ class NameScopeTest(test_util.TensorFlowTestCase): class TracebackTest(test_util.TensorFlowTestCase): def testTracebackWithStartLines(self): - with self.test_session() as sess: + with self.cached_session() as sess: a = constant_op.constant(2.0) sess.run( a, diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index 2bcfbc17df..22423c4f58 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -45,7 +45,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase): self.assertEqual(sp.dense_shape.dtype, dtypes.int64) self.assertEqual(sp.get_shape(), (4, 5)) - with self.test_session() as sess: + with self.cached_session() as sess: value = sp.eval() self.assertAllEqual(indices, value.indices) self.assertAllEqual(values, value.values) @@ -81,14 +81,14 @@ class SparseTensorTest(test_util.TensorFlowTestCase): class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase): def test_convert_dense(self): - with self.test_session(): + with self.cached_session(): value = [42, 43] from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor( value) self.assertAllEqual(value, from_value.eval()) def test_convert_sparse(self): - with self.test_session(): + with self.cached_session(): indices = [[0, 1], [1, 0]] values = [42, 43] shape = [2, 2] diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py index d6de45fdc4..1d594e4078 100644 --- a/tensorflow/python/framework/subscribe_test.py +++ b/tensorflow/python/framework/subscribe_test.py @@ -65,7 +65,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): self.assertFalse(c0.op in d.op.control_inputs) self.assertTrue(c.op in d.op.control_inputs) - with self.test_session() as sess: + with self.cached_session() as sess: c_out = sess.run([c]) n_out = sess.run([n]) d_out = sess.run([d]) @@ -144,7 +144,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): b = subscribe.subscribe(b, lambda t: script_ops.py_func(sub, [t], [t.dtype])) - with self.test_session() as sess: + with self.cached_session() as sess: c_out = sess.run([c]) d_out = sess.run([d]) @@ -204,7 +204,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): self.assertIs(c_sub, c_sub3) # Expect the three side effect graphs to have been evaluated. - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([c_sub]) self.assertIn('graph1', shared) self.assertIn('graph2', shared) @@ -227,7 +227,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): v1, lambda t: script_ops.py_func(sub, [t], [t.dtype])) self.assertTrue(subscribe._is_subscribed_identity(v1_sub)) - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize the variables first. sess.run([v1.initializer]) sess.run([v2.initializer]) @@ -272,7 +272,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): self.assertIs(tensor_array_sub, tensor_array.handle) self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([reader]) self.assertEqual(0, len(shared)) @@ -303,7 +303,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): subscribe.subscribe(sparse_add.op.outputs, lambda t: script_ops.py_func(sub, [t], [t.dtype])) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([neg]) # All three ops have been processed. @@ -374,7 +374,7 @@ class SubscribeTest(test_util.TensorFlowTestCase): # Verify that sub(x1) and sub(branch) are not. self.assertIsNot(context(subscriptions[0]), context(subscriptions[1])) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(cond) self.assertEqual(3, len(results)) diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 395cf43b3f..bdf759f220 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -768,7 +768,7 @@ class TensorUtilTest(test.TestCase): def __array__(self, dtype=None): return np.asarray(self.array, dtype) - with self.test_session() as sess: + with self.cached_session() as sess: ma = MockArray(np.array([10, 20, 30])) t = ops.convert_to_tensor(ma) a = sess.run(t) diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 441f3f4948..148dd23be7 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -48,7 +48,7 @@ except ImportError: class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): def test_weight_loading(self): - with self.test_session(): + with self.cached_session(): a = keras.layers.Input(shape=(2,)) x = keras.layers.Dense(3)(a) b = keras.layers.Dense(1)(x) @@ -208,7 +208,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): })) def test_preprocess_weights_for_loading_rnn_should_be_idempotent( self, layer_class, layer_args): - with self.test_session(): + with self.cached_session(): layer = layer_class(**layer_args) layer.build(input_shape=layer_args.get('input_shape')) weights1 = layer.get_weights() @@ -232,7 +232,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): batch_size = 5 num_classes = 2 - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) model.add(keras.layers.Dense(num_classes)) @@ -261,7 +261,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): num_hidden = 5 input_dim = 3 num_classes = 2 - with self.test_session(): + with self.cached_session(): ref_model = keras.models.Sequential() ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, name='d1')) @@ -298,7 +298,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): num_hidden = 5 input_dim = 3 num_classes = 2 - with self.test_session(): + with self.cached_session(): ref_model = keras.models.Sequential() ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, name='d1')) @@ -333,7 +333,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.RepeatVector(3)) @@ -378,7 +378,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.RepeatVector(3)) @@ -402,7 +402,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): # test with custom optimizer, loss class CustomOp(keras.optimizers.RMSprop): @@ -438,7 +438,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): inputs = keras.layers.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) output = keras.layers.Dense(3)(x) @@ -474,7 +474,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.Dense(3)) @@ -490,7 +490,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.Dense(3)) @@ -508,7 +508,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.Dense(3)) @@ -522,7 +522,7 @@ class TestWholeModelSaving(test.TestCase): os.remove(fname) def test_saving_lambda_numpy_array_arguments(self): - with self.test_session(): + with self.cached_session(): if h5py is None: self.skipTest('h5py required to run this test') @@ -548,7 +548,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): # This layer name will make the `layers_name` HDF5 attribute blow # out of proportion. Note that it fits into the internal HDF5 # attribute memory limit on its own but because h5py converts @@ -589,7 +589,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): x = keras.Input(shape=(2,), name='nested_model_input') f = x for i in range(4): @@ -634,7 +634,7 @@ class TestWholeModelSaving(test.TestCase): if h5py is None: self.skipTest('h5py required to run this test') - with self.test_session(): + with self.cached_session(): inputs = keras.Input(shape=(3,)) x = keras.layers.Dense(2)(inputs) outputs = keras.layers.Dense(3)(x) @@ -703,7 +703,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_tensorflow_format_overwrite(self): - with self.test_session() as session: + with self.cached_session() as session: model = SubclassedModel() temp_dir = self.get_temp_dir() prefix = os.path.join(temp_dir, 'ckpt') @@ -760,7 +760,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self.assertEqual(len(graph.get_operations()), op_count) def _weight_loading_test_template(self, make_model_fn): - with self.test_session(): + with self.cached_session(): model = make_model_fn() model.compile( loss='mse', @@ -822,7 +822,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): def _new_layer_weight_loading_test_template( self, first_model_fn, second_model_fn, restore_init_fn): - with self.test_session() as session: + with self.cached_session() as session: model = first_model_fn() temp_dir = self.get_temp_dir() prefix = os.path.join(temp_dir, 'ckpt') diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 28af8d61bc..9d615c9b0c 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -132,7 +132,7 @@ class TestSequential(test.TestCase, parameterized.TestCase): @parameterized.parameters((True,), (False,)) def test_training_and_eval_methods_on_symbolic_tensors(self, deferred): - with self.test_session(): + with self.cached_session(): def get_model(): if deferred: @@ -222,7 +222,7 @@ class TestSequential(test.TestCase, parameterized.TestCase): val_a = np.random.random((10, 4)) val_out = np.random.random((10, 4)) - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.BatchNormalization(input_shape=(4,))) assert model.updates diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 1fcd77d7f6..061db8ee34 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -342,7 +342,7 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual(model.non_trainable_weights, weights) def test_learning_phase(self): - with self.test_session(): + with self.cached_session(): a = keras.layers.Input(shape=(32,), name='input_a') b = keras.layers.Input(shape=(32,), name='input_b') @@ -458,7 +458,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(dense.get_output_mask_at(1), None) def test_multi_input_layer(self): - with self.test_session(): + with self.cached_session(): # test multi-input layer a = keras.layers.Input(shape=(32,), name='input_a') b = keras.layers.Input(shape=(32,), name='input_b') @@ -530,7 +530,7 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)]) def test_recursion(self): - with self.test_session(): + with self.cached_session(): a = keras.layers.Input(shape=(32,), name='input_a') b = keras.layers.Input(shape=(32,), name='input_b') @@ -591,7 +591,7 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual([x.shape for x in fn_outputs], [(10, 7), (10, 64)]) def test_multi_input_multi_output_recursion(self): - with self.test_session(): + with self.cached_session(): # test multi-input multi-output a = keras.layers.Input(shape=(32,), name='input_a') b = keras.layers.Input(shape=(32,), name='input_b') @@ -816,7 +816,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(loss, 4.) def test_layer_sharing_at_heterogenous_depth(self): - with self.test_session(): + with self.cached_session(): x_val = np.random.random((10, 5)) x = input_layer_lib.Input(shape=(5,)) @@ -837,7 +837,7 @@ class TopologyConstructionTest(test.TestCase): self.assertAllClose(output_val, output_val_2, atol=1e-6) def test_layer_sharing_at_heterogenous_depth_with_concat(self): - with self.test_session(): + with self.cached_session(): input_shape = (16, 9, 3) input_layer = input_layer_lib.Input(shape=input_shape) @@ -864,7 +864,7 @@ class TopologyConstructionTest(test.TestCase): self.assertAllClose(output_val, output_val_2, atol=1e-6) def test_explicit_training_argument(self): - with self.test_session(): + with self.cached_session(): a = keras.layers.Input(shape=(2,)) b = keras.layers.Dropout(0.5)(a) base_model = keras.models.Model(a, b) @@ -887,7 +887,8 @@ class TopologyConstructionTest(test.TestCase): def test_multi_output_model_with_none_masking(self): - with self.test_session(): + with self.cached_session(): + def func(x): return [x * 0.2, x * 0.3] @@ -1186,7 +1187,7 @@ class GraphUtilsTest(test.TestCase): def testGetReachableFromInputs(self): - with self.test_session(): + with self.cached_session(): pl_1 = array_ops.placeholder(shape=None, dtype='float32') pl_2 = array_ops.placeholder(shape=None, dtype='float32') pl_3 = array_ops.placeholder(shape=None, dtype='float32') diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 1d0d113e40..8938333b1a 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -366,7 +366,7 @@ class TrainingTest(test.TestCase): if scipy_sparse is None: return - with self.test_session(): + with self.cached_session(): test_inputs = [ scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2) ] @@ -389,7 +389,7 @@ class TrainingTest(test.TestCase): model.evaluate(test_inputs, test_outputs, batch_size=2) def test_compile_with_sparse_placeholders(self): - with self.test_session(): + with self.cached_session(): input_layer = keras.layers.Input(shape=(10,), sparse=True) weights = variables_lib.Variable( np.ones((10, 1)).astype(np.float32), name='weights') @@ -405,7 +405,7 @@ class TrainingTest(test.TestCase): val_a = np.random.random((10, 4)) val_out = np.random.random((10, 4)) - with self.test_session(): + with self.cached_session(): a = keras.layers.Input(shape=(4,)) layer = keras.layers.BatchNormalization(input_shape=(4,)) b = layer(a) @@ -441,7 +441,7 @@ class TrainingTest(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_compile_warning_for_loss_missing_output(self): - with self.test_session(): + with self.cached_session(): inp = keras.layers.Input(shape=(16,), name='input_a') out_1 = keras.layers.Dense(8, name='dense_1')(inp) out_2 = keras.layers.Dense(3, activation='softmax', name='dense_2')(out_1) @@ -654,7 +654,7 @@ class LossWeightingTest(test.TestCase): timesteps = 3 learning_rate = 0.001 - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add( keras.layers.TimeDistributed( @@ -741,7 +741,7 @@ class LossWeightingTest(test.TestCase): timesteps = 3 learning_rate = 0.001 - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add( keras.layers.TimeDistributed( @@ -810,7 +810,7 @@ class LossWeightingTest(test.TestCase): timesteps = 3 learning_rate = 0.001 - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add( keras.layers.TimeDistributed( @@ -854,7 +854,7 @@ class LossMaskingTest(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_masking_graph_sequential(self): - with self.test_session(): + with self.cached_session(): x = np.array([[[1], [1]], [[0], [0]]]) model = keras.models.Sequential() model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) @@ -868,7 +868,7 @@ class LossMaskingTest(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_masking_deferred_sequential(self): - with self.test_session(): + with self.cached_session(): x = np.array([[[1], [1]], [[0], [0]]]) model = keras.models.Sequential() model.add(keras.layers.Masking(mask_value=0)) @@ -882,7 +882,7 @@ class LossMaskingTest(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_masking_functional(self): - with self.test_session(): + with self.cached_session(): x = np.array([[[1], [1]], [[0], [0]]]) inputs = keras.layers.Input((2, 1)) outputs = keras.layers.Masking(mask_value=0)(inputs) @@ -912,7 +912,7 @@ class LossMaskingTest(test.TestCase): def compute_output_shape(self, input_shape): return input_shape - with self.test_session(): + with self.cached_session(): x = np.random.random((5, 3)) inputs = keras.layers.Input((3,)) masked = keras.layers.Masking(mask_value=0)(inputs) @@ -924,7 +924,7 @@ class LossMaskingTest(test.TestCase): model.train_on_batch(x, y) def test_loss_masking(self): - with self.test_session(): + with self.cached_session(): weighted_loss = weighted_masked_objective(keras.losses.get('mae')) shape = (3, 4, 2) x = np.arange(24).reshape(shape) @@ -945,12 +945,12 @@ class LossMaskingTest(test.TestCase): class LearningPhaseTest(test.TestCase): def test_empty_model_no_learning_phase(self): - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() self.assertFalse(model.uses_learning_phase) def test_dropout_has_learning_phase(self): - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(2, input_dim=3)) model.add(keras.layers.Dropout(0.5)) @@ -961,7 +961,7 @@ class LearningPhaseTest(test.TestCase): class TestDynamicTrainability(test.TestCase): def test_trainable_warning(self): - with self.test_session(): + with self.cached_session(): x = np.random.random((5, 3)) y = np.random.random((5, 2)) @@ -974,7 +974,7 @@ class TestDynamicTrainability(test.TestCase): self.assertRaises(Warning) def test_trainable_argument(self): - with self.test_session(): + with self.cached_session(): x = np.random.random((5, 3)) y = np.random.random((5, 2)) @@ -997,7 +997,7 @@ class TestDynamicTrainability(test.TestCase): self.assertAllClose(out, out_2) def test_layer_trainability_switch(self): - with self.test_session(): + with self.cached_session(): # with constructor argument, in Sequential model = keras.models.Sequential() model.add(keras.layers.Dense(2, trainable=False, input_dim=1)) @@ -1027,7 +1027,7 @@ class TestDynamicTrainability(test.TestCase): self.assertListEqual(model.trainable_weights, []) def test_model_trainability_switch(self): - with self.test_session(): + with self.cached_session(): # a non-trainable model has no trainable weights x = keras.layers.Input(shape=(1,)) y = keras.layers.Dense(2)(x) @@ -1042,7 +1042,7 @@ class TestDynamicTrainability(test.TestCase): self.assertListEqual(model.trainable_weights, []) def test_nested_model_trainability(self): - with self.test_session(): + with self.cached_session(): # a Sequential inside a Model inner_model = keras.models.Sequential() inner_model.add(keras.layers.Dense(2, input_dim=1)) @@ -1121,7 +1121,7 @@ class TestGeneratorMethods(test.TestCase): y = arr_labels[start: end] yield x, y - with self.test_session(): + with self.cached_session(): x = keras.Input((2,)) y = keras.layers.Dense(1)(x) fn_model = keras.models.Model(x, y) @@ -1207,7 +1207,7 @@ class TestGeneratorMethods(test.TestCase): w = arr_sample_weights[start: end] yield x, y, w - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_shape=(2,))) model.compile( @@ -1244,7 +1244,7 @@ class TestGeneratorMethods(test.TestCase): while 1: yield 0 - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_shape=(2,))) model.compile(loss='mse', optimizer='sgd') @@ -1302,7 +1302,7 @@ class TestGeneratorMethods(test.TestCase): w = arr_sample_weights[start: end] yield x, y, w - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_shape=(2,))) model.compile(loss='mse', optimizer='sgd') @@ -1360,7 +1360,7 @@ class TestTrainingUtils(test.TestCase): class TestTrainingWithDataTensors(test.TestCase): def test_training_and_eval_methods_on_symbolic_tensors_single_io(self): - with self.test_session(): + with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) @@ -1400,7 +1400,7 @@ class TestTrainingWithDataTensors(test.TestCase): validation_data=(inputs, targets), validation_steps=2) def test_training_and_eval_methods_on_symbolic_tensors_multi_io(self): - with self.test_session(): + with self.cached_session(): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(3,), name='input_b') @@ -1501,7 +1501,7 @@ class TestTrainingWithDataTensors(test.TestCase): by only passing them data for the placeholder inputs in the model. """ - with self.test_session(): + with self.cached_session(): input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 3)) @@ -1632,7 +1632,7 @@ class TestTrainingWithDataTensors(test.TestCase): self.assertEqual(out.shape, (10 * 3, 4)) def test_model_with_partial_loss(self): - with self.test_session(): + with self.cached_session(): a = keras.Input(shape=(3,), name='input_a') a_2 = keras.layers.Dense(4, name='dense_1')(a) dp = keras.layers.Dropout(0.5, name='dropout') @@ -1673,7 +1673,7 @@ class TestTrainingWithDataTensors(test.TestCase): _ = model.evaluate(input_a_np, [output_a_np]) def test_model_with_external_loss(self): - with self.test_session(): + with self.cached_session(): # None loss, only regularization loss. a = keras.Input(shape=(3,), name='input_a') a_2 = keras.layers.Dense(4, name='dense_1', @@ -1803,7 +1803,7 @@ class TestTrainingWithDataTensors(test.TestCase): self.assertEqual(out[1].shape, (10 * 3, 4)) def test_target_tensors(self): - with self.test_session(): + with self.cached_session(): # single-output, as list model = keras.models.Sequential() model.add(keras.layers.Dense(4, input_shape=(4,), name='dense')) @@ -1864,7 +1864,7 @@ class TestTrainingWithDataTensors(test.TestCase): sample_weight={'dense_a': np.random.random((10,))}) def test_model_custom_target_tensors(self): - with self.test_session(): + with self.cached_session(): a = keras.Input(shape=(3,), name='input_a') b = keras.Input(shape=(3,), name='input_b') @@ -2154,7 +2154,7 @@ class TestTrainingWithDataset(test.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) def test_dataset_input_shape_validation(self): - with self.test_session(): + with self.cached_session(): model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') @@ -2333,7 +2333,7 @@ class TestTrainingWithMetrics(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_metrics_masking(self): - with self.test_session(): + with self.cached_session(): np.random.seed(1337) model = keras.models.Sequential() model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) -- cgit v1.2.3 From 700297614b694ece80b35753ecbc451a5e15fa77 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 15:44:20 -0700 Subject: Deterministic ordering of the hyperparameters in optimizer_v2 PiperOrigin-RevId: 212348918 --- tensorflow/contrib/optimizer_v2/optimizer_v2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f6ecaba834..6af59dcfbf 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -214,7 +214,8 @@ class _OptimizerV2State(object): # with that Tensor cast to that dtype. with ops.init_scope(): self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} - for name, (dynamic, value) in hyper.items() if not dynamic} + for name, (dynamic, value) in sorted(hyper.items()) + if not dynamic} self._slots = {} self._non_slot_dict = {} # Extra state to help Optimizers implement Checkpointable. Holds information @@ -231,7 +232,8 @@ class _OptimizerV2State(object): ret._deferred_dependencies = self._deferred_dependencies ret._deferred_slot_restorations = self._deferred_slot_restorations ret._hyper = {name: {None: _resolve(value, name)} - for name, (dynamic, value) in hyper.items() if dynamic} + for name, (dynamic, value) in sorted(hyper.items()) + if dynamic} ret._hyper.update(self._hyper) ret._non_slot_devices = non_slot_devices ret._distribution = distribution -- cgit v1.2.3 From 3253b87d2a79efe8b8ea83c70cbf94285b17ea64 Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Mon, 10 Sep 2018 16:13:40 -0700 Subject: Convert layout_assignment_test to use HloVerifiedTestBase. PiperOrigin-RevId: 212353819 --- tensorflow/compiler/xla/service/BUILD | 1 + .../compiler/xla/service/layout_assignment_test.cc | 105 ++++++++++----------- 2 files changed, 53 insertions(+), 53 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 1965ba1204..f4e24bff34 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2505,6 +2505,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index f8baba03c3..752a61476d 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,7 +49,7 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloTestBase { +class LayoutAssignmentTest : public HloVerifiedTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout, @@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(module.get()) + .Run(module) .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. @@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, f32_4, "param")); auto broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_34, param, {3})); + HloInstruction::CreateBroadcast(f32_34, param, {1})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast)); auto broadcast2 = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_234, tanh, {2})); + HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); auto module = CreateNewModule(); @@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module.get()).status()); + EXPECT_IS_OK(layout_assignment.Run(module).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); - module = + std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); EXPECT_EQ(Status::OK(), backend() .compiler() - ->RunBackend(std::move(module), + ->RunBackend(std::move(compiled_module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .status()); @@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(&module(), &computation_layout); - EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module.get()).status(); + Status error_status = layout_assignment.Run(module).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -851,14 +851,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(&module(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); - EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::GetSubshape( - FindInstruction(module.get(), "send")->shape(), {0}), - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); + EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { @@ -873,11 +872,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -901,11 +900,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -932,11 +931,11 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -963,11 +962,11 @@ TEST_F(LayoutAssignmentTest, } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -985,11 +984,11 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = -- cgit v1.2.3 From 7b8ffbe4c1da2c53551645fd023df577c43fa16c Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Mon, 10 Sep 2018 16:15:34 -0700 Subject: Fix model_to_estimator bug where subclassed model receives input list from estimator model_fn. PiperOrigin-RevId: 212354111 --- tensorflow/python/estimator/BUILD | 5 ----- tensorflow/python/keras/models.py | 2 ++ 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 00da335fef..4001ffdd6b 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -684,12 +684,7 @@ py_test( shard_count = 4, srcs_version = "PY2AND3", tags = [ - "manual", # b/112769036, b/113907597 - "no_oss", # b/112769036, b/113907597 "no_windows", - "noasan", # b/114304340 - "nomsan", - "notsan", # b/67510291 ], deps = [ ":keras", diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index f0733a9105..41c5e3cccf 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -444,6 +444,8 @@ def clone_and_build_model( clone = model _in_place_subclassed_model_reset(clone) if input_tensors is not None: + if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1: + input_tensors = input_tensors[0] clone._set_inputs(input_tensors) # Compile/Build model -- cgit v1.2.3 From 5b853d4b2ca622fb038733e435d964c8f5b78edd Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Mon, 10 Sep 2018 16:22:20 -0700 Subject: Replace global starter flags with call-specific flags The earlier version of convenient default flags mistakenly applied --build_tests_only to normal "bazel build" calls, which broke pip.sh (and probably invalidated some other things). This resolves that problem by setting flags specific to "test" and "build" commands. PiperOrigin-RevId: 212355193 --- tensorflow/tools/ci_build/ci_parameterized_build.sh | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index c8472102cb..cc09784c1d 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -127,17 +127,19 @@ NO_DOCKER_OPT_FLAG="--genrule_strategy=standalone" DO_DOCKER=1 -BAZEL_CMD="bazel test" -BAZEL_BUILD_ONLY_CMD="bazel build" -BAZEL_CLEAN_CMD="bazel clean" -# Default flags: +# Helpful flags: # --test_summary=detailed: Tell us more about which targets are being built # --keep_going: Don't stop at the first failure; tell us all the failures # --build_tests_only: Don't build targets depended on by tests if the test is # disabled. Also saves some compilation time. Otherwise, # tries to build everything. -DEFAULT_BAZEL_CONFIGS="--test_summary=detailed --build_tests_only --keep_going" +BAZEL_TEST_FLAGS="--test_summary=detailed --build_tests_only --keep_going" +BAZEL_BUILD_FLAGS="--keep_going" + +BAZEL_CMD="bazel test ${BAZEL_TEST_FLAGS}" +BAZEL_BUILD_ONLY_CMD="bazel build ${BAZEL_BUILD_FLAGS}" +BAZEL_CLEAN_CMD="bazel clean" PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh" PIP_TEST_TUTORIALS_FLAG="--test_tutorials" @@ -393,7 +395,7 @@ fi EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false" if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] && - [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then + [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD} fi -- cgit v1.2.3 From 10ebeba9d4617f612bf9b714ed51d44f1d332c5d Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Mon, 10 Sep 2018 16:28:30 -0700 Subject: Move tf.scan benchmark from contrib/eager/examples to eager/benchmarks_test.py Eager execution is over 10x slower than defun/graph execution. bazel run -c opt benchmarks_test -- --benchmarks=MicroBenchmarks.benchmarkScan.* entry { name: "MicroBenchmarks.benchmarkScan" iters: 100 wall_time: 176364.049911 extras { key: "examples_per_sec" value { double_value: 5.67008979722 } } } entry { name: "MicroBenchmarks.benchmarkScanDefun" iters: 100 wall_time: 15466.0701752 extras { key: "examples_per_sec" value { double_value: 64.6576660182 } } } The benchmark deleted by this CL measured graph construction time, whereas this CL does not. PiperOrigin-RevId: 212356196 --- .../contrib/eager/python/examples/scan/BUILD | 25 ---------- .../eager/python/examples/scan/scan_graph_test.py | 54 ---------------------- .../eager/python/examples/scan/scan_test.py | 54 ---------------------- tensorflow/python/eager/benchmarks_test.py | 20 ++++++++ 4 files changed, 20 insertions(+), 133 deletions(-) delete mode 100644 tensorflow/contrib/eager/python/examples/scan/BUILD delete mode 100644 tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py delete mode 100644 tensorflow/contrib/eager/python/examples/scan/scan_test.py diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD deleted file mode 100644 index 638c57d1c9..0000000000 --- a/tensorflow/contrib/eager/python/examples/scan/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -cuda_py_test( - name = "scan_test", - size = "small", - srcs = ["scan_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "scan_graph_test", - size = "small", - srcs = ["scan_graph_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py deleted file mode 100644 index d4b8c8941e..0000000000 --- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for tf.scan under graph mode execution.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time - -import numpy as np -import tensorflow as tf - - -class ScanBenchmark(tf.test.Benchmark): - - def runScan(self, n): - elems = np.arange(n) - start_time = time.time() - sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) - with tf.Session() as sess: - sess.run(sum_op) - wall_time = time.time() - start_time - - self.report_benchmark( - name='scan', - iters=n, - wall_time=wall_time) - - def benchmarkScan16000(self): - self.runScan(16000) - - def benchmarkScan32000(self): - self.runScan(32000) - - def benchmarkScan64000(self): - self.runScan(64000) - - def benchmarkScan128000(self): - self.runScan(128000) - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py deleted file mode 100644 index a02fc24c79..0000000000 --- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for tf.scan under eager execution.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time - -import numpy as np -import tensorflow as tf - - -class ScanBenchmark(tf.test.Benchmark): - - def runScan(self, n): - elems = np.arange(n) - start_time = time.time() - _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) - wall_time = time.time() - start_time - - self.report_benchmark( - name='scan', - iters=n, - wall_time=wall_time) - - def benchmarkScan16000(self): - self.runScan(16000) - - def benchmarkScan32000(self): - self.runScan(32000) - - def benchmarkScan64000(self): - self.runScan(64000) - - def benchmarkScan128000(self): - self.runScan(128000) - - -if __name__ == '__main__': - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 3bdaf0b214..3fe79ef244 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -717,6 +718,25 @@ class MicroBenchmarks(test.Benchmark): assert np.equal(func(), make_keras_model()(data)).all() self._run(func, 30000) + def benchmarkScan(self): + elems = math_ops.range(1600) + + def scan(): + return functional_ops.scan( + lambda a, x: a + x, elems, parallel_iterations=1) + + self._run(scan, 100) + + def benchmarkScanDefun(self): + elems = math_ops.range(1600) + + @function.defun + def scan(): + return functional_ops.scan( + lambda a, x: a + x, elems, parallel_iterations=1) + + self._run(scan, 100) + if __name__ == "__main__": test.main() -- cgit v1.2.3 From c277998e9f82660b1573fd5587780a97db761a65 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Mon, 10 Sep 2018 16:34:28 -0700 Subject: Allow keras.models.load_model to load models that were saved before weighted metrics was added. PiperOrigin-RevId: 212357216 --- tensorflow/python/keras/engine/saving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py index a2eed7cb46..a2f31fda8f 100644 --- a/tensorflow/python/keras/engine/saving.py +++ b/tensorflow/python/keras/engine/saving.py @@ -248,7 +248,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= loss = convert_custom_objects(training_config['loss']) metrics = convert_custom_objects(training_config['metrics']) weighted_metrics = convert_custom_objects( - training_config['weighted_metrics']) + training_config.get('weighted_metrics', None)) sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] -- cgit v1.2.3 From fea74706aaa314cc77ec66c2c986365590e8df27 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Mon, 10 Sep 2018 16:59:51 -0700 Subject: Cleanup cudnn_convolution_runner's interface. Use a struct to pack most of the parameters, so that it's easier to toss them around. PiperOrigin-RevId: 212361326 --- .../compiler/xla/service/gpu/convolution_thunk.cc | 7 +- .../gpu/cudnn_convolution_algorithm_picker.cc | 8 +-- .../xla/service/gpu/cudnn_convolution_runner.cc | 81 +++++++++------------- .../xla/service/gpu/cudnn_convolution_runner.h | 44 ++++++------ 4 files changed, 65 insertions(+), 75 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 05448d863d..9b567cf4a8 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -72,9 +72,10 @@ Status ConvolutionThunk::ExecuteOnStream( auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); TF_RETURN_IF_ERROR(RunCudnnConvolution( - convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, - feature_group_count_, algorithm_config, stream)); + {convolution_kind_, &input_shape_, &filter_shape_, &output_shape_, + input_data, filter_data, output_data, &window_, &dim_nums_, + feature_group_count_, algorithm_config}, + scratch, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5c2555148a..8fcff84173 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -295,10 +295,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << instr->ToString(); bool launch_ok = - RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, - filter_buf, output_buf, &scratch_allocator, window, dnums, - feature_group_count, AlgorithmConfig(alg), &stream, &profile_result) + RunCudnnConvolution({kind, &input_shape, &filter_shape, &output_shape, + input_buf, filter_buf, output_buf, &window, &dnums, + feature_group_count, AlgorithmConfig(alg)}, + &scratch_allocator, &stream, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 05125e9d1f..2a86ac265e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator { }; template -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, DeviceMemory input_buf, - DeviceMemory filter_buf, DeviceMemory output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - AlgorithmConfig algorithm, Stream* stream, - ProfileResult* profile_result /*= nullptr*/) { +Status RunCudnnConvolutionImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + CudnnConvKind kind = params.kind; + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + DeviceMemory input_buf(params.input_buf); + DeviceMemory filter_buf(params.filter_buf); + DeviceMemory output_buf(params.output_buf); + const Window& window = *params.window; + const ConvolutionDimensionNumbers& dnums = *params.dnums; + int64 feature_group_count = params.feature_group_count; + AlgorithmConfig algorithm = params.algorithm; + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); @@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, - output_buf, &scratch_allocator, window, dnums, feature_group_count, - algorithm, stream, profile_result); + return RunCudnnConvolution(params, &scratch_allocator, stream, + profile_result); } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = output_shape.element_type(); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + PrimitiveType output_primitive_type = params.output_shape->element_type(); switch (output_primitive_type) { case F16: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, + stream, profile_result); case F32: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, dnums, - feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, stream, + profile_result); case F64: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, stream, + profile_result); default: - LOG(FATAL) << ShapeUtil::HumanString(output_shape); + LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index a1b4fc71d0..381aa37a1b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -47,6 +47,20 @@ enum class CudnnConvKind { kBackwardFilter, // input + output => filter }; +struct CudnnConvParams { + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; +}; + // Converts a CudnnConvKind value to a string. string CudnnConvKindToString(CudnnConvKind kind); @@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind); // Note that depending on the value of CudnnConvKind, the result of this call // may be written into input_buf, filter_buf, or output_buf! // -// At the moment we only support cudnn convolutions over float and half, and -// convolution with half data type is implemented with cudnn PSEUDO_HALF -// configuration, that is, the input values are half and the internal -// computation type is float. +// At the moment convolution with half data type is implemented with cudnn +// PSEUDO_HALF configuration, that is, the input values are half and the +// internal computation type is float. // // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In @@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind); // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla -- cgit v1.2.3 From bfc1897518063bfa1d62d9a3cfe5e6362c0d09d9 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 10 Sep 2018 17:34:29 -0700 Subject: [XLA:GPU] Don't canonicalize forward convs with constant filters to backwards conv. There's no right answer between these two choices, and our benchmarks show no performance difference. But canonicalizing to forward conv makes later pattern-matching passes work properly. PiperOrigin-RevId: 212366534 --- .../xla/service/gpu/cudnn_convolution_rewriter.cc | 87 +++++++++------------- 1 file changed, 37 insertions(+), 50 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 4a6a84d87d..3d1266355b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -234,51 +234,38 @@ MatchBackwardInput(HloInstruction* conv) { // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); - - // Match the reverse of the filter. ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); - if (reverse_filter->opcode() == HloOpcode::kReverse) { - if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || - !std::is_permutation(kernel_spatial_dims.begin(), - kernel_spatial_dims.end(), - reverse_filter->dimensions().begin())) { - VLOG(1) - << "Backward input convolution should reverse all kernel dimensions."; - return no_match_result; - } - } else if (reverse_filter->IsConstant()) { - // If the filter is a constant, we're willing to pattern-match to a - // backwards-input conv, on the theory that - // - // a) reversing a constant is free, and - // b) even if the user specified this filter as reverse(constant), we would - // long ago have constant-folded away the reverse. - // - // If the constant has any other uses, reversing it isn't entirely free, - // since we'd now have two constants to keep in memory. But hopefully it's - // free enough. - // - // TODO(jlebar): Should we do this even if the filter is not a constant? - // Reversing a non-constant filter is probably cheaper than padding the - // input! - - // Nothing to do, just fall through. - } else { - // Possibly 1x1 filter. - for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { - if (conv->window().dimensions(i).size() != 1) { - VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " - << reverse_filter->ToString(); - return no_match_result; - } - } - if (!window_util::HasBaseDilation(conv->window())) { - VLOG(1) << conv->ToString() - << " is a regular forward convolution. No need " - "to fold it to a backward input convolution."; - return no_match_result; - } + + // We pattern-match to a backwards input conv if: + // + // - all spatial dims of the filter are reversed + // + // OR + // + // - filter is 1x1 or a constant AND + // - conv has base dilation (otherwise this is just a regular forward conv). + // + // The final criterion above is just for canonicalization; cudnn seems to run + // just as fast if we canonicalize 1x1/constant filters without base dilation + // to forward or backward convs. We canonicalize to forward conv because (a) + // it's more natural (constant filters usually show up when doing inference, + // and having backwards convolutions in inference graphs would be weird), and + // (b) cudnn has special fusions for forward conv plus bias and activation, + // and we want to pattern-match to that after running this pass. + bool is_reversed_filter = + reverse_filter->opcode() == HloOpcode::kReverse && + absl::c_is_permutation(dnums.kernel_spatial_dimensions(), + reverse_filter->dimensions()); + bool is_1x1_filter = + absl::c_all_of(conv->window().dimensions(), + [](const WindowDimension& d) { return d.size() == 1; }); + if (!is_reversed_filter && + !(window_util::HasBaseDilation(conv->window()) && + (reverse_filter->IsConstant() || is_1x1_filter))) { + VLOG(1) << "Can't match to backwards convolution. Either filter is not " + "kReverse, or it's not a base-dialted conv with a 1x1 or " + "constant filter."; + return no_match_result; } // Match padding and dilation of the forward convolution. @@ -417,12 +404,12 @@ MatchBackwardInput(HloInstruction* conv) { reverse_filter->IsConstant()) { // Create a double-reverse, which is a nop. HloComputation* c = conv->parent(); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); } -- cgit v1.2.3 From c300a579be9c4adb3736f3551b35826f3f27b0f8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 17:37:11 -0700 Subject: Adds listdiff_op to android_extended_ops_group1 set. PiperOrigin-RevId: 212366879 --- tensorflow/contrib/makefile/tf_op_files.txt | 1 + tensorflow/core/kernels/BUILD | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 676620e544..08de54b8e1 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -130,6 +130,7 @@ tensorflow/core/kernels/immutable_constant_op.cc tensorflow/core/kernels/in_topk_op.cc tensorflow/core/kernels/initializable_lookup_table.c tensorflow/core/kernels/inplace_ops.cc +tensorflow/core/kernels/listdiff_op.cc tensorflow/core/kernels/logging_ops.cc tensorflow/core/kernels/lookup_table_init_op.cc tensorflow/core/kernels/lookup_table_op.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 972fb9efa9..c3c6013d83 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5184,6 +5184,7 @@ filegroup( "fifo_queue.cc", "fifo_queue_op.cc", "fused_batch_norm_op.cc", + "listdiff_op.cc", "population_count_op.cc", "population_count_op.h", "winograd_transform.h", -- cgit v1.2.3 From 6bbe31c5f5d42f646cb5080d955e9ee91bdb6d93 Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 11 Sep 2018 09:05:12 +0800 Subject: fix typos --- tensorflow/python/ops/rnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index dcc17db632..5a3a5cc225 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -710,9 +710,9 @@ def _dynamic_rnn_loop(cell, ValueError: If the input depth cannot be inferred via shape inference from the inputs. ValueError: If time_step is not the same for all the elements in the - input. + inputs. ValueError: If batch_size is not the same for all the elements in the - input. + inputs. """ state = initial_state assert isinstance(parallel_iterations, int), "parallel_iterations must be int" -- cgit v1.2.3 From de683c50d039676e36b6a718e4cc7ed2170a8a2f Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Mon, 10 Sep 2018 18:05:03 -0700 Subject: Simplify convolution_thunk's interface. PiperOrigin-RevId: 212370999 --- tensorflow/compiler/xla/service/gpu/BUILD | 3 ++ .../compiler/xla/service/gpu/convolution_thunk.cc | 54 +++++-------------- .../compiler/xla/service/gpu/convolution_thunk.h | 55 ++++++++----------- .../compiler/xla/service/gpu/ir_emission_utils.cc | 38 ++++++++++++++ .../compiler/xla/service/gpu/ir_emission_utils.h | 7 +++ .../xla/service/gpu/ir_emitter_unnested.cc | 61 ++++++---------------- 6 files changed, 96 insertions(+), 122 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index af953a2a16..aab8d0fdca 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -174,6 +174,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", @@ -371,6 +372,8 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], deps = [ + ":backend_configs", + ":cudnn_convolution_runner", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 9b567cf4a8..3a23ac1d63 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" @@ -30,63 +31,32 @@ namespace gpu { using se::dnn::AlgorithmDesc; -ConvolutionThunk::ConvolutionThunk( - CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, - const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count, - int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo) - : Thunk(Kind::kConvolution, hlo), - convolution_kind_(convolution_kind), - input_buffer_(input_buffer), - filter_buffer_(filter_buffer), - output_buffer_(output_buffer), - tuple_result_buffer_(tuple_result_buffer), - scratch_buffer_(scratch_buffer), - input_shape_(input_shape), - filter_shape_(filter_shape), - output_shape_(output_shape), - window_(window), - dim_nums_(dim_nums), - feature_group_count_(feature_group_count), - algorithm_(algorithm), - tensor_ops_enabled_(tensor_ops_enabled) {} - Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - se::DeviceMemoryBase input_data = - buffer_allocations.GetDeviceAddress(input_buffer_); - se::DeviceMemoryBase filter_data = - buffer_allocations.GetDeviceAddress(filter_buffer_); - se::DeviceMemoryBase output_data = - buffer_allocations.GetDeviceAddress(output_buffer_); + CudnnConvParams params; + + params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_); + params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_); + params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - se::dnn::AlgorithmConfig algorithm_config( - se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution( - {convolution_kind_, &input_shape_, &filter_shape_, &output_shape_, - input_data, filter_data, output_data, &window_, &dim_nums_, - feature_group_count_, algorithm_config}, - scratch, stream)); + TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. void* result_ptr = [&] { - switch (convolution_kind_) { + switch (params.kind) { case CudnnConvKind::kForward: - return output_data.opaque(); + return params.output_buf.opaque(); case CudnnConvKind::kBackwardInput: - return input_data.opaque(); + return params.input_buf.opaque(); case CudnnConvKind::kBackwardFilter: - return filter_data.opaque(); + return params.filter_buf.opaque(); } }(); void* ptrs[] = {result_ptr, scratch.opaque()}; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 68d67c40c5..d7d1f91fba 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -32,7 +33,7 @@ limitations under the License. namespace xla { namespace gpu { -// This class stores everything that StreamExecutor needs to launch a BNN +// This class stores everything that StreamExecutor needs to launch a DNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. @@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that - // we should use the default (i.e. baseline) cudnn algorithm. - // // Note that "output" here doesn't refer to the output from running this // thunk, but rather to the "output" of a hypothetical forward convolution // that corresponds to this input+filter+output triple. That is, the result // generated by this thunk is "output" for forward convs, "input" for // backward-input convs, and "filter" for backward-filter convs. - // - // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(CudnnConvKind convolution_kind, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, - const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, - int64 feature_group_count, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo); + ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, + BufferAllocation::Slice input_slice, + BufferAllocation::Slice filter_slice, + BufferAllocation::Slice output_slice, + BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + input_buffer_(std::move(input_slice)), + filter_buffer_(std::move(filter_slice)), + output_buffer_(std::move(output_slice)), + scratch_buffer_(std::move(scratch_slice)), + tuple_result_buffer_(std::move(tuple_result_slice)) {} ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - const CudnnConvKind convolution_kind_; - - const BufferAllocation::Slice input_buffer_; - const BufferAllocation::Slice filter_buffer_; - const BufferAllocation::Slice output_buffer_; - const BufferAllocation::Slice tuple_result_buffer_; - const BufferAllocation::Slice scratch_buffer_; - - const Shape input_shape_; - const Shape filter_shape_; - const Shape output_shape_; - - const Window window_; - const ConvolutionDimensionNumbers dim_nums_; - int64 feature_group_count_; - int64 algorithm_; - bool tensor_ops_enabled_; + const HloCustomCallInstruction* cudnn_call_; + BufferAllocation::Slice input_buffer_; + BufferAllocation::Slice filter_buffer_; + BufferAllocation::Slice output_buffer_; + BufferAllocation::Slice scratch_buffer_; + BufferAllocation::Slice tuple_result_buffer_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 20d523abe0..22f43bc08b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -287,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } +Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, + CudnnConvParams* params) { + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); + const auto& target = custom_call->custom_call_target(); + const auto& lhs_shape = custom_call->operand(0)->shape(); + const auto& rhs_shape = custom_call->operand(1)->shape(); + const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); + + params->window = &custom_call->window(); + params->dnums = &custom_call->convolution_dimension_numbers(); + params->feature_group_count = custom_call->feature_group_count(); + params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + + if (target == kCudnnConvForwardCallTarget) { + params->kind = CudnnConvKind::kForward; + params->input_shape = &lhs_shape; + params->filter_shape = &rhs_shape; + params->output_shape = &conv_result_shape; + } else if (target == kCudnnConvBackwardInputCallTarget) { + params->kind = CudnnConvKind::kBackwardInput; + params->input_shape = &conv_result_shape; + params->filter_shape = &rhs_shape; + params->output_shape = &lhs_shape; + } else if (target == kCudnnConvBackwardFilterCallTarget) { + params->kind = CudnnConvKind::kBackwardFilter; + params->input_shape = &lhs_shape; + params->filter_shape = &conv_result_shape; + params->output_shape = &rhs_shape; + } else { + LOG(FATAL) << "Unexpected custom call target: " + << custom_call->custom_call_target(); + } + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 59c65fc268..09c455cc1e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -20,7 +20,9 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they // don't belong in "ir_emission_utils". @@ -148,6 +150,11 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Populates params using conv, which must be a custom-call to a cudnn +// convolution. Does not modify any buffers in the params. +Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, + CudnnConvParams* params); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f91cc00d71..b669881026 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); - std::unique_ptr thunk; + BufferAllocation::Slice input_slice, filter_slice, output_slice; + if (target == kCudnnConvForwardCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kForward, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/conv_result_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = rhs_slice; + output_slice = conv_result_slice; } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kBackwardInput, - /*input_buffer=*/conv_result_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/lhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/lhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = conv_result_slice; + filter_slice = rhs_slice; + output_slice = lhs_slice; } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kBackwardFilter, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/conv_result_slice, - /*output_buffer=*/rhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, - /*output_shape=*/rhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = conv_result_slice; + output_slice = rhs_slice; } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); } - thunk_sequence_->emplace_back(std::move(thunk)); + thunk_sequence_->emplace_back(absl::make_unique( + Cast(custom_call), input_slice, filter_slice, + output_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } -- cgit v1.2.3 From 497715e0a9bbb3c844a1902e319778cc30819f77 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 10 Sep 2018 18:25:37 -0700 Subject: [XLA:GPU] Don't canonicalize forward convs with constant filters to backwards conv. No functional change. PiperOrigin-RevId: 212373345 --- .../compiler/xla/service/algebraic_simplifier.cc | 302 ++++++++++++--------- 1 file changed, 167 insertions(+), 135 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 2a0823aeca..c88a3a3b4b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return scalar_add_computation_; } + // Tries to fold a kPad in the input or filter into the convolution + // instruction's window. + StatusOr FoldConvInputPad(HloInstruction* convolution); + StatusOr FoldConvFilterPad(HloInstruction* convolution); + + // Tries to use a kDot in place of the given convolution. + StatusOr SimplifyConvToDot(HloInstruction* convolution); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplification on platforms where it causes a slowdown. + // Disable convolution -> dot simplification on platforms where it causes a + // slowdown. bool enable_conv_simplification_; // Cached computation for adding two scalar F32. @@ -2212,169 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleConvolution( +StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { - auto lhs = convolution->mutable_operand(0); - auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::IsZeroElementArray(lhs->shape()) || - ShapeUtil::IsZeroElementArray(rhs->shape())) { - return ReplaceWithNewInstruction( - convolution, - HloInstruction::CreateBroadcast( - convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()))), - {})); - } - + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); const auto& window = convolution->window(); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - // Try to merge padding/dilation of the input with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr { - if (lhs->opcode() != HloOpcode::kPad) { + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(lhs->operand(1), 0)) { + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { return false; } - const auto& padding = lhs->padding_config(); - - // Can't pad batch or feature dims. - for (int64 dim : - {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); } + } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = window; - for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); - // Edge padding composes with itself in the straightforward way, but - // composing interior padding is nontrivial, and we cowardly refuse to - // think about it. If we see interior padding in either the kPad or conv, - // bail if there's any sort of padding in the other. - if (p.interior_padding() != 0 && - (w.padding_low() != 0 || w.padding_high() != 0 || - w.base_dilation() != 1)) { - return false; - } - if (w.base_dilation() != 1 && - (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0)) { - return false; - } + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - w.set_padding_low(w.padding_low() + p.edge_padding_low()); - w.set_padding_high(w.padding_high() + p.edge_padding_high()); - if (p.interior_padding() != 0) { - CHECK_EQ(w.base_dilation(), 1); - w.set_base_dilation(1 + p.interior_padding()); - } - } +StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs->mutable_operand(0), rhs}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } - if (folded_input_pad) { - return Status::OK(); + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; } - // Try to merge dilation of the filter with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr { - if (rhs->opcode() != HloOpcode::kPad) { - return false; - } + const auto& padding = rhs->padding_config(); - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(rhs->operand(1), 0)) { + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - const auto& padding = rhs->padding_config(); + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - // Can't pad or dilate feature dims. - for (int64 dim : {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = convolution->window(); - for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - - // We can only do this transformation if p adds dilation to the filter -- - // edge padding on the filter is not supported in conv. - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { - return false; - } - - // Nothing to do if the kPad for this dim is entirely a nop. - if (p.interior_padding() == 0) { - continue; - } + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } - // We cowardly refuse to think about how dilation composes with itself; - // bail if both the kPad and conv have dilation on this dimension. - if (w.window_dilation() > 1) { - return false; - } - CHECK_EQ(w.window_dilation(), 1); - w.set_window_dilation(1 + p.interior_padding()); - w.set_size(rhs->operand(0)->shape().dimensions( - dnums.kernel_spatial_dimensions(dim))); + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); + } - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs, rhs->mutable_operand(0)}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - if (folded_filter_pad) { - return Status::OK(); - } +StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); if (!enable_conv_simplification_) { - return Status::OK(); + return false; } - // HandleConvolution tries to replace a convolution with a DOT instruction. - // - // Only add when bitcasts can be used: - // - if bitcasts are not supported, then reshapes could be used but will - // end up with another copy. - // - if bitcasts are supported, the simplifier will be called again with - // bitcasts_ == true. - // TODO(cwhipkey): b/31337498, make this layout insensitive. + // TODO(b/31337498): For now, we cowardly refuse to do this optimization in + // layout-insensitive mode, for fear of adding nontrivial reshapes. if (!is_layout_sensitive_) { - return Status::OK(); + return false; } const Shape& input_shape = lhs->shape(); @@ -2387,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Require the spatial dimensions in the kernel to have a bound of one. for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { - return Status::OK(); + return false; } } @@ -2398,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // for a 1x1 window, so window dilation is no problem. if (window_util::HasStride(window) || window_util::HasPadding(window) || window_util::HasBaseDilation(window)) { - return Status::OK(); + return false; } // Also, the shapes must align for a rowmajor matmul: @@ -2424,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dnums.kernel_input_feature_dimension()) < PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { - return Status::OK(); + return false; } auto add_bitcast = [&](Shape shape, HloInstruction* operand) { @@ -2466,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( if (!valid_bitcast_callback_(input_shape, new_input_shape) || !valid_bitcast_callback_(filter_shape, new_filter_shape) || !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { - return Status::OK(); + return false; } auto new_lhs = add_bitcast(new_input_shape, lhs); @@ -2478,7 +2473,44 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, convolution->precision_config())); - return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); + TF_RETURN_IF_ERROR( + ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); + return true; +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution) { + // Zero-sized input or filter. + if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(convolution->shape().element_type()))), + {})); + } + + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); + if (folded_filter_pad) { + return Status::OK(); + } + + // Try to replace the convolution with a kDot instruction. + TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); + if (replaced_with_dot) { + return Status::OK(); + } + + return Status::OK(); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( -- cgit v1.2.3 From e6830cdb06efe6f4cea2e4f30aa98f66ee1b305a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 19:50:28 -0700 Subject: Resolving a bug where regex pattern for errors was not matching in case the error message had multiple newline characters. PiperOrigin-RevId: 212381070 --- tensorflow/contrib/data/python/kernel_tests/test_utils.py | 7 +++++-- tensorflow/python/framework/error_interpolation.py | 2 +- tensorflow/python/framework/error_interpolation_test.py | 7 ++++++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py index 1def07179a..4c3353fe40 100644 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -52,8 +52,11 @@ class DatasetTestBase(test.TestCase): dataset2, exception_class, replacements=None): - next1 = dataset1.make_one_shot_iterator().get_next() - next2 = dataset2.make_one_shot_iterator().get_next() + # We are defining next1 and next2 in the same line so that we get identical + # file:line_number in the error messages + # pylint: disable=line-too-long + next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next() + # pylint: enable=line-too-long with self.cached_session() as sess: try: sess.run(next1) diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index 46bda2e621..bc3c81b2a2 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -34,7 +34,7 @@ from tensorflow.python.util import tf_stack _NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?" _TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX) _INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) -_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX) +_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL) _ParseTag = collections.namedtuple("_ParseTag", ["type", "name"]) diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index d312b825d2..1b77548592 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -184,9 +184,14 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase): interpolated_string = error_interpolation.interpolate( two_tags_with_seps, self.graph) expected_regex = ( - r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]*\) ;;;$") + r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$") self.assertRegexpMatches(interpolated_string, expected_regex) + def testNewLine(self): + newline = "\n\n{{node One}}" + interpolated_string = error_interpolation.interpolate(newline, self.graph) + self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*") + class InterpolateDeviceSummaryTest(test.TestCase): -- cgit v1.2.3 From 0b176e9e45d391b2e6da5199fc6c5e8000a772a4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 20:39:11 -0700 Subject: Give a warning about partitioned variable on TPU and set it to None, instead of erring out. PiperOrigin-RevId: 212385555 --- tensorflow/contrib/tpu/python/tpu/tpu.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index c1f90c3963..0f9f7cd91b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -654,13 +654,16 @@ def split_compile_and_replicate(computation, # variables. # Partitioned variables is not supported (b/112311320). def custom_getter(getter, name, *args, **kwargs): + """Variables on TPU have a few restrictions.""" partitioner = kwargs["partitioner"] - if partitioner is None: - return getter(name, *args, **kwargs) - else: - raise ValueError( + if partitioner is not None: + kwargs["partitioner"] = None + logging.warning( "Partitioned variables are not supported on TPU. Got " - "`partitioner` that is {}.".format(partitioner)) + "`partitioner` that is {} for variable {}. " + "Setting `partitioner` to `None`." + .format(partitioner, name)) + return getter(name, *args, **kwargs) vscope = variable_scope.get_variable_scope() -- cgit v1.2.3 From 786ebb25ea3cd5d69d04bf63838d8dfbf13e6e37 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Mon, 10 Sep 2018 20:49:36 -0700 Subject: Simplify algorithm picker's internal interface. PiperOrigin-RevId: 212386412 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../gpu/cudnn_convolution_algorithm_picker.cc | 83 ++++++++-------------- .../gpu/cudnn_convolution_algorithm_picker.h | 6 +- 3 files changed, 33 insertions(+), 57 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index aab8d0fdca..64b9683628 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -399,6 +399,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 8fcff84173..c607aea1a8 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/mutex.h" @@ -176,10 +177,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // caching would speed up compilation a lot. StatusOr> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr) { + const HloCustomCallInstruction* instr) { + CudnnConvParams params; + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, ¶ms)); + + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); CHECK_EQ(input_shape.element_type(), output_shape.element_type()); // TODO(timshen): for now only check fp16. It can be expanded to other types, @@ -220,13 +225,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. ScratchAllocator input_output_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, + TF_ASSIGN_OR_RETURN(params.input_buf, input_output_allocator.AllocateBytes( &stream, ShapeUtil::ByteSizeOf(input_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, + TF_ASSIGN_OR_RETURN(params.filter_buf, input_output_allocator.AllocateBytes( &stream, ShapeUtil::ByteSizeOf(filter_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, + TF_ASSIGN_OR_RETURN(params.output_buf, input_output_allocator.AllocateBytes( &stream, ShapeUtil::ByteSizeOf(output_shape))); @@ -253,32 +258,32 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( static_cast(buffer.opaque()) + aligned_size, left_over_bytes); stream.ThenMemcpy(&left_over, halfs, left_over_bytes); }; - initialize_f16(input_buf); - initialize_f16(filter_buf); - initialize_f16(output_buf); + initialize_f16(params.input_buf); + initialize_f16(params.filter_buf); + initialize_f16(params.output_buf); } else { // Although we don't have evidence this matters, zero out the buffers before // autotuning. It's conceivable that using uninitialized memory as the // inputs might affect performance if e.g. the inputs contain denormals, and // this is easy enough. - stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()); + stream.ThenMemZero(¶ms.input_buf, params.input_buf.size()) + .ThenMemZero(¶ms.filter_buf, params.filter_buf.size()) + .ThenMemZero(¶ms.output_buf, params.output_buf.size()); } DeviceMemoryBase* result_buf = [&] { - switch (kind) { + switch (params.kind) { case CudnnConvKind::kBackwardFilter: - return &filter_buf; + return ¶ms.filter_buf; case CudnnConvKind::kBackwardInput: - return &input_buf; + return ¶ms.input_buf; case CudnnConvKind::kForward: - return &output_buf; + return ¶ms.output_buf; } }(); const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( - input_shape, output_shape, dnums, stream_exec_); + input_shape, output_shape, *params.dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; @@ -288,18 +293,16 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // this algorithm considered correct, though. optional first_algorithm; for (const AlgorithmDesc& alg : - GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = - RunCudnnConvolution({kind, &input_shape, &filter_shape, &output_shape, - input_buf, filter_buf, output_buf, &window, &dnums, - feature_group_count, AlgorithmConfig(alg)}, - &scratch_allocator, &stream, &profile_result) - .ok(); + params.algorithm = AlgorithmConfig(alg); + bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream, + &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { const bool crash_on_checking_failure = @@ -374,34 +377,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - const auto& call_target = instr->custom_call_target(); - const auto& lhs_shape = instr->operand(0)->shape(); - const auto& rhs_shape = instr->operand(1)->shape(); - const auto& conv_result_shape = instr->shape().tuple_shapes(0); - StatusOr> alg_scratch_and_tc; - if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = - PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, instr->window(), - instr->convolution_dimension_numbers(), - instr->feature_group_count(), instr); - } else if (call_target == kCudnnConvBackwardInputCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), - instr->convolution_dimension_numbers(), instr->feature_group_count(), - instr); - } else if (call_target == kCudnnConvBackwardFilterCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, - instr->window(), instr->convolution_dimension_numbers(), - instr->feature_group_count(), instr); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instr->ToString(); - } + StatusOr> alg_scratch_and_tc = + PickBestAlgorithm(Cast(instr)); if (!alg_scratch_and_tc.ok()) { LOG(ERROR) << alg_scratch_and_tc.status(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 0cb01161b0..f79b113f8f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -49,10 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); StatusOr> PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr); + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null -- cgit v1.2.3 From 34ef46ca948440fa034c7b29cf1a516750eb02d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 21:38:54 -0700 Subject: internal change only. PiperOrigin-RevId: 212390532 --- tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 0345a2a5f8..d52f4e5a61 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -123,6 +123,10 @@ class NodeFilter { // We arbitrarily set this as the boundary between "large" and "small" // instructions. bool IsSmall(const HloInstruction* instr) { + if (ShapeUtil::IsOpaque(instr->shape()) || + ShapeUtil::IsToken(instr->shape())) { + return true; + } return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; } -- cgit v1.2.3 From 45965cfd8b54fb113275ffdaced5366e28aa3553 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 00:50:04 -0700 Subject: Graph optimization pass that creates XlaLaunch ops for the computations that have been explicitly marked to be compiled via xla.compile() PiperOrigin-RevId: 212407112 --- tensorflow/compiler/jit/BUILD | 6 + .../compiler/jit/encapsulate_subgraphs_pass.cc | 17 + .../compiler/jit/encapsulate_subgraphs_pass.h | 6 + .../jit/encapsulate_xla_computations_pass.cc | 360 +++++++++++++++++++++ .../jit/encapsulate_xla_computations_pass.h | 61 ++++ .../jit/encapsulate_xla_computations_pass_test.cc | 346 ++++++++++++++++++++ .../jit/jit_compilation_pass_registration.cc | 7 + tensorflow/compiler/jit/ops/xla_ops.cc | 19 ++ tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/compiler/tf2xla/cc/BUILD | 4 +- tensorflow/compiler/tf2xla/test_util.cc | 8 + tensorflow/compiler/tf2xla/test_util.h | 16 + .../core/common_runtime/graph_execution_state.cc | 4 + .../core/grappler/optimizers/meta_optimizer.cc | 23 ++ 14 files changed, 877 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc create mode 100644 tensorflow/compiler/jit/encapsulate_xla_computations_pass.h create mode 100644 tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a989f15a1c..352f63bc98 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -362,6 +362,7 @@ cc_library( "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", + "encapsulate_xla_computations_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -370,6 +371,7 @@ cc_library( "build_xla_launch_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", + "encapsulate_xla_computations_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -396,6 +398,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -474,6 +477,7 @@ tf_cc_test( size = "small", srcs = [ "encapsulate_subgraphs_pass_test.cc", + "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], @@ -489,7 +493,9 @@ tf_cc_test( "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index ae7a22f451..e0632ff7e4 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; +void SortControlInputs(GraphDef* gdef) { + int64 num_nodes = gdef->node_size(); + for (int64 i = 0; i < num_nodes; ++i) { + NodeDef* node = gdef->mutable_node(i); + // Stable sort control inputs and leave the order of data inputs unchanged. + std::stable_sort(node->mutable_input()->begin(), + node->mutable_input()->end(), + [](const string& a, const string& b) { + bool a_is_control = absl::StartsWith(a, "^"); + bool b_is_control = absl::StartsWith(b, "^"); + return (!a_is_control && b_is_control) || + (a_is_control && b_is_control && a < b); + }); + } +} + namespace { bool AreAllParentsGuaranteedConst( diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 926589546f..90354a801a 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; +// Sorts each node's control inputs by their names. This guarantees that for two +// structually equivalent GraphDefs, we get the same traversal ordering on +// node's control input fields. +// TODO(hpucha): Move the utilities to a more appropriate place. +void SortControlInputs(GraphDef* gdef); + class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc new file mode 100644 index 0000000000..97ef8cd3cb --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -0,0 +1,360 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/fingerprint.h" + +namespace tensorflow { + +const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = + "_xla_compile_id"; + +namespace { + +const char* const kXlaClusterOutput = "XlaClusterOutput"; + +// Checks if a graph node is marked to be a guaranteed constant. +bool is_guaranteed_constant(const Node& n) { + bool guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) + .ok()) { + return false; + } + return guaranteed_constant; +} + +// Finds the `index` of an _Arg or _Retval node. +Status GetIndexAttr(const Node& n, int num_args, int* index) { + TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); + if (*index < 0 || *index >= num_args) { + return errors::InvalidArgument("Invalid ", n.type_string(), " number ", + *index); + } + return Status::OK(); +} + +// Returns the data type of the destination of an edge. +DataType EdgeType(const Edge* edge) { + return edge->dst()->input_type(edge->dst_input()); +} + +// Adds the control inputs of `node` to `*deps`. +void AddControlInputs(const Node& node, gtl::FlatSet* deps) { + for (const Edge* edge : node.in_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->src()); + } + } +} + +// Adds the control outputs of `node` to `*deps`. +void AddControlOutputs(const Node& node, gtl::FlatSet* deps) { + for (const Edge* edge : node.out_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->dst()); + } + } +} + +// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts +// the arguments into the order expected by XlaLaunch computations: +// 1) arguments +// 2) resource variable arguments +// See the documentation of EncapsulateSubgraphsInFunctions for the meaning +// of the arguments. +// +// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. +Status RewriteSubgraph(const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + const int num_args = input_permutation->size(); + const int num_retvals = output_permutation->size(); + + std::vector args; + std::vector retvals; + args.reserve(num_args); + retvals.reserve(num_retvals); + for (Node* n : graph->nodes()) { + if (n->type_string() == "_Arg") { + // Check if this is a guaranteed constant. + if (is_guaranteed_constant(*n)) { + return errors::InvalidArgument( + "Guaranteed constants are not supported (", n->name(), ")"); + } + args.push_back(n); + } else if (n->type_string() == "_Retval") { + retvals.push_back(n); + } + } + + if (std::find(args.begin(), args.end(), nullptr) != args.end()) { + return errors::InvalidArgument("Missing or non-consecutive arguments"); + } + + // Reorders the arguments. + std::sort(args.begin(), args.end(), [&](Node* a, Node* b) { + // Non-resources appear before resources + bool a_is_resource = (a->output_type(0) == DT_RESOURCE); + bool b_is_resource = (b->output_type(0) == DT_RESOURCE); + // Uses the name as a tiebreaker so the output is deterministic. + StringPiece a_name(a->name()); + StringPiece b_name(b->name()); + return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); + }); + + // Sorts the retvals by name so the order is deterministic. + std::sort(retvals.begin(), retvals.end(), + [](Node* a, Node* b) { return a->name() < b->name(); }); + + // Computes the permutation to produce the correct argument order, and update + // the argument indices. + int variable_start_index = num_args; + for (int i = 0; i < num_args; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index)); + if (args[i]->output_type(0) == DT_RESOURCE && + variable_start_index == num_args) { + variable_start_index = i; + } + (*input_permutation)[index] = i; + args[i]->AddAttr("index", i); + } + VLOG(4) << "variable_start_index: " << variable_start_index; + + // Computes the permutation to produce the correct retval order, and update + // the argument indices. + for (int i = 0; i < num_retvals; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index)); + (*output_permutation)[index] = i; + retvals[i]->AddAttr("index", i); + } + + AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), + call_def); + AddNodeAttr("_variable_start_index", variable_start_index, call_def); + + // Uniquify the function name. + GraphDef gdef; + graph->ToGraphDef(&gdef); + + // Before serialization, sort each node's control inputs to achieve + // determinism. Sorting control inputs could help (but not necessarily) create + // a deterministic serialization and fingerprint. Other sources of + // nondeterminism include unstable node ordering. + SortControlInputs(&gdef); + // Fingerprint the function. + // Nondeterminism in serialization would not lead to incorrect results, but + // may cause spurious cache misses. DeterministicSerialization is a + // best-effort deterministic serialization. + string serialized; + TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); + uint64 fingerprint = Fingerprint64(serialized); + LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); + return Status::OK(); +} + +} // namespace + +/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // Check for undeclared outputs before Encapsulation, so we can give a better + // error message. + // TODO(phawkins): merge this with the encapsulation code to avoid the extra + // O(n) pass over the edges. + for (const Edge* e : (*graph)->edges()) { + if (!e->IsControlEdge() && + e->src()->attrs().Find(kXlaClusterAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->dst()->type_string() != kXlaClusterOutput) { + return errors::InvalidArgument( + "Undeclared output of XLA computation. A common cause of this error " + "is variable initializers that depend on the XLA computation. Edge: ", + e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", + e->dst_input()); + } + } + + auto output = absl::make_unique((*graph)->op_registry()); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, "", **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), + "EncapsulateXlaComputationsPass failed"); + graph->swap(output); + return Status::OK(); +} + +/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( + Graph* graph) { + // Finds all of the XlaLaunch function calls, to avoid mutating the graph + // while iterating. + std::vector launch_nodes; + for (Node* n : graph->nodes()) { + string name; + if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { + launch_nodes.push_back(n); + } + } + + // Replaces each launch function call together with its neighboring + // XlaClusterOutput nodes with a XlaLaunch node. + for (Node* launch : launch_nodes) { + int variable_start_index; + TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index", + &variable_start_index)); + + std::vector in_edges; + TF_RETURN_IF_ERROR(launch->input_edges(&in_edges)); + + const int num_inputs = in_edges.size(); + const int num_variables = num_inputs - variable_start_index; + const int num_args = variable_start_index; + + VLOG(4) << "Launch node '" << launch->name() << "'" + << " input edges: " << in_edges.size() << " num_args: " << num_args + << " num_variables: " << num_variables; + + std::vector nodes_to_remove = {launch}; + + // Data and control inputs to the new XlaLaunch node. + std::vector> data_inputs(num_inputs); + gtl::FlatSet control_inputs; + DataTypeVector arg_types(num_args); + + AddControlInputs(*launch, &control_inputs); + + for (int i = 0; i < num_args; ++i) { + const Edge* edge = in_edges[i]; + data_inputs[i] = {edge->src(), edge->src_output()}; + arg_types[i] = EdgeType(edge); + } + + // Appends the variable inputs. + for (int i = 0; i < num_variables; ++i) { + int pos = variable_start_index + i; + const Edge* edge = in_edges[pos]; + data_inputs[pos] = {edge->src(), edge->src_output()}; + } + + // Outputs. + const int num_outputs = launch->output_types().size(); + gtl::FlatSet control_outputs; + std::vector>> data_outputs(num_outputs); + DataTypeVector output_types(num_outputs); + + for (const Edge* le : launch->out_edges()) { + if (le->IsControlEdge()) { + control_outputs.insert(le->dst()); + } else { + TF_RET_CHECK(le->src_output() < num_outputs); + Node* output_node = le->dst(); + + TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput) + << le->DebugString(); + nodes_to_remove.push_back(output_node); + + for (const Edge* oe : output_node->out_edges()) { + TF_RET_CHECK(!oe->IsControlEdge()); + data_outputs[le->src_output()].push_back( + {oe->dst(), oe->dst_input()}); + } + output_types[le->src_output()] = output_node->input_type(0); + + AddControlOutputs(*output_node, &control_outputs); + } + } + + NodeDef def; + def.set_name(launch->name()); + + // Target the XLA CPU/GPU backends. + VLOG(2) << "Replacing with XlaLaunch"; + def.set_op("XlaLaunch"); + AddNodeAttr("Tconstants", DataTypeVector{}, &def); + AddNodeAttr("Targs", arg_types, &def); + AddNodeAttr("Nresources", num_variables, &def); + AddNodeAttr("Tresults", output_types, &def); + NameAttrList function; + function.set_name(launch->type_string()); + AddNodeAttr("function", function, &def); + + for (Node* node : nodes_to_remove) { + VLOG(2) << "Deleting node " << node->DebugString(); + // Ensure that we do not attempt to add control edges to nodes that are + // deleted. + control_inputs.erase(node); + control_outputs.erase(node); + graph->RemoveNode(node); + } + + Status status; + Node* xla_launch = graph->AddNode(def, &status); + if (!status.ok()) { + return status; + } + for (int i = 0; i < data_inputs.size(); ++i) { + graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, + i); + } + for (Node* n : control_inputs) { + graph->AddControlEdge(n, xla_launch); + } + for (int i = 0; i < data_outputs.size(); ++i) { + for (const auto& successor : data_outputs[i]) { + graph->AddEdge(xla_launch, i, successor.first, successor.second); + } + } + for (Node* n : control_outputs) { + graph->AddControlEdge(xla_launch, n); + } + } + return Status::OK(); +} + +Status EncapsulateXlaComputationsPass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "EncapsulateXlaComputations(): " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + VLOG(1) << "EncapsulateXlaComputations() half-way: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + VLOG(1) << "EncapsulateXlaComputations() finished: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", + **options.graph, options.flib_def); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h new file mode 100644 index 0000000000..c8bb4dc114 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Rewrites computations generated by the xla.compile() Python code into +// XlaLaunch nodes. +// +// xla.compile() does two main things: +// a) marks operators that make up a XLA computation with the attribute +// _xla_compile_id=XYZ, where XYZ is a unique key. +// b) adds XlaClusterOutput nodes to represent outputs of the computation. +// These nodes are not marked with the _xla_compile_id attribute. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +// Encapsulates nodes marked with the _xla_compile_id attribute into +// XlaLaunch operators. +class EncapsulateXlaComputationsPass : public GraphOptimizationPass { + public: + static const char* const kXlaClusterAttr; // _xla_compile_id + + Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for unit tests. + + // This pass has two stages: + // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes + // marked with the same _xla_compile_id attribute into functions. These + // functions contain the computations to be passed to XlaLaunch. During + // encapsulation, we sort the arguments into the order expected by + // XlaLaunch. + static Status Encapsulate(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // b) we rewrite the function calls generated in phase (a) into XlaLaunch + // operators. We also convert the XlaClusterOutput output nodes of the + // function call into the outputs of the XlaLaunch operator. + static Status BuildXlaLaunchOps(Graph* graph); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc new file mode 100644 index 0000000000..f643fb0cfe --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -0,0 +1,346 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +static std::unique_ptr MakeOuterGraph( + const FunctionLibraryDefinition& flib_def, const string& function) { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NodeDef def; + TF_CHECK_OK( + NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); + + Status status; + Node* launch = scope.graph()->AddNode(def, &status); + TF_CHECK_OK(status); + TF_CHECK_OK(scope.DoShapeInference(launch)); + scope.graph()->AddEdge(a.node(), 0, launch, 0); + scope.graph()->AddEdge(b.node(), 0, launch, 1); + scope.graph()->AddEdge(c.node(), 0, launch, 2); + scope.graph()->AddEdge(d.node(), 0, launch, 3); + scope.graph()->AddEdge(u.node(), 0, launch, 4); + scope.graph()->AddEdge(v.node(), 0, launch, 5); + scope.graph()->AddEdge(w.node(), 0, launch, 6); + + auto out0 = + ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0)); + auto out1 = + ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1)); + auto out2 = + ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2)); + auto out3 = + ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3)); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +// Makes an encapsulate body graph for use in tests. +static std::unique_ptr MakeBodyGraph() { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); + auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); + + auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); + auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); + auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, arg3); + add_attrs(g.node()); + + auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), + b_identity, 0); + auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); + auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); + auto out3 = + ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { + // Test that control edge insertion order doesn't affect the cache key + // (cluster name) generated by TPU encapsulate pass. + auto get_serialized_graph = [](bool control_input_reversed, + bool operand_reversed) -> string { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); + auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); + + ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) + : ops::Add(scope.WithOpName("E"), a1, a0); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, + "launch0"); + }; + add_attrs(e.node()); + + TF_CHECK_OK(scope.ToGraph(graph.get())); + auto get_node_in_graph = [&graph](Node* node) { + return graph->FindNodeId(node->id()); + }; + // Insert control edge in different order. The order should not affect + // the encapsulated or serialized graph. + if (!control_input_reversed) { + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + } else { + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + } + } + TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + GraphDef gdef; + graph->ToGraphDef(&gdef); + // Before serialization, sort control inputs first to remove + // nondeterminism. + SortControlInputs(&gdef); + string serialized; + SerializeToStringDeterministic(gdef, &serialized); + return serialized; + }; + + // Changing the order of control input shouldn't affect the graph generated. + EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, + /*operand_reversed=*/false), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); + + // Changing the order of data input should affect the graph generated. + EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/true), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); +} + +TEST(EncapsulateXlaComputations, Encapsulate) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); + add_attrs(b_identity.node()); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), a, c); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, d); + add_attrs(g.node()); + + auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity); + auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e); + auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g); + auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + std::unique_ptr graph_copy(new Graph(&flib_def)); + CopyGraph(*graph, graph_copy.get()); + + TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + + std::unordered_map index = BuildNodeIndex(*graph); + string function = index.at("launch0")->type_string(); + + // Tests the outer graph is as expected. + { + std::unique_ptr outer = MakeOuterGraph(flib_def, function); + GraphDef expected_def; + outer->ToGraphDef(&expected_def); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); + } + + // Tests the encapsulated body graph is as expected. + { + std::unique_ptr body = MakeBodyGraph(); + GraphDef expected_body_def; + body->ToGraphDef(&expected_body_def); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, + DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); + } + + // Encapsulates the same computation again, verifies we reuse the same + // function. Encapsulation should be deterministic to avoid recompilation. + TF_ASSERT_OK( + EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); + std::unordered_map index_copy = BuildNodeIndex(*graph_copy); + string function_copy = index_copy.at("launch0")->type_string(); + EXPECT_EQ(function, function_copy); +} + +TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { + std::unique_ptr body_graph = MakeBodyGraph(); + FunctionDefLibrary flib; + TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function())); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + std::unique_ptr graph = MakeOuterGraph(flib_def, "launch0"); + TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get())); + + Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NameAttrList function; + function.set_name("launch0"); + auto launch = ops::XlaLaunch( + scope.WithOpName("launch0"), std::initializer_list{}, + std::initializer_list{a, b, c, d}, + std::initializer_list{u, v, w}, + DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); + + auto consumer0_a = + ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); + auto consumer0_b = + ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); + auto consumer0_c = + ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); + auto consumer1 = + ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); + auto consumer2 = + ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); + auto consumer3 = + ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); + + GraphDef expected_def; + TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ(expected_def, actual_def); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index c37b6112cc..315fcb2fa7 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -15,12 +15,19 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { +// EncapsulateXlaComputationsPass rewrites computations generated by the +// xla.compile() Python code into XlaLaunch nodes. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, + EncapsulateXlaComputationsPass); + +// The following POST_REWRITE passes support auto-clustering to enable XLA. REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index f2473d98ff..1a29c3caab 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::InferenceContext; + REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") @@ -32,4 +36,19 @@ REGISTER_OP("XlaLaunch") .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); +REGISTER_OP("XlaClusterOutput") + .Input("input: T") + // Note: when replication is supported, this op will have N outputs. + .Output("outputs: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(0)); + } + return Status::OK(); + }) + .Doc( + "Operator that connects the output of an XLA computation to other " + "consumer graph nodes."); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ab289a2b6c..74b131e07e 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -594,6 +594,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index ea8d1b3d14..8ac5eb5df9 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -31,7 +31,9 @@ cc_library( tf_gen_op_wrapper_cc( name = "xla_jit_op_gen", out_ops_file = "ops/xla_jit_op", - deps = ["//tensorflow/compiler/jit/ops:xla_ops"], + deps = [ + "//tensorflow/compiler/jit/ops:xla_ops", + ], ) cc_library( diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 3c6c9a91b6..f31bfb45a2 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } +std::unordered_map BuildNodeIndex(const Graph& graph) { + std::unordered_map index; + for (Node* node : graph.nodes()) { + index[node->name()] = node; + } + return index; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index e6e4ae92ed..350a868568 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); +// Builds a map from node name to Node* for `graph`. +std::unordered_map BuildNodeIndex(const Graph& graph); + } // namespace tensorflow +// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for +// equality. +#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ + do { \ + string diff; \ + EqualGraphDefOptions eq_options; \ + eq_options.ignore_internal_attrs = false; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + #endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 7f260b3139..4475fa979e 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -561,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph( grappler::GrapplerItem item; item.id = "tf_graph"; graph_->ToGraphDef(&item.graph); + // TODO(b/114748242): Add a unit test to test this bug fix. + if (flib_def_) { + *item.graph.mutable_library() = flib_def_->ToProto(); + } item.fetch.insert(item.fetch.end(), options.callable_options.fetch().begin(), diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index a5fd33d28b..b75d6303b4 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -72,6 +72,16 @@ bool IsRunOnceOptimizer(const string& name) { name == "loop_optimizer"; } +// Check if the graphdef contains nodes that indicate TPU execution. +bool IsTPUGraphDef(const GraphDef& def) { + for (auto node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} + } // namespace #define MK_OPT(NAME, VALUE) \ @@ -336,6 +346,19 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // 1. Optimize main graph TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph)); + // Skip optimizing functions if this is a TPU graph. Currently, Grappler + // passes do not handle TPU functions correctly in a variety of ways (Note + // that due to the pre-placement TPU graph rewriting passes, the TPU-related + // ops are encapsulated away into functions). For example, TPU graphs contain + // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler + // passes could prune that away. Grappler passes could also cause issues + // around shape inference. Since the desired and existing behavior is to not + // optimize TPU functions with Grappler, this check preserves that. + if (IsTPUGraphDef(*optimized_graph)) { + VLOG(2) << "Skipping optimizing funcs for TPU graphs"; + return Status::OK(); + } + // 2. Optimize function library FunctionLibraryDefinition flib(OpRegistry::Global(), optimized_graph->library()); -- cgit v1.2.3 From e18f84a394bcbde62b344a3b32e8d8fd248fea58 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 02:01:18 -0700 Subject: compat: Update forward compatibility horizon to 2018-09-11 PiperOrigin-RevId: 212414205 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index af58a6f841..60ebae19ab 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 10) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 11) @tf_export("compat.forward_compatible") -- cgit v1.2.3 From 9fd56039064871a736bb7cff398b2a8e08454bee Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 11 Sep 2018 05:34:31 -0700 Subject: Fix a typo in cudnn_convolution_rewriter. PiperOrigin-RevId: 212436340 --- tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 3d1266355b..228379a248 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -263,7 +263,7 @@ MatchBackwardInput(HloInstruction* conv) { !(window_util::HasBaseDilation(conv->window()) && (reverse_filter->IsConstant() || is_1x1_filter))) { VLOG(1) << "Can't match to backwards convolution. Either filter is not " - "kReverse, or it's not a base-dialted conv with a 1x1 or " + "kReverse, or it's not a base-dilated conv with a 1x1 or " "constant filter."; return no_match_result; } -- cgit v1.2.3 From 87d440506547d5c549261922c268aa55badf0bc4 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 11 Sep 2018 06:09:38 -0700 Subject: Fix 31 ClangTidy - Readability findings in //tensorflow/compiler/xla/. * redundant string conversion * using decl 'Eq' is unused * using decl 'HasSubstr' is unused * redundant StrCat calls * please use StrAppend instead of StrCat when appending to an existing string (4 times) * parameters of type 'absl::Span<...>' should be taken by value (23 times) PiperOrigin-RevId: 212439742 --- tensorflow/compiler/xla/client/xla_builder.cc | 2 +- tensorflow/compiler/xla/reference_util.cc | 47 +++++++++----------- tensorflow/compiler/xla/reference_util.h | 50 ++++++++++------------ .../xla/service/gpu/while_transformer_test.cc | 3 -- .../compiler/xla/service/hlo_graph_dumper.cc | 5 +-- .../compiler/xla/tests/reduce_window_test.cc | 8 ++-- 6 files changed, 49 insertions(+), 66 deletions(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 4e1ff9e5c0..8951e93ee6 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2419,7 +2419,7 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { - instr.set_name(StrCat(instr.opcode())); + instr.set_name(instr.opcode()); } for (const auto& operand : operands) { if (operand.builder_ == nullptr) { diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 05325367f5..ceb5e74db7 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -186,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( - const absl::Span& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); @@ -218,10 +217,9 @@ ReferenceUtil::ReduceWindow1DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DAdd(const absl::Span& operand, - float init, - const absl::Span& window, - const absl::Span& stride, +ReferenceUtil::ReduceWindow1DAdd(absl::Span operand, float init, + absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{static_cast(operand.size())}; @@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd(const absl::Span& operand, ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.height(), operand.width()}; std::vector window_counts(window.size(), 0); @@ -273,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( - const Array2D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{operand.height(), operand.width()}; return ReduceWindow2DGeneric( @@ -284,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( - const Array3D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -332,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + absl::Span window, absl::Span stride, + Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; return ReduceWindow4DGeneric( @@ -345,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; @@ -399,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( - const Array4D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, padding); @@ -425,8 +418,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::SelectAndScatter4DGePlus(const Array4D& operand, const Array4D& source, float init, - const absl::Span& window, - const absl::Span& stride, + absl::Span window, + absl::Span stride, bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; auto result = absl::make_unique>(operand.n1(), operand.n2(), diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 9ce098029d..8654fbb9b5 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -177,47 +177,41 @@ class ReferenceUtil { // Windowed reductions with Add as the function to apply. static std::unique_ptr> ReduceWindow1DAdd( - const absl::Span& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + absl::Span operand, float init, + absl::Span window, absl::Span stride, + Padding padding); static std::unique_ptr> ReduceWindow2DAdd( - const Array2D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow3DAdd( - const Array3D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( - const Array4D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); // Windowed reductions with a generic reduce function. static std::unique_ptr> ReduceWindow1DGeneric( - const absl::Span& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, Padding padding); + absl::Span window, absl::Span stride, + Padding padding); // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); // Batch normalize data. static std::unique_ptr> BatchNorm4D( @@ -230,8 +224,8 @@ class ReferenceUtil { // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, - const absl::Span& window, - const absl::Span& stride, bool same_padding); + absl::Span window, absl::Span stride, + bool same_padding); // Concatenates the lhs and rhs arrays along the concatenate_dimension. // E.g. if concatenate_dimension is 0, the "n1"/height dimension is @@ -332,8 +326,8 @@ class ReferenceUtil { // Slices with index clamping template - static std::vector ClampSlice1D(const absl::Span& input, - int64 start, int64 size) { + static std::vector ClampSlice1D(absl::Span input, int64 start, + int64 size) { start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 40183de96e..9a61f8ac5a 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -26,9 +26,6 @@ limitations under the License. namespace xla { namespace { -using ::testing::Eq; -using ::testing::HasSubstr; - class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index d52f4e5a61..4826bff19e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -469,9 +469,8 @@ stylesheet=< string graph_label = StrCat(label_, "
Computation ", computation_->name()); if (computation_->IsFusionComputation()) { - StrAppend(&graph_label, - StrCat(" (in fusion instruction ", - computation_->FusionInstruction()->name(), ")")); + StrAppend(&graph_label, " (in fusion instruction ", + computation_->FusionInstruction()->name(), ")"); } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index d5de9650f1..63491a90bf 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -588,7 +588,7 @@ string R4ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -980,7 +980,7 @@ string R3ReduceWindowTestDataToString( param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1121,7 +1121,7 @@ string R2ReduceWindowTestDataToString( param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1322,7 +1322,7 @@ string R1ReduceWindowTestDataToString( "__pad_high_", absl::StrJoin(param.pad_high, "x"), "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } -- cgit v1.2.3 From de5ddd51e32c4630e63c0cb3e960c69f9ac77662 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 09:10:11 -0700 Subject: Add more description for a common use case of SequenceExample. PiperOrigin-RevId: 212462406 --- tensorflow/core/example/example.proto | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto index e7142a4ef9..e36e51d8d5 100644 --- a/tensorflow/core/example/example.proto +++ b/tensorflow/core/example/example.proto @@ -199,7 +199,13 @@ message Example { // to determine if all features within the FeatureList must // have the same size. The same holds for this FeatureList across multiple // examples. -// +// - For sequence modeling, e.g.: +// http://colah.github.io/posts/2015-08-Understanding-LSTMs/ +// https://github.com/tensorflow/nmt +// the feature lists represent a sequence of frames. +// In this scenario, all FeatureLists in a SequenceExample have the same +// number of Feature messages, so that the ith element in each FeatureList +// is part of the ith frame (or time step). // Examples of conformant and non-conformant examples' FeatureLists: // // Conformant FeatureLists: -- cgit v1.2.3 From 847b38406a28546991b62193278ee87910cd3d74 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 11 Sep 2018 09:31:42 -0700 Subject: TFTS: Fix an input statistics race condition The fix is straightforward enough, although the triggering circumstances are still a bit mysterious. The unit test did fail with ubsan prior to this CL, so I'm going to leave it at that for now. PiperOrigin-RevId: 212465732 --- .../contrib/timeseries/python/timeseries/estimators_test.py | 9 +++++++++ tensorflow/contrib/timeseries/python/timeseries/math_utils.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index 461fe22210..83260fc59a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -216,6 +216,15 @@ class TimeSeriesRegressorTest(test.TestCase): exogenous_feature_columns=exogenous_feature_columns) self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype) + def test_structural_ensemble_numpy_input(self): + numpy_data = {"times": numpy.arange(50), + "values": numpy.random.normal(size=[50])} + estimators.StructuralEnsembleRegressor( + num_features=1, periodicities=[], model_dir=self.get_temp_dir(), + config=_SeedRunConfig()).train( + input_pipeline.WholeDatasetInputFn( + input_pipeline.NumpyReader(numpy_data)), + steps=1) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 9b593fecbb..03da2b82e5 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -896,8 +896,8 @@ class InputStatisticsFromMiniBatch(object): statistics.total_observation_count, math_ops.cast( gen_math_ops.round( - math_ops.cast(auxiliary_variables.max_time_seen - - statistics.start_time + 1, self._dtype) / + math_ops.cast(max_time_seen_assign - + start_time_update + 1, self._dtype) / inter_observation_duration_estimate), dtypes.int64)) per_chunk_stat_updates = control_flow_ops.group( overall_feature_mean_update, overall_feature_var_update, -- cgit v1.2.3 From ac60b46e2c5962fd8099a4406c1788d826ad3c0d Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Tue, 11 Sep 2018 09:33:04 -0700 Subject: Automated rollback of commit 45965cfd8b54fb113275ffdaced5366e28aa3553 PiperOrigin-RevId: 212465918 --- tensorflow/compiler/jit/BUILD | 6 - .../compiler/jit/encapsulate_subgraphs_pass.cc | 17 - .../compiler/jit/encapsulate_subgraphs_pass.h | 6 - .../jit/encapsulate_xla_computations_pass.cc | 360 --------------------- .../jit/encapsulate_xla_computations_pass.h | 61 ---- .../jit/encapsulate_xla_computations_pass_test.cc | 346 -------------------- .../jit/jit_compilation_pass_registration.cc | 7 - tensorflow/compiler/jit/ops/xla_ops.cc | 19 -- tensorflow/compiler/tf2xla/BUILD | 1 - tensorflow/compiler/tf2xla/cc/BUILD | 4 +- tensorflow/compiler/tf2xla/test_util.cc | 8 - tensorflow/compiler/tf2xla/test_util.h | 16 - .../core/common_runtime/graph_execution_state.cc | 4 - .../core/grappler/optimizers/meta_optimizer.cc | 23 -- 14 files changed, 1 insertion(+), 877 deletions(-) delete mode 100644 tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc delete mode 100644 tensorflow/compiler/jit/encapsulate_xla_computations_pass.h delete mode 100644 tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 352f63bc98..a989f15a1c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -362,7 +362,6 @@ cc_library( "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", - "encapsulate_xla_computations_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -371,7 +370,6 @@ cc_library( "build_xla_launch_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", - "encapsulate_xla_computations_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -398,7 +396,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -477,7 +474,6 @@ tf_cc_test( size = "small", srcs = [ "encapsulate_subgraphs_pass_test.cc", - "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], @@ -493,9 +489,7 @@ tf_cc_test( "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", - "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index e0632ff7e4..ae7a22f451 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include -#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -59,22 +58,6 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; -void SortControlInputs(GraphDef* gdef) { - int64 num_nodes = gdef->node_size(); - for (int64 i = 0; i < num_nodes; ++i) { - NodeDef* node = gdef->mutable_node(i); - // Stable sort control inputs and leave the order of data inputs unchanged. - std::stable_sort(node->mutable_input()->begin(), - node->mutable_input()->end(), - [](const string& a, const string& b) { - bool a_is_control = absl::StartsWith(a, "^"); - bool b_is_control = absl::StartsWith(b, "^"); - return (!a_is_control && b_is_control) || - (a_is_control && b_is_control && a < b); - }); - } -} - namespace { bool AreAllParentsGuaranteedConst( diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 90354a801a..926589546f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -102,12 +102,6 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; -// Sorts each node's control inputs by their names. This guarantees that for two -// structually equivalent GraphDefs, we get the same traversal ordering on -// node's control input fields. -// TODO(hpucha): Move the utilities to a more appropriate place. -void SortControlInputs(GraphDef* gdef); - class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc deleted file mode 100644 index 97ef8cd3cb..0000000000 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ /dev/null @@ -1,360 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/fingerprint.h" - -namespace tensorflow { - -const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = - "_xla_compile_id"; - -namespace { - -const char* const kXlaClusterOutput = "XlaClusterOutput"; - -// Checks if a graph node is marked to be a guaranteed constant. -bool is_guaranteed_constant(const Node& n) { - bool guaranteed_constant = false; - if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) - .ok()) { - return false; - } - return guaranteed_constant; -} - -// Finds the `index` of an _Arg or _Retval node. -Status GetIndexAttr(const Node& n, int num_args, int* index) { - TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); - if (*index < 0 || *index >= num_args) { - return errors::InvalidArgument("Invalid ", n.type_string(), " number ", - *index); - } - return Status::OK(); -} - -// Returns the data type of the destination of an edge. -DataType EdgeType(const Edge* edge) { - return edge->dst()->input_type(edge->dst_input()); -} - -// Adds the control inputs of `node` to `*deps`. -void AddControlInputs(const Node& node, gtl::FlatSet* deps) { - for (const Edge* edge : node.in_edges()) { - if (edge->IsControlEdge()) { - deps->insert(edge->src()); - } - } -} - -// Adds the control outputs of `node` to `*deps`. -void AddControlOutputs(const Node& node, gtl::FlatSet* deps) { - for (const Edge* edge : node.out_edges()) { - if (edge->IsControlEdge()) { - deps->insert(edge->dst()); - } - } -} - -// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts -// the arguments into the order expected by XlaLaunch computations: -// 1) arguments -// 2) resource variable arguments -// See the documentation of EncapsulateSubgraphsInFunctions for the meaning -// of the arguments. -// -// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. -Status RewriteSubgraph(const std::vector& arg_source_tensors, - std::unique_ptr* graph_ptr, - std::vector* input_permutation, - std::vector* output_permutation, - NodeDef* call_def) { - Graph* graph = graph_ptr->get(); - const int num_args = input_permutation->size(); - const int num_retvals = output_permutation->size(); - - std::vector args; - std::vector retvals; - args.reserve(num_args); - retvals.reserve(num_retvals); - for (Node* n : graph->nodes()) { - if (n->type_string() == "_Arg") { - // Check if this is a guaranteed constant. - if (is_guaranteed_constant(*n)) { - return errors::InvalidArgument( - "Guaranteed constants are not supported (", n->name(), ")"); - } - args.push_back(n); - } else if (n->type_string() == "_Retval") { - retvals.push_back(n); - } - } - - if (std::find(args.begin(), args.end(), nullptr) != args.end()) { - return errors::InvalidArgument("Missing or non-consecutive arguments"); - } - - // Reorders the arguments. - std::sort(args.begin(), args.end(), [&](Node* a, Node* b) { - // Non-resources appear before resources - bool a_is_resource = (a->output_type(0) == DT_RESOURCE); - bool b_is_resource = (b->output_type(0) == DT_RESOURCE); - // Uses the name as a tiebreaker so the output is deterministic. - StringPiece a_name(a->name()); - StringPiece b_name(b->name()); - return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); - }); - - // Sorts the retvals by name so the order is deterministic. - std::sort(retvals.begin(), retvals.end(), - [](Node* a, Node* b) { return a->name() < b->name(); }); - - // Computes the permutation to produce the correct argument order, and update - // the argument indices. - int variable_start_index = num_args; - for (int i = 0; i < num_args; ++i) { - int index; - TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index)); - if (args[i]->output_type(0) == DT_RESOURCE && - variable_start_index == num_args) { - variable_start_index = i; - } - (*input_permutation)[index] = i; - args[i]->AddAttr("index", i); - } - VLOG(4) << "variable_start_index: " << variable_start_index; - - // Computes the permutation to produce the correct retval order, and update - // the argument indices. - for (int i = 0; i < num_retvals; ++i) { - int index; - TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index)); - (*output_permutation)[index] = i; - retvals[i]->AddAttr("index", i); - } - - AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), - call_def); - AddNodeAttr("_variable_start_index", variable_start_index, call_def); - - // Uniquify the function name. - GraphDef gdef; - graph->ToGraphDef(&gdef); - - // Before serialization, sort each node's control inputs to achieve - // determinism. Sorting control inputs could help (but not necessarily) create - // a deterministic serialization and fingerprint. Other sources of - // nondeterminism include unstable node ordering. - SortControlInputs(&gdef); - // Fingerprint the function. - // Nondeterminism in serialization would not lead to incorrect results, but - // may cause spurious cache misses. DeterministicSerialization is a - // best-effort deterministic serialization. - string serialized; - TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); - uint64 fingerprint = Fingerprint64(serialized); - LOG(INFO) << "Subgraph fingerprint:" << fingerprint; - call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); - return Status::OK(); -} - -} // namespace - -/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( - std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { - // Check for undeclared outputs before Encapsulation, so we can give a better - // error message. - // TODO(phawkins): merge this with the encapsulation code to avoid the extra - // O(n) pass over the edges. - for (const Edge* e : (*graph)->edges()) { - if (!e->IsControlEdge() && - e->src()->attrs().Find(kXlaClusterAttr) != nullptr && - e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && - e->dst()->type_string() != kXlaClusterOutput) { - return errors::InvalidArgument( - "Undeclared output of XLA computation. A common cause of this error " - "is variable initializers that depend on the XLA computation. Edge: ", - e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", - e->dst_input()); - } - } - - auto output = absl::make_unique((*graph)->op_registry()); - TF_RETURN_WITH_CONTEXT_IF_ERROR( - EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, "", **graph, RewriteSubgraph, - /*reuse_existing_functions=*/true, &output, flib_def), - "EncapsulateXlaComputationsPass failed"); - graph->swap(output); - return Status::OK(); -} - -/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( - Graph* graph) { - // Finds all of the XlaLaunch function calls, to avoid mutating the graph - // while iterating. - std::vector launch_nodes; - for (Node* n : graph->nodes()) { - string name; - if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { - launch_nodes.push_back(n); - } - } - - // Replaces each launch function call together with its neighboring - // XlaClusterOutput nodes with a XlaLaunch node. - for (Node* launch : launch_nodes) { - int variable_start_index; - TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index", - &variable_start_index)); - - std::vector in_edges; - TF_RETURN_IF_ERROR(launch->input_edges(&in_edges)); - - const int num_inputs = in_edges.size(); - const int num_variables = num_inputs - variable_start_index; - const int num_args = variable_start_index; - - VLOG(4) << "Launch node '" << launch->name() << "'" - << " input edges: " << in_edges.size() << " num_args: " << num_args - << " num_variables: " << num_variables; - - std::vector nodes_to_remove = {launch}; - - // Data and control inputs to the new XlaLaunch node. - std::vector> data_inputs(num_inputs); - gtl::FlatSet control_inputs; - DataTypeVector arg_types(num_args); - - AddControlInputs(*launch, &control_inputs); - - for (int i = 0; i < num_args; ++i) { - const Edge* edge = in_edges[i]; - data_inputs[i] = {edge->src(), edge->src_output()}; - arg_types[i] = EdgeType(edge); - } - - // Appends the variable inputs. - for (int i = 0; i < num_variables; ++i) { - int pos = variable_start_index + i; - const Edge* edge = in_edges[pos]; - data_inputs[pos] = {edge->src(), edge->src_output()}; - } - - // Outputs. - const int num_outputs = launch->output_types().size(); - gtl::FlatSet control_outputs; - std::vector>> data_outputs(num_outputs); - DataTypeVector output_types(num_outputs); - - for (const Edge* le : launch->out_edges()) { - if (le->IsControlEdge()) { - control_outputs.insert(le->dst()); - } else { - TF_RET_CHECK(le->src_output() < num_outputs); - Node* output_node = le->dst(); - - TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput) - << le->DebugString(); - nodes_to_remove.push_back(output_node); - - for (const Edge* oe : output_node->out_edges()) { - TF_RET_CHECK(!oe->IsControlEdge()); - data_outputs[le->src_output()].push_back( - {oe->dst(), oe->dst_input()}); - } - output_types[le->src_output()] = output_node->input_type(0); - - AddControlOutputs(*output_node, &control_outputs); - } - } - - NodeDef def; - def.set_name(launch->name()); - - // Target the XLA CPU/GPU backends. - VLOG(2) << "Replacing with XlaLaunch"; - def.set_op("XlaLaunch"); - AddNodeAttr("Tconstants", DataTypeVector{}, &def); - AddNodeAttr("Targs", arg_types, &def); - AddNodeAttr("Nresources", num_variables, &def); - AddNodeAttr("Tresults", output_types, &def); - NameAttrList function; - function.set_name(launch->type_string()); - AddNodeAttr("function", function, &def); - - for (Node* node : nodes_to_remove) { - VLOG(2) << "Deleting node " << node->DebugString(); - // Ensure that we do not attempt to add control edges to nodes that are - // deleted. - control_inputs.erase(node); - control_outputs.erase(node); - graph->RemoveNode(node); - } - - Status status; - Node* xla_launch = graph->AddNode(def, &status); - if (!status.ok()) { - return status; - } - for (int i = 0; i < data_inputs.size(); ++i) { - graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, - i); - } - for (Node* n : control_inputs) { - graph->AddControlEdge(n, xla_launch); - } - for (int i = 0; i < data_outputs.size(); ++i) { - for (const auto& successor : data_outputs[i]) { - graph->AddEdge(xla_launch, i, successor.first, successor.second); - } - } - for (Node* n : control_outputs) { - graph->AddControlEdge(xla_launch, n); - } - } - return Status::OK(); -} - -Status EncapsulateXlaComputationsPass::Run( - const GraphOptimizationPassOptions& options) { - VLOG(1) << "EncapsulateXlaComputations(): " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", - **options.graph, options.flib_def); - - TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); - VLOG(1) << "EncapsulateXlaComputations() half-way: " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", - **options.graph, options.flib_def); - - TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); - VLOG(1) << "EncapsulateXlaComputations() finished: " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", - **options.graph, options.flib_def); - return Status::OK(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h deleted file mode 100644 index c8bb4dc114..0000000000 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Rewrites computations generated by the xla.compile() Python code into -// XlaLaunch nodes. -// -// xla.compile() does two main things: -// a) marks operators that make up a XLA computation with the attribute -// _xla_compile_id=XYZ, where XYZ is a unique key. -// b) adds XlaClusterOutput nodes to represent outputs of the computation. -// These nodes are not marked with the _xla_compile_id attribute. - -#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ -#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ - -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/platform/env.h" - -namespace tensorflow { - -// Encapsulates nodes marked with the _xla_compile_id attribute into -// XlaLaunch operators. -class EncapsulateXlaComputationsPass : public GraphOptimizationPass { - public: - static const char* const kXlaClusterAttr; // _xla_compile_id - - Status Run(const GraphOptimizationPassOptions& options) override; - - // The following methods are public only for unit tests. - - // This pass has two stages: - // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes - // marked with the same _xla_compile_id attribute into functions. These - // functions contain the computations to be passed to XlaLaunch. During - // encapsulation, we sort the arguments into the order expected by - // XlaLaunch. - static Status Encapsulate(std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def); - - // b) we rewrite the function calls generated in phase (a) into XlaLaunch - // operators. We also convert the XlaClusterOutput output nodes of the - // function call into the outputs of the XlaLaunch operator. - static Status BuildXlaLaunchOps(Graph* graph); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc deleted file mode 100644 index f643fb0cfe..0000000000 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ /dev/null @@ -1,346 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" - -#include "tensorflow/cc/ops/function_ops.h" -#include "tensorflow/cc/ops/resource_variable_ops.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" -#include "tensorflow/compiler/tf2xla/test_util.h" -#include "tensorflow/core/framework/graph_to_functiondef.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/util/equal_graph_def.h" -#include "tensorflow/core/util/ptr_util.h" - -namespace tensorflow { - -static std::unique_ptr MakeOuterGraph( - const FunctionLibraryDefinition& flib_def, const string& function) { - Scope scope = Scope::NewRootScope().ExitOnError(); - TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); - - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - NodeDef def; - TF_CHECK_OK( - NodeDefBuilder("launch0", function, &flib_def) - .Input(a.node()->name(), 0, DT_INT32) - .Input(b.node()->name(), 0, DT_FLOAT) - .Input(c.node()->name(), 0, DT_INT32) - .Input(d.node()->name(), 0, DT_FLOAT) - .Input(u.node()->name(), 0, DT_RESOURCE) - .Input(v.node()->name(), 0, DT_RESOURCE) - .Input(w.node()->name(), 0, DT_RESOURCE) - .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") - .Attr("_variable_start_index", 4) - .Finalize(&def)); - - Status status; - Node* launch = scope.graph()->AddNode(def, &status); - TF_CHECK_OK(status); - TF_CHECK_OK(scope.DoShapeInference(launch)); - scope.graph()->AddEdge(a.node(), 0, launch, 0); - scope.graph()->AddEdge(b.node(), 0, launch, 1); - scope.graph()->AddEdge(c.node(), 0, launch, 2); - scope.graph()->AddEdge(d.node(), 0, launch, 3); - scope.graph()->AddEdge(u.node(), 0, launch, 4); - scope.graph()->AddEdge(v.node(), 0, launch, 5); - scope.graph()->AddEdge(w.node(), 0, launch, 6); - - auto out0 = - ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0)); - auto out1 = - ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1)); - auto out2 = - ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2)); - auto out3 = - ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3)); - - auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); - auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); - auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); - auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); - auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); - auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); - - std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -// Makes an encapsulate body graph for use in tests. -static std::unique_ptr MakeBodyGraph() { - Scope scope = Scope::NewRootScope().ExitOnError(); - - auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); - auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); - - auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); - auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); - auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); - - auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); - }; - - auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); - - auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); - add_attrs(read_u.node()); - auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); - add_attrs(read_v.node()); - auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); - add_attrs(read_w.node()); - - auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); - add_attrs(e.node()); - auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); - add_attrs(f.node()); - auto g = ops::Add(scope.WithOpName("G"), f, arg3); - add_attrs(g.node()); - - auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), - b_identity, 0); - auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); - auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); - auto out3 = - ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); - - std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { - // Test that control edge insertion order doesn't affect the cache key - // (cluster name) generated by TPU encapsulate pass. - auto get_serialized_graph = [](bool control_input_reversed, - bool operand_reversed) -> string { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - std::unique_ptr graph(new Graph(&flib_def)); - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); - auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); - - ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) - : ops::Add(scope.WithOpName("E"), a1, a0); - - auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, - "launch0"); - }; - add_attrs(e.node()); - - TF_CHECK_OK(scope.ToGraph(graph.get())); - auto get_node_in_graph = [&graph](Node* node) { - return graph->FindNodeId(node->id()); - }; - // Insert control edge in different order. The order should not affect - // the encapsulated or serialized graph. - if (!control_input_reversed) { - graph->AddControlEdge(get_node_in_graph(a0.node()), - get_node_in_graph(e.node()), true); - graph->AddControlEdge(get_node_in_graph(a1.node()), - get_node_in_graph(e.node()), true); - } else { - graph->AddControlEdge(get_node_in_graph(a1.node()), - get_node_in_graph(e.node()), true); - graph->AddControlEdge(get_node_in_graph(a0.node()), - get_node_in_graph(e.node()), true); - } - } - TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); - GraphDef gdef; - graph->ToGraphDef(&gdef); - // Before serialization, sort control inputs first to remove - // nondeterminism. - SortControlInputs(&gdef); - string serialized; - SerializeToStringDeterministic(gdef, &serialized); - return serialized; - }; - - // Changing the order of control input shouldn't affect the graph generated. - EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, - /*operand_reversed=*/false), - get_serialized_graph(/*control_input_reversed=*/false, - /*operand_reversed=*/false)); - - // Changing the order of data input should affect the graph generated. - EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, - /*operand_reversed=*/true), - get_serialized_graph(/*control_input_reversed=*/false, - /*operand_reversed=*/false)); -} - -TEST(EncapsulateXlaComputations, Encapsulate) { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - std::unique_ptr graph(new Graph(&flib_def)); - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); - }; - - auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); - add_attrs(b_identity.node()); - - auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT); - add_attrs(read_u.node()); - auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); - add_attrs(read_v.node()); - auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); - add_attrs(read_w.node()); - - auto e = ops::Add(scope.WithOpName("E"), a, c); - add_attrs(e.node()); - auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); - add_attrs(f.node()); - auto g = ops::Add(scope.WithOpName("G"), f, d); - add_attrs(g.node()); - - auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity); - auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e); - auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g); - auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u); - - auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); - auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); - auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); - auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); - auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); - auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); - TF_ASSERT_OK(scope.ToGraph(graph.get())); - } - - std::unique_ptr graph_copy(new Graph(&flib_def)); - CopyGraph(*graph, graph_copy.get()); - - TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); - - std::unordered_map index = BuildNodeIndex(*graph); - string function = index.at("launch0")->type_string(); - - // Tests the outer graph is as expected. - { - std::unique_ptr outer = MakeOuterGraph(flib_def, function); - GraphDef expected_def; - outer->ToGraphDef(&expected_def); - - GraphDef actual_def; - graph->ToGraphDef(&actual_def); - TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); - } - - // Tests the encapsulated body graph is as expected. - { - std::unique_ptr body = MakeBodyGraph(); - GraphDef expected_body_def; - body->ToGraphDef(&expected_body_def); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, - DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}), - result.ret_types); - TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); - } - - // Encapsulates the same computation again, verifies we reuse the same - // function. Encapsulation should be deterministic to avoid recompilation. - TF_ASSERT_OK( - EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); - std::unordered_map index_copy = BuildNodeIndex(*graph_copy); - string function_copy = index_copy.at("launch0")->type_string(); - EXPECT_EQ(function, function_copy); -} - -TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { - std::unique_ptr body_graph = MakeBodyGraph(); - FunctionDefLibrary flib; - TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function())); - - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - - std::unique_ptr graph = MakeOuterGraph(flib_def, "launch0"); - TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get())); - - Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); - TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); - - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - NameAttrList function; - function.set_name("launch0"); - auto launch = ops::XlaLaunch( - scope.WithOpName("launch0"), std::initializer_list{}, - std::initializer_list{a, b, c, d}, - std::initializer_list{u, v, w}, - DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); - - auto consumer0_a = - ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); - auto consumer0_b = - ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); - auto consumer0_c = - ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); - auto consumer1 = - ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); - auto consumer2 = - ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); - auto consumer3 = - ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); - - GraphDef expected_def; - TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); - - GraphDef actual_def; - graph->ToGraphDef(&actual_def); - TF_EXPECT_GRAPH_EQ(expected_def, actual_def); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 315fcb2fa7..c37b6112cc 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -15,19 +15,12 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { -// EncapsulateXlaComputationsPass rewrites computations generated by the -// xla.compile() Python code into XlaLaunch nodes. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, - EncapsulateXlaComputationsPass); - -// The following POST_REWRITE passes support auto-clustering to enable XLA. REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 1a29c3caab..f2473d98ff 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,14 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { -using shape_inference::InferenceContext; - REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") @@ -36,19 +32,4 @@ REGISTER_OP("XlaLaunch") .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); -REGISTER_OP("XlaClusterOutput") - .Input("input: T") - // Note: when replication is supported, this op will have N outputs. - .Output("outputs: T") - .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { - for (int i = 0; i < c->num_outputs(); ++i) { - c->set_output(i, c->input(0)); - } - return Status::OK(); - }) - .Doc( - "Operator that connects the output of an XLA computation to other " - "consumer graph nodes."); - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 74b131e07e..ab289a2b6c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -594,7 +594,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 8ac5eb5df9..ea8d1b3d14 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -31,9 +31,7 @@ cc_library( tf_gen_op_wrapper_cc( name = "xla_jit_op_gen", out_ops_file = "ops/xla_jit_op", - deps = [ - "//tensorflow/compiler/jit/ops:xla_ops", - ], + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) cc_library( diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index f31bfb45a2..3c6c9a91b6 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,12 +40,4 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } -std::unordered_map BuildNodeIndex(const Graph& graph) { - std::unordered_map index; - for (Node* node : graph.nodes()) { - index[node->name()] = node; - } - return index; -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index 350a868568..e6e4ae92ed 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -24,10 +24,8 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -44,20 +42,6 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); -// Builds a map from node name to Node* for `graph`. -std::unordered_map BuildNodeIndex(const Graph& graph); - } // namespace tensorflow -// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for -// equality. -#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ - do { \ - string diff; \ - EqualGraphDefOptions eq_options; \ - eq_options.ignore_internal_attrs = false; \ - EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ - << diff << "\nActual: " << SummarizeGraphDef(actual); \ - } while (false) - #endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 4475fa979e..7f260b3139 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -561,10 +561,6 @@ Status GraphExecutionState::OptimizeGraph( grappler::GrapplerItem item; item.id = "tf_graph"; graph_->ToGraphDef(&item.graph); - // TODO(b/114748242): Add a unit test to test this bug fix. - if (flib_def_) { - *item.graph.mutable_library() = flib_def_->ToProto(); - } item.fetch.insert(item.fetch.end(), options.callable_options.fetch().begin(), diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index b75d6303b4..a5fd33d28b 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -72,16 +72,6 @@ bool IsRunOnceOptimizer(const string& name) { name == "loop_optimizer"; } -// Check if the graphdef contains nodes that indicate TPU execution. -bool IsTPUGraphDef(const GraphDef& def) { - for (auto node : def.node()) { - if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") { - return true; - } - } - return false; -} - } // namespace #define MK_OPT(NAME, VALUE) \ @@ -346,19 +336,6 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // 1. Optimize main graph TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph)); - // Skip optimizing functions if this is a TPU graph. Currently, Grappler - // passes do not handle TPU functions correctly in a variety of ways (Note - // that due to the pre-placement TPU graph rewriting passes, the TPU-related - // ops are encapsulated away into functions). For example, TPU graphs contain - // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler - // passes could prune that away. Grappler passes could also cause issues - // around shape inference. Since the desired and existing behavior is to not - // optimize TPU functions with Grappler, this check preserves that. - if (IsTPUGraphDef(*optimized_graph)) { - VLOG(2) << "Skipping optimizing funcs for TPU graphs"; - return Status::OK(); - } - // 2. Optimize function library FunctionLibraryDefinition flib(OpRegistry::Global(), optimized_graph->library()); -- cgit v1.2.3 From 624ff13fdf4e54e255d23971ef2beec3c48c3bb2 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 11 Sep 2018 09:35:09 -0700 Subject: PR #21826: merge_repeated option is confusing Please approve this CL. It will be submitted automatically, and its GitHub pull request will be marked as merged. Imported from GitHub PR #21826 I have the same question with [WIP: Remove invalid merge_repeated option from CTC beam decoder](#15586), it's a pity I haven't seen any changes for so long. Generally I will use the default value of merge_repeated: True, but I found it's confusing, that is, I got the wrong anser, it has been explained well in [WIP: Remove invalid merge_repeated option from CTC beam decoder](#15586). And the top path in ctc_beam_search_decoder is similar with sequence in ctc_greedy_decoder, this is confusing, I have found the project [CRNN](https://github.com/Belval/CRNN/blob/master/CRNN/crnn.py)(line 167) and some other projects use the wrong settings. So I think it's better to give a explain here, this has no conflict with the existing code. Copybara import of the project: - e357bcea4b10d5e5cbc3a4ba59385e832401ba8d merge_repeated option is confusing by Dao Zhang - a0467d35cc19293fa16918658a7f98e18ead7f87 Merge e357bcea4b10d5e5cbc3a4ba59385e832401ba8d into 34ef4... by Dao Zhang(??) PiperOrigin-RevId: 212466200 --- tensorflow/python/ops/ctc_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 908e793902..32d455bdad 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -242,11 +242,11 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, If `merge_repeated` is `True`, merge repeated classes in the output beams. This means that if consecutive entries in a beam are the same, - only the first of these is emitted. That is, when the top path - is `A B B B B`, the return value is: + only the first of these is emitted. That is, when the sequence is + `A B B * B * B` (where '*' is the blank label), the return value is: * `A B` if `merge_repeated = True`. - * `A B B B B` if `merge_repeated = False`. + * `A B B B` if `merge_repeated = False`. Args: inputs: 3-D `float` `Tensor`, size -- cgit v1.2.3 From 7cfed353d9eb8344d20cd65ecfb5740cff48304c Mon Sep 17 00:00:00 2001 From: Olivia Nordquist Date: Tue, 11 Sep 2018 09:45:29 -0700 Subject: disable tsan for failing test PiperOrigin-RevId: 212467900 --- tensorflow/contrib/saved_model/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index b897224c6d..f687b56ea3 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -123,6 +123,7 @@ py_test( size = "medium", srcs = ["python/saved_model/keras_saved_model_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":keras_saved_model", "//tensorflow/python:client_testlib", -- cgit v1.2.3 From b566170b29c41b0da4c23bf5ce0fdfe19b8bcb14 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Tue, 11 Sep 2018 10:35:30 -0700 Subject: Block tsan for keras_test PiperOrigin-RevId: 212477605 --- tensorflow/python/estimator/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 4001ffdd6b..bfcc019dd5 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -685,6 +685,7 @@ py_test( srcs_version = "PY2AND3", tags = [ "no_windows", + "notsan", # b/67510291 ], deps = [ ":keras", -- cgit v1.2.3 From 36e1a5ea5ba2dd5eaa7f4cfc84a61f8ce3ea20e1 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 11 Sep 2018 10:41:44 -0700 Subject: [TF] Variant improvements. 1. Change Variant Decode to accept VariantTensorData (non-ref). This should allow some optimization in the future. In the meantime it means removing the variant.h include from tensor.h, since variant_encode_decode.h now relies on tensor.h and variant.h now relies on that. It also means we found a bunch of places where tensor.proto.h, variant.h, and mutex.h were being imported through tensor.h (along with a bunch of other crap); so now we directly import them in order to compile. 2. Move Variant registry to use TypeIndex instead of a TypeName string; this should speed up registry lookups. PiperOrigin-RevId: 212478896 --- tensorflow/c/c_api.cc | 1 + tensorflow/c/c_api_experimental.cc | 1 + tensorflow/c/c_api_function.cc | 1 + .../contrib/lite/toco/import_tensorflow_test.cc | 1 + tensorflow/contrib/nccl/BUILD | 24 +-- tensorflow/contrib/nccl/kernels/nccl_rewrite.cc | 1 + tensorflow/core/BUILD | 1 + tensorflow/core/common_runtime/copy_tensor.cc | 2 +- tensorflow/core/common_runtime/rendezvous_util.cc | 1 + .../common_runtime/single_threaded_cpu_device.h | 1 + tensorflow/core/framework/allocator.cc | 9 + tensorflow/core/framework/allocator.h | 11 +- tensorflow/core/framework/allocator_registry.h | 1 + tensorflow/core/framework/attr_value_util_test.cc | 1 + tensorflow/core/framework/tensor.h | 3 +- tensorflow/core/framework/tensor_test.cc | 1 + tensorflow/core/framework/tensor_util.h | 1 + tensorflow/core/framework/types.h | 3 +- tensorflow/core/framework/variant.cc | 25 +-- tensorflow/core/framework/variant.h | 60 ++---- tensorflow/core/framework/variant_encode_decode.h | 32 +-- tensorflow/core/framework/variant_op_copy_test.cc | 6 +- tensorflow/core/framework/variant_op_registry.cc | 85 ++++---- tensorflow/core/framework/variant_op_registry.h | 216 +++++++++++---------- .../core/framework/variant_op_registry_test.cc | 96 ++++----- tensorflow/core/framework/variant_tensor_data.cc | 22 ++- tensorflow/core/framework/variant_tensor_data.h | 10 +- tensorflow/core/framework/variant_test.cc | 15 +- tensorflow/core/kernels/data/iterator_ops.cc | 4 +- tensorflow/core/kernels/data/optional_ops.cc | 7 +- tensorflow/core/kernels/gather_functor.h | 1 + tensorflow/core/kernels/list_kernels.cc | 12 +- tensorflow/core/kernels/list_kernels.cu.cc | 3 +- tensorflow/core/kernels/shape_op_test.cc | 10 +- tensorflow/core/platform/abi.cc | 4 +- tensorflow/core/platform/abi.h | 3 +- 36 files changed, 344 insertions(+), 331 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 173bbea596..79811ceae5 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index c046bd66cd..c195c9e01c 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index a2c5a42c11..f68f8a3e90 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/strings/base64.h" diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index 90e6f698ef..a00e136dd6 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 62996d1fd8..225025e995 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -25,15 +25,17 @@ tf_custom_op_library( name = "python/ops/_nccl_ops.so", srcs = [ "ops/nccl_ops.cc", - ], + ] + if_cuda(["kernels/nccl_rewrite.cc"]), gpu_srcs = if_not_windows_cuda([ "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", ]), - deps = if_cuda([ + deps = [] + if_cuda([ "@local_config_nccl//:nccl", "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:protos_all_proto_text", ]), ) @@ -57,32 +59,30 @@ tf_cuda_cc_test( "notap", ], deps = - [ + if_cuda([ + "@local_config_nccl//:nccl", "//tensorflow/core:cuda", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "@local_config_nccl//:nccl", - ], + ]), ) tf_kernel_library( name = "nccl_kernels", - srcs = [ + srcs = if_cuda([ "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", - "kernels/nccl_rewrite.cc", - ], - deps = [ + ]), + deps = if_cuda([ + "@local_config_nccl//:nccl", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", - "//tensorflow/core:proto_text", "//tensorflow/core:stream_executor", - "@local_config_nccl//:nccl", - ], + ]), alwayslink = 1, ) diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc index 4676e937e5..06ff86e6d8 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/node_builder.h" namespace tensorflow { diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 79ad3b8e54..957aa254e5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -720,6 +720,7 @@ cc_library( name = "abi", srcs = ["platform/abi.cc"], hdrs = ["platform/abi.h"], + deps = [":platform_base"], ) cc_library( diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index f8cb854b52..cf3d1f0b79 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -358,7 +358,7 @@ static Status WrappedTensorDeviceCopy( #define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy) + Tensor, DIRECTION, WrappedTensorDeviceCopy) REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc index 1e3fed0d6f..43ca3f1e3e 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.cc +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/rendezvous_util.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/reffed_status_callback.h" diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h index 04d5af9087..22650b0d83 100644 --- a/tensorflow/core/common_runtime/single_threaded_cpu_device.h +++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 888ed0c57b..2a7ee16a16 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } +void Allocator::RunVariantCtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); +} + +void Allocator::RunVariantDtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); +} + // If true, cpu allocator collects more stats. static bool cpu_allocator_collect_stats = false; // If true, cpu allocator collects full stats. diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 774b1fe137..ded120b704 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -23,12 +23,13 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/type_traits.h" -#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { +class Variant; + // Attributes for a single allocation call. Different calls to the same // allocator could potentially have different allocation attributes. struct AllocationAttributes { @@ -228,13 +229,9 @@ class Allocator { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } - virtual void RunVariantCtor(Variant* p, size_t n) { - for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); - } + virtual void RunVariantCtor(Variant* p, size_t n); - virtual void RunVariantDtor(Variant* p, size_t n) { - for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); - } + virtual void RunVariantDtor(Variant* p, size_t n); // TODO(jeff): Maybe provide some interface to give info about // current allocation state (total number of bytes available for diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h index 24f282ce84..e907c52ba9 100644 --- a/tensorflow/core/framework/allocator_registry.h +++ b/tensorflow/core/framework/allocator_registry.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" namespace tensorflow { diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc index 1a3994736c..4ffd732f8e 100644 --- a/tensorflow/core/framework/attr_value_util_test.cc +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 1b19ab5da3..696fd277cd 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -37,11 +37,12 @@ namespace tensorflow { class AllocationDescription; class Allocator; class OpKernelContext; +class Tensor; class TensorBuffer; class TensorCApi; class TensorDescription; class TensorProto; -class VariantTensorData; + namespace batch_util { Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index); diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 84a373c196..9a78cdc91e 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/math/math_util.h" diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index 4bda8f9eb8..a7cf600bab 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 15b1add2c1..2e96b05787 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -39,6 +38,8 @@ limitations under the License. namespace tensorflow { +class Variant; + // MemoryType is used to describe whether input or output Tensors of // an OpKernel should reside in "Host memory" (e.g., CPU memory) or // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc index 5a507804b0..d43e3c72ec 100644 --- a/tensorflow/core/framework/variant.cc +++ b/tensorflow/core/framework/variant.cc @@ -23,11 +23,11 @@ limitations under the License. namespace tensorflow { -bool Variant::TryDecode(Variant* out) const { - const VariantTensorDataProto* p = get(); - if (p == nullptr) return false; - VariantTensorData data(*p); - return out->Decode(data); +bool Variant::Decode(VariantTensorData data) { + if (!is_empty()) { + return value_->Decode(std::move(data)); + } + return true; } template <> @@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) { template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data) { - data->FromProto(value); + data->FromConstProto(value); } template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value) { - data.ToProto(value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) { + data->ToProto(value); return true; } @@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) { } template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { - return value->ParseFromString(buf); +bool DecodeVariant(string* buf, VariantTensorDataProto* value) { + return value->ParseFromString(*buf); } void EncodeVariantList(const Variant* variant_array, int64 n, @@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr d, if (variant_array[i].is_empty()) { variant_array[i] = VariantTensorDataProto(); } + // TODO(ebrevdo): Replace with StringPiece? Any way to make this a + // zero-copy operation that keeps a reference to the data in d? string str(d->Data(sizes[i]), sizes[i]); - if (!variant_array[i].Decode(str)) return false; + if (!variant_array[i].Decode(std::move(str))) return false; if (!DecodeUnaryVariant(&variant_array[i])) { LOG(ERROR) << "Could not decode variant with type_name: \"" << variant_array[i].TypeName() diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h index 52732801a0..10eabbc85f 100644 --- a/tensorflow/core/framework/variant.h +++ b/tensorflow/core/framework/variant.h @@ -23,7 +23,6 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/core/status.h" @@ -38,17 +37,19 @@ string TypeNameVariant(const T& value); template string DebugStringVariant(const T& value); +// Allows for specializations of Variant Decoding. `data` may be modified in +// the process of decoding to `value`. template -void EncodeVariant(const T& value, VariantTensorData* data); +bool DecodeVariant(VariantTensorData* data, T* value); template -bool DecodeVariant(const VariantTensorData& data, T* value); +bool DecodeVariant(string* buf, T* value); template -void EncodeVariant(const T& value, string* buf); +void EncodeVariant(const T& value, VariantTensorData* data); template -bool DecodeVariant(const string& buf, T* value); +void EncodeVariant(const T& value, string* buf); // This is an implementation of a type-erased container that can store an // object of any type. The implementation is very similar to std::any, but has @@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value); // // string TypeName() const; // void Encode(VariantTensorData* data) const; -// void Decode(const VariantTensorData& data); +// void Decode(VariantTensorData data); // // Simple POD types can elide the Encode/Decode functions, they are provided by // helper methods. @@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value); // x.Encode(&serialized_f); // // Variant y = Foo(); // default constructed Foo. -// y.Decode(&serialized_f); +// y.Decode(std::move(serialized_f)); // EXPECT_EQ(*x.get(), *y.get()); // // @@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value); // EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo. // EXPECT_EQ(MakeTypeIndex(), // y_type_unknown.TypeId()); -// // Decode and get y_type_unknown; compare to value in x. -// Foo f_decoded; -// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded)); -// EXPECT_EQ(f_decoded, f); // class Variant { public: @@ -241,12 +238,7 @@ class Variant { } // Deserialize `data` and update the stored object. - bool Decode(const VariantTensorData& data) { - if (!is_empty()) { - return value_->Decode(data); - } - return true; - } + bool Decode(VariantTensorData data); // Helper methods to directly serialize/deserialize from strings. void Encode(string* buf) const { @@ -254,31 +246,13 @@ class Variant { value_->Encode(buf); } } - bool Decode(const string& buf) { + bool Decode(string buf) { if (!is_empty()) { - return value_->Decode(buf); + return value_->Decode(std::move(buf)); } return true; } - template - bool MaybeDecodeAndCopy(T* out) const { - const T* ret = get(); - if (ret != nullptr) { - *out = std::move(*ret); - return true; - }; - Variant decoded = T(); - if (!TryDecode(&decoded)) return false; - T* decoded_ret = decoded.get(); - CHECK_NOTNULL(decoded_ret); - *out = std::move(*decoded_ret); - return true; - } - - private: - bool TryDecode(Variant* out) const; - private: struct in_place_t {}; static constexpr in_place_t in_place{}; @@ -292,9 +266,9 @@ class Variant { virtual string TypeName() const = 0; virtual string DebugString() const = 0; virtual void Encode(VariantTensorData* data) const = 0; - virtual bool Decode(const VariantTensorData& data) = 0; + virtual bool Decode(VariantTensorData data) = 0; virtual void Encode(string* buf) const = 0; - virtual bool Decode(const string& data) = 0; + virtual bool Decode(string data) = 0; }; template @@ -325,15 +299,13 @@ class Variant { EncodeVariant(value, data); } - bool Decode(const VariantTensorData& data) override { - return DecodeVariant(data, &value); + bool Decode(VariantTensorData data) override { + return DecodeVariant(&data, &value); } void Encode(string* buf) const override { EncodeVariant(value, buf); } - bool Decode(const string& buf) override { - return DecodeVariant(buf, &value); - } + bool Decode(string buf) override { return DecodeVariant(&buf, &value); } T value; }; diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h index f155aa4892..5e08e5a7a6 100644 --- a/tensorflow/core/framework/variant_encode_decode.h +++ b/tensorflow/core/framework/variant_encode_decode.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/abi.h" @@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value, // Specialization for POD type template -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver, T* value) { @@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for tensorflow::Tensor template -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver, T* value) { @@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for protobuf template -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver, T* value) { @@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for other types template -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver, T* value) { - return value->Decode(data); + return value->Decode(std::move(data)); } template @@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) { } template -bool DecodeVariant(const VariantTensorData& data, T* value) { - return DecodeVariantImpl(data, TypeResolver(), value); +bool DecodeVariant(VariantTensorData* data, T* value) { + return DecodeVariantImpl(std::move(*data), TypeResolver(), value); } template @@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) { } template -bool DecodeVariant(const string& buf, T* value) { +bool DecodeVariant(string* buf, T* value) { VariantTensorData data; - if (!data.ParseFromString(buf)) return false; - if (!DecodeVariantImpl(data, TypeResolver(), value)) return false; + if (!data.ParseFromString(*buf)) return false; + if (!DecodeVariantImpl(std::move(data), TypeResolver(), value)) { + return false; + } return true; } // Specializations for VariantTensorDataProto template <> string TypeNameVariant(const VariantTensorDataProto& value); + template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data); + template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value); + template <> void EncodeVariant(const VariantTensorDataProto& value, string* buf); + template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value); +bool DecodeVariant(string* buf, VariantTensorDataProto* value); // Encodes an array of Variant objects in to the given StringListEncoder. // `variant_array` is assumed to point to an array of `n` Variant objects. diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index 60fa7bd559..daa744e877 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "StoredTensorValue", StoredTensorValue::CopyCPUToGPU); + StoredTensorValue::CopyCPUToGPU); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST, - "StoredTensorValue", StoredTensorValue::CopyGPUToCPU); + StoredTensorValue::CopyGPUToCPU); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, - "StoredTensorValue", StoredTensorValue::CopyGPUToGPU); + StoredTensorValue::CopyGPUToGPU); REGISTER_OP("CreateTestVariant") .Input("input: T") diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index ee07db1aee..ef5b240aea 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() { } UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn( - StringPiece type_name) { - auto found = shape_fns.find(type_name); + const TypeIndex& type_index) { + auto found = shape_fns.find(type_index); if (found == shape_fns.end()) return nullptr; return &found->second; } -void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name, +void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index, const VariantShapeFn& shape_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape"; - VariantShapeFn* existing = GetShapeFn(type_name); + VariantShapeFn* existing = GetShapeFn(type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantShapeFn for type_name: " << type_name - << " already registered"; - shape_fns.insert(std::pair( - GetPersistentStringPiece(type_name), shape_fn)); + << "Unary VariantShapeFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered"; + shape_fns.insert(std::pair(type_index, shape_fn)); } Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { @@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { CHECK_EQ(variant_tensor.dims(), 0); const Variant& v = variant_tensor.scalar()(); UnaryVariantOpRegistry::VariantShapeFn* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName()); + UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId()); if (shape_fn == nullptr) { return errors::Internal( - "No unary variant shape function found for Variant type_name: ", - v.TypeName()); + "No unary variant shape function found for Variant type_index: ", + port::MaybeAbiDemangle(v.TypeId().name())); } return (*shape_fn)(v, shape); } @@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) { } // namespace #define REGISTER_VARIANT_SHAPE_TYPE(T) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape); + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape); // No encode/shape registered for std::complex<> and Eigen::half // objects yet. @@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double); UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* UnaryVariantOpRegistry::GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name) { - auto found = device_copy_fns.find(std::make_pair(direction, type_name)); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index) { + auto found = device_copy_fns.find(std::make_pair(direction, type_index)); if (found == device_copy_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterDeviceCopyFn( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy"; - AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name); + AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index); CHECK_EQ(existing, nullptr) << "UnaryVariantDeviceCopy for direction: " << direction - << " and type_name: " << type_name << " already registered"; + << " and type_index: " << port::MaybeAbiDemangle(type_index.name()) + << " already registered"; device_copy_fns.insert( - std::pair, - AsyncVariantDeviceCopyFn>( - std::make_pair(direction, GetPersistentStringPiece(type_name)), - device_copy_fn)); + std::pair, + AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index), + device_copy_fn)); } Status VariantDeviceCopy( @@ -170,35 +167,34 @@ Status VariantDeviceCopy( const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) { UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction, - from.TypeName()); + from.TypeId()); if (device_copy_fn == nullptr) { return errors::Internal( "No unary variant device copy function found for direction: ", - direction, " and Variant type_name: ", from.TypeName()); + direction, " and Variant type_index: ", + port::MaybeAbiDemangle(from.TypeId().name())); } return (*device_copy_fn)(from, to, copy_fn); } // Special casing UnaryOpFn per op and per device. UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn( - VariantUnaryOp op, StringPiece device, StringPiece type_name) { - auto found = unary_op_fns.find({op, device, type_name}); + VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) { + auto found = unary_op_fns.find({op, device, type_index}); if (found == unary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterUnaryOpFn( - VariantUnaryOp op, const string& device, const string& type_name, + VariantUnaryOp op, const string& device, const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp"; - VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name); + VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantUnaryOpFn for type_name: " << type_name + << "Unary VariantUnaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; unary_op_fns.insert(std::pair, VariantUnaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - unary_op_fn)); + {op, GetPersistentStringPiece(device), type_index}, unary_op_fn)); } namespace { @@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \ - DEVICE_CPU, T, TF_STR(T), \ + DEVICE_CPU, T, \ ZerosLikeVariantPrimitiveType); // No zeros_like registered for std::complex<> or Eigen::half objects yet. @@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); // Special casing BinaryOpFn per op and per device. UnaryVariantOpRegistry::VariantBinaryOpFn* UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name) { - auto found = binary_op_fns.find({op, device, type_name}); + const TypeIndex& type_index) { + auto found = binary_op_fns.find({op, device, type_index}); if (found == binary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterBinaryOpFn( - VariantBinaryOp op, const string& device, const string& type_name, + VariantBinaryOp op, const string& device, const TypeIndex& type_index, const VariantBinaryOpFn& add_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp"; - VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name); + VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantBinaryOpFn for type_name: " << type_name + << "Unary VariantBinaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; binary_op_fns.insert(std::pair, VariantBinaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - add_fn)); + {op, GetPersistentStringPiece(device), type_index}, add_fn)); } namespace { @@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b, #define REGISTER_VARIANT_ADD_TYPE(T) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \ - T, TF_STR(T), \ - AddVariantPrimitiveType); + T, AddVariantPrimitiveType); // No add registered for std::complex<> or Eigen::half objects yet. REGISTER_VARIANT_ADD_TYPE(int); diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index e6a2665a56..7eb37e859f 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -22,10 +22,14 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/abi.h" namespace tensorflow { @@ -90,10 +94,11 @@ class UnaryVariantOpRegistry { AsyncVariantDeviceCopyFn; // Add a shape lookup function to the registry. - void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); + void RegisterShapeFn(const TypeIndex& type_index, + const VariantShapeFn& shape_fn); - // Returns nullptr if no shape function was found for the given TypeName. - VariantShapeFn* GetShapeFn(StringPiece type_name); + // Returns nullptr if no shape function was found for the given TypeIndex. + VariantShapeFn* GetShapeFn(const TypeIndex& type_index); // Add a decode function to the registry. void RegisterDecodeFn(const string& type_name, @@ -104,33 +109,33 @@ class UnaryVariantOpRegistry { // Add a copy-to-GPU function to the registry. void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction, - const string& type_name, + const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn); // Returns nullptr if no copy function was found for the given // TypeName and direction. AsyncVariantDeviceCopyFn* GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index); // Add a unary op function to the registry. void RegisterUnaryOpFn(VariantUnaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn); // Returns nullptr if no unary op function was found for the given // op, device, and TypeName. VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, - StringPiece type_name); + const TypeIndex& type_index); // Add a binary op function to the registry. void RegisterBinaryOpFn(VariantBinaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const VariantBinaryOpFn& add_fn); // Returns nullptr if no binary op function was found for the given // op, device and TypeName. VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name); + const TypeIndex& type_index); // Get a pointer to a global UnaryVariantOpRegistry object static UnaryVariantOpRegistry* Global(); @@ -145,24 +150,26 @@ class UnaryVariantOpRegistry { static std::unordered_set* PersistentStringStorage(); private: - std::unordered_map shape_fns; - std::unordered_map - decode_fns; + struct TypeIndexHash { + std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); } + }; + + gtl::FlatMap shape_fns; + gtl::FlatMap decode_fns; // Map std::pair to function. struct PairHash { template - std::size_t operator()(const std::pair& x) const { + std::size_t operator()(const std::pair& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast(std::get<0>(x)); - ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); + ret = Hash64Combine(ret, std::get<1>(x).hash_code()); return ret; } - StringPieceHasher sp_hasher_; }; - std::unordered_map, - AsyncVariantDeviceCopyFn, PairHash> + gtl::FlatMap, + AsyncVariantDeviceCopyFn, PairHash> device_copy_fns; // Map std::tuple to function. @@ -172,10 +179,11 @@ class UnaryVariantOpRegistry { // and references therein template struct FuncTuple { - FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname) - : op_type_(op), device_(dev), typename_(tname){}; + FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index) + : op_type_(op), device_(dev), type_index_(type_index) {} Op op_type_; - StringPiece device_, typename_; + StringPiece device_; + TypeIndex type_index_; }; // friend declaration for operator== // needed for clang @@ -184,11 +192,11 @@ class UnaryVariantOpRegistry { struct TupleHash { template std::size_t operator()( - const std::tuple& x) const { + const std::tuple& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast(std::get<0>(x)); ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); - ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x))); + ret = Hash64Combine(ret, std::get<2>(x).hash_code()); return ret; } @@ -197,14 +205,14 @@ class UnaryVariantOpRegistry { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast(x.op_type_); ret = Hash64Combine(ret, sp_hasher_(x.device_)); - ret = Hash64Combine(ret, sp_hasher_(x.typename_)); + ret = Hash64Combine(ret, x.type_index_.hash_code()); return ret; } StringPieceHasher sp_hasher_; }; - std::unordered_map, VariantUnaryOpFn, TupleHash> + gtl::FlatMap, VariantUnaryOpFn, TupleHash> unary_op_fns; - std::unordered_map, VariantBinaryOpFn, TupleHash> + gtl::FlatMap, VariantBinaryOpFn, TupleHash> binary_op_fns; // Find or insert a string into a persistent string storage @@ -225,7 +233,7 @@ template inline bool operator==(const UnaryVariantOpRegistry::FuncTuple& lhs, const UnaryVariantOpRegistry::FuncTuple& rhs) { return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) && - (lhs.typename_ == rhs.typename_); + (lhs.type_index_ == rhs.type_index_); } // Gets a TensorShape from a Tensor containing a scalar Variant. // Returns an Internal error if the Variant does not have a registered shape @@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, Variant* v_out) { const string& device = DeviceName::value; UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = - UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName()); + UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId()); if (unary_op_fn == nullptr) { return errors::Internal( "No unary variant unary_op function found for unary variant op enum: ", @@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, template Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, const Variant& a, const Variant& b, Variant* out) { - if (a.TypeName() != b.TypeName()) { + if (a.TypeId() != b.TypeId()) { return errors::Internal( "BianryOpVariants: Variants a and b have different " - "type names: '", + "type ids. Type names: '", a.TypeName(), "' vs. '", b.TypeName(), "'"); } const string& device = DeviceName::value; UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = - UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName()); + UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId()); if (binary_op_fn == nullptr) { return errors::Internal( "No unary variant binary_op function found for binary variant op " @@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration { public: typedef std::function LocalVariantShapeFn; - UnaryVariantShapeRegistration(const string& type_name, + UnaryVariantShapeRegistration(const TypeIndex& type_index, const LocalVariantShapeFn& shape_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterShapeFn( - type_name, - [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status { + type_index, + [type_index_name, shape_fn](const Variant& v, + TensorShape* s) -> Status { const T* t = v.get(); if (t == nullptr) { return errors::Internal( - "VariantShapeFn: Could not access object, type_name: ", - type_name); + "VariantShapeFn: Could not access object, type_index: ", + type_index_name); } return shape_fn(*t, s); }); @@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration { return false; } Variant decoded = T(); - VariantTensorData data(*t); - if (!decoded.Decode(data)) { + VariantTensorData data(std::move(*t)); + if (!decoded.Decode(std::move(data))) { return false; } - *v = std::move(decoded); + std::swap(decoded, *v); return true; }); } @@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration { UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)> LocalVariantDeviceCopyFn; UnaryVariantDeviceCopyRegistration( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const LocalVariantDeviceCopyFn& device_copy_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn( - direction, type_name, - [type_name, device_copy_fn]( + direction, type_index, + [type_index_name, device_copy_fn]( const Variant& from, Variant* to, UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn device_copy_tensor_fn) -> Status { @@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration { *to = T(); if (from.get() == nullptr) { return errors::Internal( - "VariantCopyToGPUFn: Could not access object, type_name: ", - type_name); + "VariantCopyToGPUFn: Could not access object, type_index: ", + type_index_name); } const T& t = *from.get(); T* t_out = to->get(); @@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration { public: UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const LocalVariantUnaryOpFn& unary_op_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn( - op, device, type_name, - [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, - Variant* v_out) -> Status { + op, device, type_index, + [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, + Variant* v_out) -> Status { DCHECK_NE(v_out, nullptr); *v_out = T(); if (v.get() == nullptr) { return errors::Internal( - "VariantUnaryOpFn: Could not access object, type_name: ", - type_name); + "VariantUnaryOpFn: Could not access object, type_index: ", + type_index_name); } const T& t = *v.get(); T* t_out = v_out->get(); @@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration { public: UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const LocalVariantBinaryOpFn& binary_op_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn( - op, device, type_name, - [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, - const Variant& b, Variant* out) -> Status { + op, device, type_index, + [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, + const Variant& b, + Variant* out) -> Status { DCHECK_NE(out, nullptr); *out = T(); if (a.get() == nullptr) { return errors::Internal( - "VariantBinaryOpFn: Could not access object 'a', type_name: ", - type_name); + "VariantBinaryOpFn: Could not access object 'a', type_index: ", + type_index_name); } if (b.get() == nullptr) { return errors::Internal( - "VariantBinaryOpFn: Could not access object 'b', type_name: ", - type_name); + "VariantBinaryOpFn: Could not access object 'b', type_index: ", + type_index_name); } const T& t_a = *a.get(); const T& t_b = *b.get(); @@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration { // Register a unary shape variant function with the signature: // Status ShapeFn(const T& t, TensorShape* s); -// to Variants having TypeName type_name. -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \ - shape_function) +// to Variants having TypeIndex type_index. +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, MakeTypeIndex(), shape_function) -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \ - shape_function) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function) +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \ + shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function) -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \ +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \ shape_function) \ static variant_op_registry_fn_registration::UnaryVariantShapeRegistration \ - register_unary_variant_op_shape_registration_fn_##ctr(type_name, \ + register_unary_variant_op_shape_registration_fn_##ctr(type_index, \ shape_function) // Register a unary decode variant function for the given type. @@ -519,63 +533,63 @@ class UnaryVariantBinaryOpRegistration { // ****** NOTE ****** // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. // ****** NOTE ****** -#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - T, direction, type_name, device_copy_fn) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, T, direction, type_name, device_copy_fn) +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \ + device_copy_fn) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, direction, MakeTypeIndex(), device_copy_fn) #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ - ctr, T, direction, type_name, device_copy_fn) \ + ctr, T, direction, type_index, device_copy_fn) \ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ - ctr, T, direction, type_name, device_copy_fn) + ctr, T, direction, type_index, device_copy_fn) -#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ - ctr, T, direction, type_name, device_copy_fn) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantDeviceCopyRegistration \ - register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \ - device_copy_fn) +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ + ctr, T, direction, type_index, device_copy_fn) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantDeviceCopyRegistration \ + register_unary_variant_op_device_copy_fn_##ctr( \ + direction, type_index, device_copy_fn) // Register a unary unary_op variant function with the signature: // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); -// to Variants having TypeName type_name, for device string device, +// to Variants having TypeIndex type_index, for device string device, // for UnaryVariantOp enum op. -#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \ - unary_op_function) \ - REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, op, device, T, type_name, unary_op_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \ + unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, MakeTypeIndex(), unary_op_function) -#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ - ctr, op, device, T, type_name, unary_op_function) \ - REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \ - unary_op_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_index, unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \ + type_index, unary_op_function) #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, unary_op_function) \ + ctr, op, device, T, type_index, unary_op_function) \ static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \ T> \ - register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ unary_op_function) // Register a binary_op variant function with the signature: // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); -// to Variants having TypeName type_name, for device string device, +// to Variants having TypeIndex type_index, for device string device, // for BinaryVariantOp enum OP. -#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \ - binary_op_function) \ - REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, op, device, T, type_name, binary_op_function) +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \ + binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, MakeTypeIndex(), binary_op_function) #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ - ctr, op, device, T, type_name, binary_op_function) \ + ctr, op, device, T, type_index, binary_op_function) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, binary_op_function) + ctr, op, device, T, type_index, binary_op_function) -#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, binary_op_function) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantBinaryOpRegistration \ - register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_index, binary_op_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantBinaryOpRegistration \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ binary_op_function) } // end namespace tensorflow diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 7055e62c0e..b2443e8676 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -89,41 +89,37 @@ struct VariantValue { int value; }; -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", - VariantValue::ShapeFn); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn); REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "TEST VariantValue", VariantValue::CPUToGPUCopyFn); + VariantValue::CPUToGPUCopyFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, VariantValue, - "TEST VariantValue", VariantValue::CPUZerosLikeFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, VariantValue, - "TEST VariantValue", VariantValue::GPUZerosLikeFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - VariantValue, "TEST VariantValue", - VariantValue::CPUAddFn); + VariantValue, VariantValue::CPUAddFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - VariantValue, "TEST VariantValue", - VariantValue::GPUAddFn); + VariantValue, VariantValue::GPUAddFn); } // namespace TEST(VariantOpShapeRegistryTest, TestBasic) { - EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"), + class Blah {}; + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex()), nullptr); - auto* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue"); + auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn( + MakeTypeIndex()); EXPECT_NE(shape_fn, nullptr); TensorShape shape; @@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) { TEST(VariantOpShapeRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantShapeFn f; - string kTypeName = "fjfjfj"; - registry.RegisterShapeFn(kTypeName, f); - EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f), - "fjfjfj already registered"); + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex(); + registry.RegisterShapeFn(kTypeIndex, f); + EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpDecodeRegistryTest, TestBasic) { @@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) { TEST(VariantOpCopyToGPURegistryTest, TestBasic) { // No registered copy fn for GPU<->GPU. - EXPECT_EQ( - UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"), - nullptr); + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE, + MakeTypeIndex()), + nullptr); auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, + MakeTypeIndex()); EXPECT_NE(copy_to_gpu_fn, nullptr); VariantValue vv{true /* early_exit */}; @@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) { TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex(); registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE, - kTypeName, f); + kTypeIndex, f); EXPECT_DEATH(registry.RegisterDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f), - "fjfjfj already registered"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantUnaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex(); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_CPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_CPU, kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_GPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_GPU, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpAddRegistryTest, TestBasicCPU) { - return; + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpAddRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) { TEST(VariantOpAddRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantBinaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex(); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); } } // namespace tensorflow diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc index 99712dc114..3e67e4a864 100644 --- a/tensorflow/core/framework/variant_tensor_data.cc +++ b/tensorflow/core/framework/variant_tensor_data.cc @@ -22,8 +22,8 @@ namespace tensorflow { VariantTensorData::VariantTensorData() {} -VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) { - FromProto(proto); +VariantTensorData::VariantTensorData(VariantTensorDataProto proto) { + FromProto(std::move(proto)); } VariantTensorData::~VariantTensorData() {} @@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const { } } -bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) { +bool VariantTensorData::FromProto(VariantTensorDataProto proto) { + // TODO(ebrevdo): Do this lazily. + set_type_name(proto.type_name()); + set_metadata(proto.metadata()); + for (const auto& tensor : proto.tensors()) { + Tensor tmp; + if (!tmp.FromProto(tensor)) return false; + tensors_.push_back(tmp); + } + return true; +} + +bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) { set_type_name(proto.type_name()); set_metadata(proto.metadata()); for (const auto& tensor : proto.tensors()) { @@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) { return proto.SerializeToString(buf); } -bool VariantTensorData::ParseFromString(const string& s) { +bool VariantTensorData::ParseFromString(string s) { VariantTensorDataProto proto; const bool status = proto.ParseFromString(s); - if (status) FromProto(proto); + if (status) FromProto(std::move(proto)); return status; } diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h index 7500e77d43..8a240ee1e3 100644 --- a/tensorflow/core/framework/variant_tensor_data.h +++ b/tensorflow/core/framework/variant_tensor_data.h @@ -19,13 +19,13 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { class VariantTensorDataProto; -class Tensor; // The serialization format for Variant objects. Objects with references to // other Tensors can simply store those tensors in the `tensors` field, and @@ -38,7 +38,7 @@ class Tensor; class VariantTensorData { public: VariantTensorData(); - VariantTensorData(const VariantTensorDataProto& proto); + VariantTensorData(VariantTensorDataProto proto); ~VariantTensorData(); // Name of the type of objects being serialized. @@ -68,12 +68,14 @@ class VariantTensorData { // Conversion to and from VariantTensorDataProto void ToProto(VariantTensorDataProto* proto) const; - bool FromProto(const VariantTensorDataProto& proto); + // This allows optimizations via std::move. + bool FromProto(VariantTensorDataProto proto); + bool FromConstProto(const VariantTensorDataProto& proto); // Serialization via VariantTensorDataProto string SerializeAsString() const; bool SerializeToString(string* buf); - bool ParseFromString(const string& s); + bool ParseFromString(string s); string DebugString() const; diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc index eef5c47d15..08d09de7b8 100644 --- a/tensorflow/core/framework/variant_test.cc +++ b/tensorflow/core/framework/variant_test.cc @@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) { struct TensorList { void Encode(VariantTensorData* data) const { data->tensors_ = vec; } - bool Decode(const VariantTensorData& data) { - vec = data.tensors_; + bool Decode(VariantTensorData data) { + vec = std::move(data.tensors_); return true; } @@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) { x.Encode(&serialized); Variant y = TensorList(); - y.Decode(serialized); + y.Decode(std::move(serialized)); const TensorList& decoded_vec = *y.get(); for (int i = 0; i < 4; ++i) { @@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) { EXPECT_EQ(y_unknown.DebugString(), strings::StrCat( "Variant")); - - TensorList unknown_decoded_vec; - EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec)); - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(unknown_decoded_vec.vec[i].flat()(0), i); - } - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat()(0), 2 * i); - } } TEST(VariantTest, VariantArray) { diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index fe6d705eab..30c6585ba2 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -403,12 +403,12 @@ class IteratorStateVariant { } string TypeName() const { return kIteratorVariantTypeName; } void Encode(VariantTensorData* data) const { *data = *data_; } - bool Decode(const VariantTensorData& data) { + bool Decode(VariantTensorData data) { if (data.type_name() != TypeName()) { return false; } std::unique_ptr tensor_data(new VariantTensorData); - *tensor_data = data; + std::swap(*tensor_data, data); std::unique_ptr reader( new VariantTensorDataReader(tensor_data.get())); status_ = reader->status(); diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc index b372d31a93..6180df5af2 100644 --- a/tensorflow/core/kernels/data/optional_ops.cc +++ b/tensorflow/core/kernels/data/optional_ops.cc @@ -231,10 +231,9 @@ static Status OptionalDeviceCopy( return Status::OK(); } -#define REGISTER_OPTIONAL_COPY(DIRECTION) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - OptionalVariant, DIRECTION, kOptionalVariantTypeName, \ - OptionalDeviceCopy) +#define REGISTER_OPTIONAL_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + OptionalVariant, DIRECTION, OptionalDeviceCopy) REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h index cd2873bdca..7710cf93d6 100644 --- a/tensorflow/core/kernels/gather_functor.h +++ b/tensorflow/core/kernels/gather_functor.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/prefetch.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index bca1cff41c..2088c13586 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -77,9 +77,9 @@ static Status TensorListDeviceCopy( return Status::OK(); } -#define REGISTER_LIST_COPY(DIRECTION) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy) +#define REGISTER_LIST_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \ + TensorListDeviceCopy) REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); @@ -92,8 +92,7 @@ Status TensorListShape(const TensorList& t, TensorShape* s) { return Status::OK(); } -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName, - TensorListShape); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape); bool TensorList::Decode(const VariantTensorData& data) { tensors = data.tensors(); @@ -625,12 +624,11 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16); #undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - TensorList, TensorList::kTypeName, + TensorList, TensorListBinaryAdd); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, TensorList, - TensorList::kTypeName, TensorListZerosLike); } // namespace tensorflow diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc index c591226b76..a00bf700ca 100644 --- a/tensorflow/core/kernels/list_kernels.cu.cc +++ b/tensorflow/core/kernels/list_kernels.cu.cc @@ -94,11 +94,10 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bool); #undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - TensorList, TensorList::kTypeName, + TensorList, TensorListBinaryAdd); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, TensorList, - TensorList::kTypeName, TensorListZerosLike); } // namespace tensorflow diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc index 9cd590ae61..30cb1e0a7f 100644 --- a/tensorflow/core/kernels/shape_op_test.cc +++ b/tensorflow/core/kernels/shape_op_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/abi.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -60,8 +61,7 @@ Status GetShapeFromKnownVecSize(const KnownVecSize& ks, TensorShape* s) { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE"); -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE", - GetShapeFromKnownVecSize); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize); static void ExpectHasError(const Status& s, StringPiece substr) { EXPECT_TRUE(str_util::StrContains(s.ToString(), substr)) @@ -94,9 +94,9 @@ TEST_F(ShapeOpTest, Simple) { Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs); EXPECT_FALSE(s.ok()); ExpectHasError( - s, - "No unary variant shape function found for Variant type_name: " - "NO KNOWN SHAPE"); + s, strings::StrCat( + "No unary variant shape function found for Variant type_index: ", + port::MaybeAbiDemangle(MakeTypeIndex().name()))); } { diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc index e597a490d6..d7a13a3528 100644 --- a/tensorflow/core/platform/abi.cc +++ b/tensorflow/core/platform/abi.cc @@ -37,13 +37,13 @@ extern "C" char* __unDName(char* output_string, const char* name, namespace tensorflow { namespace port { -std::string MaybeAbiDemangle(const char* name) { +string MaybeAbiDemangle(const char* name) { #if defined(_MSC_VER) std::unique_ptr demangled{__unDName(nullptr, name, 0, std::malloc, std::free, static_cast(0))}; - return std::string(demangled.get() != nullptr ? demangled.get() : name); + return string(demangled.get() != nullptr ? demangled.get() : name); #else int status = 0; std::unique_ptr res{ diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h index 591e83b0c4..d1498a6a64 100644 --- a/tensorflow/core/platform/abi.h +++ b/tensorflow/core/platform/abi.h @@ -17,11 +17,12 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_ABI_H_ #include +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace port { -std::string MaybeAbiDemangle(const char* name); +string MaybeAbiDemangle(const char* name); } // namespace port } // namespace tensorflow -- cgit v1.2.3 From 232fcbb6fcf8c5ab3713261a0ef9a771b270753e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 10:49:24 -0700 Subject: Add basic logging to metagraph transform PiperOrigin-RevId: 212480467 --- .../contrib/meta_graph_transform/meta_graph_transform.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py index c35e60a554..b1c852c2c6 100644 --- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py +++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py @@ -31,6 +31,7 @@ from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as _graph_util from tensorflow.python.framework import importer as _importer from tensorflow.python.framework import ops as _ops +from tensorflow.python.platform import tf_logging as _logging from tensorflow.python.saved_model import constants as _saved_model_constants from tensorflow.python.training import saver as _saver_lib from tensorflow.python.util import compat as _compat @@ -476,6 +477,12 @@ def _add_pruned_collection(base_meta_graph_def, meta_graph_def, collection.bytes_list.value[:] = [ s for s in base_collection.bytes_list.value if not _is_removed_mentioned(s, removed_op_names)] + _logging.info( + 'In collection %s, nodes excluded are: %s', collection_name, + sorted([ + s for s in base_collection.bytes_list.value + if _is_removed_mentioned(s, removed_op_names) + ])) elif base_collection.HasField('node_list'): collection.node_list.value[:] = [ s for s in base_collection.node_list.value @@ -745,6 +752,9 @@ def meta_graph_transform( retained_op_names = [_compat.as_str(node.name) for node in meta_graph_def.graph_def.node] removed_op_names = set(base_op_names) - set(retained_op_names) + _logging.info('Node names in base graph: %s', sorted(base_op_names)) + _logging.info('Node names retained: %s', sorted(retained_op_names)) + _logging.info('Node names removed: %s', sorted(removed_op_names)) # Copy saver, excluding any pruned nodes if graph was not frozen. # TODO(b/63447631): Revisit this once the problem is addressed. Currently -- cgit v1.2.3 From 7e5ae7109f558cafaa87e3bcebabfc0e1f67aabc Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 11 Sep 2018 11:12:34 -0700 Subject: Handle control dependencies from switch nodes as nonreachable. In DeleteReachableNodes all the nodes reachable from nodes deleted from the graph during extraction was considered. But if a node had a control dependency on a switch, then that node doesn't conditionally execute based on the switch predicate and is not part of the conditional extracted, so it should be considered reachable for deletion. Additionally perform sweep of graph for dead nodes together with deleting the reachable nodes to keep all dead node deletion together. Also delete a dead function and ensure all graph dumps from functionalize_cond has that as prefix. PiperOrigin-RevId: 212485183 --- tensorflow/compiler/tf2xla/functionalize_cond.cc | 71 +++++++++++++++++------- tensorflow/compiler/tf2xla/functionalize_cond.h | 13 ++--- 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 0911550f1f..3ad1d1d5b4 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -217,10 +217,6 @@ void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { added_node_ancestorid_mapping_[node->id()] = id; } -const StateMap::CondState& StateMap::LookupState(const Node* node) const { - return *LookupCondId(node); -} - void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } string StateMap::CondStateToString(const Node* node) const { @@ -791,7 +787,6 @@ Status Conditional::BuildAndReplace(Graph* graph, TF_RETURN_IF_ERROR(AddInputEdges(graph)); TF_RETURN_IF_ERROR(AddOutputEdges(graph)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); - for (Node* m : merges_) state_map_->MarkDead(m); // Check that the if_node doesn't feed into itself. TF_RETURN_WITH_CONTEXT_IF_ERROR( @@ -1056,7 +1051,6 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { " has no non-dead inputs."); } state_map_.MarkDead(node); - delete_nodes_.push_back(node->id()); VLOG(5) << "removing redundant merge: " << node->name(); while (!node->out_edges().empty()) { const Edge* oe = *node->out_edges().begin(); @@ -1132,7 +1126,6 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { } } else if (BranchType(switch_branch) != b) { state_map_.MarkDead(dst_node); - delete_nodes_.push_back(dst_node->id()); continue; } graph_->AddEdge( @@ -1154,7 +1147,7 @@ Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst) << " @ " << state_map_.AncestorStateToString(dst); - if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it"); + if (VLOG_IS_ON(10)) DumpGraphWithCondState("it"); } return Status::OK(); } @@ -1184,23 +1177,62 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) { return Status::OK(); } -void FunctionalizeCond::DeleteReachableNodes() { +void FunctionalizeCond::DeleteReachableAndDeadNodes( + const std::vector& switch_ids, const std::vector& merge_order) { // Delete all nodes that have been extracted or are reachable from // deleted/dead nodes. The input and outgoing edges should have already been // removed. + std::deque delete_nodes; std::vector deleted(graph_->num_node_ids(), false); // Don't try to delete source or sink nodes. deleted[graph_->kSourceId] = true; deleted[graph_->kSinkId] = true; - while (!delete_nodes_.empty()) { - int d_id = delete_nodes_.front(); - delete_nodes_.pop_front(); + + // All remaining Switch nodes are not reachable from a Merge node and + // removed. This is to account for dead Switch nodes. + for (int s_id : switch_ids) { + Node* s = graph_->FindNodeId(s_id); + if (s == nullptr) continue; + for (const Edge* e : s->out_edges()) { + // Control outputs of switch nodes (which are unconditionally executed if + // the switch is) are not removed as they need not be part of a + // conditional. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[s_id] = true; + graph_->RemoveNode(s); + } + + // All merge nodes should have been transformed at this point and we remove + // them from the graph here. + for (Node* m : merge_order) { + for (const Edge* e : m->out_edges()) { + // Similar to control outputs of switch nodes don't remove control + // outputs of merge nodes. + // TODO(jpienaar): Check cases where output edges still exist here vs + // being removed in AddOutputEdges. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[m->id()] = true; + graph_->RemoveNode(m); + } + + // Enqueue all the dead nodes. + for (Node* n : graph_->nodes()) { + if (state_map_.IsDead(state_map_.LookupCondId(n))) { + delete_nodes.push_back(n->id()); + } + } + + while (!delete_nodes.empty()) { + int d_id = delete_nodes.front(); + delete_nodes.pop_front(); if (deleted[d_id]) continue; Node* d = graph_->FindNodeId(d_id); // Switch and Merge nodes could have been deleted already. if (d == nullptr) continue; for (const Edge* e : d->out_edges()) { - delete_nodes_.push_back(e->dst()->id()); + delete_nodes.push_back(e->dst()->id()); } deleted[d_id] = true; graph_->RemoveNode(d); @@ -1274,7 +1306,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { } TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); - if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id"); + if (VLOG_IS_ON(4)) DumpGraphWithCondState("id"); // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); @@ -1312,11 +1344,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } - // All remaining Switch nodes are not reachable from a Merge node and - // removed. This is to account for dead Switch nodes. - for (int s_id : switch_ids) delete_nodes_.push_back(s_id); - for (Node* m : merge_order) delete_nodes_.push_back(m->id()); - DeleteReachableNodes(); + DeleteReachableAndDeadNodes(switch_ids, merge_order); return Status::OK(); } @@ -1331,8 +1359,9 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " - << dump_graph::DumpGraphToFile(absl::StrCat("functionalize_", name), - *graph_, library_); + << dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_cond_", name), *graph_, + library_); } Status FunctionalizeCond::Functionalize(Graph* graph, diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 28301150ea..1899808940 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -91,10 +91,6 @@ class StateMap { // Resets the AncestorId for a given node. void ResetAncestorId(const Node* node, AncestorId id); - // Returns the CondState for a Node. - // REQUIRES: node has a non-empty CondState. - const CondState& LookupState(const Node* node) const; - // Marks `node` as dead. void MarkDead(const Node* node); @@ -221,8 +217,10 @@ class FunctionalizeCond { // nesting depth. void SortMergeNodes(std::vector* merge_order); - // Deletes all nodes in/consumers of `delete_nodes_`. - void DeleteReachableNodes(); + // Deletes all nodes in/consumers reachable from switch/merge nodes that were + // extracted. + void DeleteReachableAndDeadNodes(const std::vector& switch_ids, + const std::vector& merge_order); // Member used to unique the CondState to a unique CondId (AncestorState to a // unique AncestorId) and keep track of CondState/CondId @@ -232,9 +230,6 @@ class FunctionalizeCond { // Mapping from merge nodes to predicate. std::unordered_map merge_to_predicate_; - // Nodes to be deleted. - std::deque delete_nodes_; - FunctionLibraryDefinition* library_; Graph* graph_; -- cgit v1.2.3 From ded099749d4f987b404b9d5fd7169baf1671582b Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Tue, 11 Sep 2018 11:16:06 -0700 Subject: Add missing spaces to error message. PiperOrigin-RevId: 212485820 --- tensorflow/core/graph/graph_constructor.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index ee10194142..7399613f6a 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() { } if (processed < node_defs_.size()) { - LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed) + LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed) << " NODES IN A CYCLE"; for (int64 i = 0; i < node_defs_.size(); i++) { if (pending_count_[i] != 0) { LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i]) - << "WITH PENDING COUNT = " << pending_count_[i]; + << " WITH PENDING COUNT = " << pending_count_[i]; } } return errors::InvalidArgument(node_defs_.size() - processed, -- cgit v1.2.3 From a346aa260d32eb83621bb7ed501a2b07ba186480 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 11 Sep 2018 11:22:27 -0700 Subject: Automated rollback of commit 624ff13fdf4e54e255d23971ef2beec3c48c3bb2. Revert #21826. PiperOrigin-RevId: 212487142 --- tensorflow/python/ops/ctc_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 32d455bdad..908e793902 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -242,11 +242,11 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, If `merge_repeated` is `True`, merge repeated classes in the output beams. This means that if consecutive entries in a beam are the same, - only the first of these is emitted. That is, when the sequence is - `A B B * B * B` (where '*' is the blank label), the return value is: + only the first of these is emitted. That is, when the top path + is `A B B B B`, the return value is: * `A B` if `merge_repeated = True`. - * `A B B B` if `merge_repeated = False`. + * `A B B B B` if `merge_repeated = False`. Args: inputs: 3-D `float` `Tensor`, size -- cgit v1.2.3 From 6cb9189c567397b0779f1c52604e2ea6255a9183 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Tue, 11 Sep 2018 11:25:23 -0700 Subject: Removes option of pass-through runner on eager execution. It is possible it will deadlock by running code in the GPU event manager thread. PiperOrigin-RevId: 212487862 --- tensorflow/core/common_runtime/eager/context.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 37fc031985..263467a5b6 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -66,13 +66,9 @@ EagerContext::EagerContext(const SessionOptions& opts, local_unowned_device_manager_ = device_mgr; } InitDeviceMapAndAsync(); - if (opts.config.inter_op_parallelism_threads() > 0) { - runner_ = [this](std::function closure) { - this->thread_pool_->Schedule(closure); - }; - } else { - runner_ = [](std::function closure) { closure(); }; - } + runner_ = [this](std::function closure) { + this->thread_pool_->Schedule(closure); + }; } void EagerContext::InitDeviceMapAndAsync() { -- cgit v1.2.3 From 9b8c30fb0abf42f34c17050ff455d36166fa0e24 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 11 Sep 2018 11:26:28 -0700 Subject: Contraction mapper for cuboid convolutions. Directly pack rhs memory for the gebp kernes with a gemm_pack_rhs specialization. It's similar to optimized memory packing in eigen_spatial_convolutions. Works for: 1. CuboidConvolution 2. CuboidConvolutionBackwardInput ~2x-4x speedup when compiled with AVX (depends on tensor&patch dimensions). PiperOrigin-RevId: 212488060 --- tensorflow/core/kernels/eigen_cuboid_convolution.h | 1356 ++++++++++++++++++++ 1 file changed, 1356 insertions(+) diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h index 62e9f9123d..c41fbc42d3 100644 --- a/tensorflow/core/kernels/eigen_cuboid_convolution.h +++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h @@ -21,6 +21,1362 @@ limitations under the License. namespace Eigen { +namespace internal { + +// WARNING: Most of the code here implicitly assumes that the matrix is in +// ColMajor layout. This is guaranteed by the tensor contraction (see +// TensorContraction.h). +// +// Inside Eigen a tensor contraction is represented by a matrix multiplication. +// We don't want to actually extract volume patches and reshape the result into +// a matrix (this involves allocating huge extra memory), so the patch +// extraction and reshape operations are implicit. +// +// TensorContractionInputMapper takes a matrix index and returns the coefficient +// (or the packet) of the "virtual tensor", that would be at that index if we +// were to actually reshape the result of patch extraction. +// +// TensorContractionSubMapper provides a similar view into the "virtual matrix" +// at the given vertical and horizontal offsets. +// +// "Virtual matrix" dimensions: +// *0: kernelChannels * kernelDepth * kernelRows * kernelCols; +// 1: out_depth * out_height * out_width; * OTHERS (e.g batches, etc...) +// +// *) extracted patches are continuous in memory (innermost dimension assuming +// col major layout) +// +// With this dimensions: +// row - offset within a single patch (in code: patchId) +// col - index of the extracted patch (in code: patchIndex) +// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) +// +template +class TensorContractionInputMapper< + Scalar_, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef Scalar_ Scalar; + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper VectorMapper; + typedef SubMapper LinearMapper; + typedef typename packet_traits::type Packet; + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper( + const TensorEvaluator< + const TensorReshapingOp< + NewDimension, + const TensorVolumePatchOp >, + Device>& tensor, + const nocontract_t&, const nocontract_t&, const contract_t&, + const contract_t&) + : m_impl(tensor.impl().impl()) { + if (internal::traits::Layout == ColMajor) { + m_patch_depth = tensor.impl().dimensions()[0]; + m_patch_planes = tensor.impl().dimensions()[1]; + m_patch_rows = tensor.impl().dimensions()[2]; + m_patch_cols = tensor.impl().dimensions()[3]; + m_num_patches = tensor.impl().dimensions()[4]; + } else { + const int NumDims = tensor.impl().dimensions().size(); + m_patch_depth = tensor.impl().dimensions()[NumDims - 1]; + m_patch_planes = tensor.impl().dimensions()[NumDims - 2]; + m_patch_rows = tensor.impl().dimensions()[NumDims - 3]; + m_patch_cols = tensor.impl().dimensions()[NumDims - 4]; + m_num_patches = tensor.impl().dimensions()[NumDims - 5]; + } + + // Strides for the output tensor. + // IMPORTANT: These strides are used to locate an element in a patch at a + // depth zero (channel), which is not quite the same as "traditional" + // stride. + m_rowStride = m_patch_planes; + m_colStride = m_patch_rows * m_rowStride; + m_patchStride = m_colStride * m_patch_cols * m_patch_depth; + m_otherStride = m_patchStride * m_num_patches; + + m_outputPlanes = tensor.impl().outputPlanes(); + m_outputRows = tensor.impl().outputRows(); + m_outputCols = tensor.impl().outputCols(); + + m_outputPlanesRows = m_outputPlanes * m_outputRows; + + m_plane_strides = tensor.impl().userPlaneStride(); + m_row_strides = tensor.impl().userRowStride(); + m_col_strides = tensor.impl().userColStride(); + + m_in_plane_strides = tensor.impl().userInPlaneStride(); + m_in_row_strides = tensor.impl().userInRowStride(); + m_in_col_strides = tensor.impl().userInColStride(); + + m_patch_plane_inflate_strides = tensor.impl().planeInflateStride(); + m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); + m_patch_col_inflate_strides = tensor.impl().colInflateStride(); + + if (internal::traits::Layout == ColMajor) { + m_inputDepth = tensor.impl().impl().dimensions()[0]; + m_inputPlanes = tensor.impl().impl().dimensions()[1]; + m_inputRows = tensor.impl().impl().dimensions()[2]; + m_inputCols = tensor.impl().impl().dimensions()[3]; + } else { + const int NumDims = tensor.impl().impl().dimensions().size(); + m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1]; + m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2]; + m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3]; + m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4]; + } + + // Strides for navigating through the input tensor. + m_planeInputStride = m_inputDepth; + m_rowInputStride = m_inputDepth * m_inputPlanes; + m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; + m_patchInputStride = + m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; + + m_planePaddingTop = tensor.impl().planePaddingTop(); + m_rowPaddingTop = tensor.impl().rowPaddingTop(); + m_colPaddingLeft = tensor.impl().colPaddingLeft(); + + m_fastNumPatches = internal::TensorIntDivisor(m_num_patches); + + m_fastInputPlaneStride = + internal::TensorIntDivisor(m_patch_plane_inflate_strides); + m_fastInputRowStride = + internal::TensorIntDivisor(m_patch_row_inflate_strides); + m_fastInputColStride = + internal::TensorIntDivisor(m_patch_col_inflate_strides); + + m_fastRowStride = internal::TensorIntDivisor(m_rowStride); + m_fastColStride = internal::TensorIntDivisor(m_colStride); + + m_fastDimZero = internal::TensorIntDivisor(m_patch_depth); + m_fastOutputRows = internal::TensorIntDivisor(m_outputRows); + m_fastOutputPlanes = internal::TensorIntDivisor(m_outputPlanes); + m_fastOutputRows = internal::TensorIntDivisor(m_outputRows); + m_fastOutputCols = internal::TensorIntDivisor(m_outputCols); + + m_fastOutputPlanesRows = + internal::TensorIntDivisor(m_outputPlanesRows); + } + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) + : m_impl(base_mapper.m_impl) { + m_patch_depth = base_mapper.m_patch_depth; + m_patch_planes = base_mapper.m_patch_planes; + m_patch_rows = base_mapper.m_patch_rows; + m_patch_cols = base_mapper.m_patch_cols; + m_num_patches = base_mapper.m_num_patches; + + m_rowStride = base_mapper.m_rowStride; + m_colStride = base_mapper.m_colStride; + m_patchStride = base_mapper.m_patchStride; + m_otherStride = base_mapper.m_otherStride; + + m_planeInputStride = base_mapper.m_planeInputStride; + m_rowInputStride = base_mapper.m_rowInputStride; + m_colInputStride = base_mapper.m_colInputStride; + m_patchInputStride = base_mapper.m_patchInputStride; + m_otherInputStride = base_mapper.m_otherInputStride; + + m_inputDepth = base_mapper.m_inputDepth; + m_inputPlanes = base_mapper.m_inputPlanes; + m_inputRows = base_mapper.m_inputRows; + m_inputCols = base_mapper.m_inputCols; + + m_outputPlanes = base_mapper.m_outputPlanes; + m_outputRows = base_mapper.m_outputRows; + m_outputCols = base_mapper.m_outputCols; + + m_plane_strides = base_mapper.m_plane_strides; + m_row_strides = base_mapper.m_row_strides; + m_col_strides = base_mapper.m_col_strides; + + m_in_plane_strides = base_mapper.m_in_plane_strides; + m_in_row_strides = base_mapper.m_in_row_strides; + m_in_col_strides = base_mapper.m_in_col_strides; + + m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides; + m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; + m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; + + m_planePaddingTop = base_mapper.m_planePaddingTop; + m_rowPaddingTop = base_mapper.m_rowPaddingTop; + m_colPaddingLeft = base_mapper.m_colPaddingLeft; + + m_outputPlanesRows = base_mapper.m_outputPlanesRows; + + m_fastNumPatches = base_mapper.m_fastNumPatches; + m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride; + m_fastInputRowStride = base_mapper.m_fastInputRowStride; + m_fastInputColStride = base_mapper.m_fastInputColStride; + m_fastRowStride = base_mapper.m_fastRowStride; + m_fastColStride = base_mapper.m_fastColStride; + m_fastOutputPlanes = base_mapper.m_fastOutputPlanes; + m_fastOutputRows = base_mapper.m_fastOutputRows; + m_fastOutputCols = base_mapper.m_fastOutputCols; + m_fastDimZero = base_mapper.m_fastDimZero; + m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows; + } + + // If true, turns off some optimizations for loading packets since the image + // patches are "non-standard" such as there are non-trivial strides or + // inflations in the input. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_in_plane_strides != 1 || m_in_row_strides != 1 || + m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 || + m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { + return SubMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Load the coefficient at the patchIndex location instead of the usual + // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the + // gpu code. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Load the packet at the patchIndex location instead of the usual m_rowIndex, + // m_colIndex, m_otherIndex. This is currently only used by the gpu code. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE const TensorEvaluator& impl() const { + return m_impl; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_patch_planes; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_impl.template packet(inputIndex); + } + + private: + friend class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>; + + // Load coefficient from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset * m_in_col_strides; + const Index origInputCol = + (m_patch_col_inflate_strides == 1) + ? inputCol + : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); + + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index inputRow = rowIndex + rowOffset * m_in_row_strides; + const Index origInputRow = + (m_patch_row_inflate_strides == 1) + ? inputRow + : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); + + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides; + const Index origInputPlane = + (m_patch_plane_inflate_strides == 1) + ? inputPlane + : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); + + if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 || + origInputCol >= m_inputCols || origInputRow >= m_inputRows || + origInputPlane >= m_inputPlanes || + (inputCol != origInputCol * m_patch_col_inflate_strides) || + (inputRow != origInputRow * m_patch_row_inflate_strides) || + (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) { + return Scalar(0); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + origInputPlane * m_planeInputStride + + origInputRow * m_rowInputStride + + origInputCol * m_colInputStride + otherIndex; + + return m_impl.coeff(inputIndex); + } + + // This is the same as loadCoeff(...), but optimized for all `inflate_strides` + // and `in_strides` equal to 1 (template specialization without templates). + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + eigen_assert(!nonStandardPatches()); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset; + + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index inputRow = rowIndex + rowOffset; + + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + const Index inputPlane = planeIndex + planeOffset; + + if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || + inputRow >= m_inputRows || inputPlane < 0 || + inputPlane >= m_inputPlanes) { + return Scalar(0); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputPlane * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + + return m_impl.coeff(inputIndex); + } + + // Load packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + + if (nonStandardPatches()) { + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + eigen_assert(!nonStandardPatches()); + + if ((patchDepth() % packetSize) == 0) { + return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } else { + // Offsets and input calculation here are identical to + // loadCoeffStandard(...), but repeated twice. + + const Index patchOffsets[2] = { + patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; + + const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, + patchOffsets[1] / m_fastColStride}; + eigen_assert(colOffsets[0] <= colOffsets[1]); + + const Index inputCols[2] = {colIndex + colOffsets[0], + colIndex + colOffsets[1]}; + if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { + return internal::pset1(Scalar(0)); + } + + if (inputCols[0] == inputCols[1]) { + const Index rowOffsets[2] = { + (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, + (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; + eigen_assert(rowOffsets[0] <= rowOffsets[1]); + const Index inputRows[2] = {rowIndex + rowOffsets[0], + rowIndex + rowOffsets[1]}; + + if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { + return internal::pset1(Scalar(0)); + } + + if (inputRows[0] == inputRows[1]) { + const Index planeOffsets[2] = { + patchOffsets[0] - colOffsets[0] * m_colStride - + rowOffsets[0] * m_rowStride, + patchOffsets[1] - colOffsets[1] * m_colStride - + rowOffsets[1] * m_rowStride}; + eigen_assert(planeOffsets[0] <= planeOffsets[1]); + const Index inputPlanes[2] = {planeIndex + planeOffsets[0], + planeIndex + planeOffsets[1]}; + + if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { + return internal::pset1(Scalar(0)); + } + + if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { + const Index depth = patchId - patchOffsets[0] * patchDepth(); + const Index inputIndex = + depth + inputPlanes[0] * m_planeInputStride + + inputRows[0] * m_rowInputStride + + inputCols[0] * m_colInputStride + otherIndex; + return m_impl.template packet(inputIndex); + } + } + } + } + + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + + eigen_assert(!nonStandardPatches()); + eigen_assert((patchDepth() % packetSize) == 0); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset; + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index inputRow = rowIndex + rowOffset; + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + const Index inputPlane = planeIndex + planeOffset; + + if (inputCol < 0 || inputRow < 0 || inputPlane < 0 || + inputCol >= m_inputCols || inputRow >= m_inputRows || + inputPlane >= m_inputPlanes) { + return internal::pset1(Scalar(0)); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputPlane * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + return m_impl.template packet(inputIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex, + Index colIndex, Index otherIndex) const { + const int packetSize = internal::unpacket_traits::size; + EIGEN_ALIGN_MAX + typename internal::remove_const::type values[packetSize]; + for (int i = 0; i < packetSize; ++i) { + values[i] = + loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex); + } + Packet rslt = internal::pload(values); + return rslt; + } + + // Precompute the indices (plane, row, col, other) of the first element of + // the given patch index, within the output tensor of the TensorVolumePatchOp. + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( + Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex, + Index& otherIndex) const { + const int NumInputDims = array_size< + typename TensorEvaluator::Dimensions>::value; + + // Check if patchIndex might contain batch and other dimensions. + otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches; + + // Compute index of the patch within the batch (and other dimensions). + const Index patch3DIndex = (NumInputDims == 4) + ? patchIndex + : (patchIndex - otherIndex * m_num_patches); + + otherIndex *= m_patchInputStride; + + colIndex = patch3DIndex / m_fastOutputPlanesRows; + rowIndex = + (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; + planeIndex = + patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes; + + colIndex = colIndex * m_col_strides - m_colPaddingLeft; + rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; + planeIndex = planeIndex * m_plane_strides - m_planePaddingTop; + } + + Index m_patch_depth; // number of channels in the patch + Index m_patch_planes; // number of planes in the patch + Index m_patch_rows; // number of rows in the patch + Index m_patch_cols; // number of columns in the patch + Index m_num_patches; // number of patches to extract + + // Strides for the output tensor. + Index m_rowStride; + Index m_colStride; + Index m_patchStride; + Index m_otherStride; + + Index m_planeInputStride; // Plane stride in the input tensor + Index m_rowInputStride; // Row stride in the input tensor + Index m_colInputStride; // Col stride in the input tensor + Index m_patchInputStride; // Patch stride in the input tensor + Index m_otherInputStride; + + Index m_inputDepth; // Depth of the input tensor + Index m_inputPlanes; // Number of planes in the input tensor + Index m_inputRows; // Number of rows in the input tensor + Index m_inputCols; // Number of cols in the input tensor + + Index m_outputPlanes; // Number of output planes + Index m_outputRows; // Number of output rows + Index m_outputCols; // Number of output cols + Index m_outputPlanesRows; // Cached outputPlanes * outputRows. + + Index m_plane_strides; // User specified plane stride + Index m_row_strides; // User specified row stride + Index m_col_strides; // User specified col stride + + // User specified plane/row/col atrous convolution strides. + Index m_in_plane_strides; + Index m_in_row_strides; + Index m_in_col_strides; + + // User specified plane/row/col inflation strides in the image patch. + Index m_patch_plane_inflate_strides; + Index m_patch_row_inflate_strides; + Index m_patch_col_inflate_strides; + + Index m_planePaddingTop; // Plane padding + Index m_rowPaddingTop; // Row padding + Index m_colPaddingLeft; // Column padding + + // Fast representation of various divisors. + internal::TensorIntDivisor m_fastNumPatches; + + internal::TensorIntDivisor m_fastInputPlaneStride; + internal::TensorIntDivisor m_fastInputRowStride; + internal::TensorIntDivisor m_fastInputColStride; + + internal::TensorIntDivisor m_fastRowStride; + internal::TensorIntDivisor m_fastColStride; + + internal::TensorIntDivisor m_fastDimZero; // aka output depth + internal::TensorIntDivisor m_fastOutputPlanes; + internal::TensorIntDivisor m_fastOutputRows; + internal::TensorIntDivisor m_fastOutputCols; + internal::TensorIntDivisor m_fastOutputPlanesRows; + + const TensorEvaluator m_impl; +}; + +template +class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef typename packet_traits::type Packet; + typedef typename packet_traits::half HalfPacket; + + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + ParentMapper; + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + typedef Self LinearMapper; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper), + m_depth_offset(vert_offset), + m_col_offset(horiz_offset) { + m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const Self& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper.m_base_mapper), + m_depth_offset(vert_offset + base_mapper.m_depth_offset), + m_col_offset(horiz_offset + base_mapper.m_col_offset) { + m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { + return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, + Index j) const { + return m_base_mapper(i + m_depth_offset, j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { + return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex, + m_rowIndex, m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, + Index j) const { + return m_base_mapper.template loadPacket(i + m_depth_offset, + j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar + loadCoeffStandard(Index i) const { + return m_base_mapper.loadCoeffStandard( + i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { + return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex, + m_rowIndex, m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + loadPacketStandard(Index i) const { + return m_base_mapper.loadPacketStandard( + i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); + } + template + EIGEN_DEVICE_FUNC bool aligned(Index) const { + return false; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_base_mapper.nonStandardPatches(); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { + return m_base_mapper.m_patch_depth; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlanes() const { + return m_base_mapper.m_patch_planes; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { + return m_base_mapper.m_patch_rows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { + return m_base_mapper.m_patch_cols; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_base_mapper.m_impl.template packet(inputIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const { + const Index p = m_planeIndex + plane; + return p < 0 || p >= m_base_mapper.m_inputPlanes; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { + const Index r = m_rowIndex + row; + return r < 0 || r >= m_base_mapper.m_inputRows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { + const Index c = m_colIndex + col; + return c < 0 || c >= m_base_mapper.m_inputCols; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row, + const Index col) const { + const Index p = m_planeIndex + plane; + const Index r = m_rowIndex + row; + const Index c = m_colIndex + col; + return p * m_base_mapper.m_planeInputStride + + r * m_base_mapper.m_rowInputStride + + c * m_base_mapper.m_colInputStride + m_otherIndex; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index planeOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_base_mapper.m_colStride) / + m_base_mapper.m_fastRowStride; + const Index planeOffset = patchOffset - + colOffset * m_base_mapper.m_colStride - + rowOffset * m_base_mapper.m_rowStride; + return planeOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index rowOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_base_mapper.m_colStride) / + m_base_mapper.m_fastRowStride; + return rowOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index colOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + return colOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index depthOffset() const { + const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth(); + return patchOffset; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper + getLinearMapper(Index i, Index j) const { + return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); + } + + private: + const ParentMapper& m_base_mapper; + Index m_depth_offset; // First row in the input matrix + Index m_col_offset; // First col in the input matrix + + // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base + // indices for the first element in a patch specified by col_offset + // (see computeBaseIndices(...) for details). + Index m_planeIndex; + Index m_rowIndex; + Index m_colIndex; + Index m_otherIndex; +}; + +// Arrange a block of the right input matrix (in our case it's always a "virtual +// matrix" constructed from extracted volume patches) in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... +// A1 B1 C1 D1 E1 F1 G1 H1 ... +// A2 B2 C2 D2 E2 F2 G2 H2 ... +// A3 B3 C3 D3 E3 F3 G3 H3 ... +// A4 B4 C4 D4 E4 F4 G4 H4 ... +// A5 B5 C5 D5 E5 F5 G5 H5 ... +// A6 B6 C6 D6 E6 F6 G6 H6 ... +// A7 B7 C7 D7 E7 F7 G7 H7 ... +// A8 ... +// ... +// +// Packing yields row major output (A0 beside A1 in memory): +// A0 A1 A2 A3 A4 A5 A6 A7 +// B0 B1 B2 B3 B4 B5 B6 B7 +// C0 ... +// ... +// +// *) A, B, C, ... - patches extracted from the original input. +// *) nr - number of registers along the 'n' dimension. +// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix +// Multiplication" paper. +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + typedef typename packet_traits::type Packet; + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if ((packet_size % 4) == 0 && !non_standard_patches) { + const Index patch_depth = rhs.patchDepth(); + + if ((patch_depth % packet_size) == 0) { + const Index patch_cols = rhs.patchCols(); + const Index patch_rows = rhs.patchRows(); + const Index patch_planes = rhs.patchPlanes(); + + const Index startCol = rhs.colOffset(); + const Index max_cols = std::min( + Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) + + startCol, + patch_cols); + + for (Index c = startCol; c < max_cols; ++c) { + eigen_assert(k < peeled_k); + + const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; + const Index max_rows = std::min( + Eigen::divup( + peeled_k - c * patch_rows * patch_planes * patch_depth, + patch_planes * patch_depth) + + startRow, + patch_rows); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + for (Index r = startRow; r < max_rows; ++r) { + eigen_assert(k < peeled_k); + + const Index startPlane = + ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0; + const Index max_planes = std::min( + Eigen::divup( + peeled_k - + c * patch_rows * patch_planes * patch_depth - // col + r * patch_planes * patch_depth, // row + patch_depth) + + startPlane, + patch_planes); + + const bool pad_row0 = dm0.padRow(r); + const bool pad_row1 = dm1.padRow(r); + const bool pad_row2 = dm2.padRow(r); + const bool pad_row3 = dm3.padRow(r); + + for (Index p = startPlane; p < max_planes; ++p) { + eigen_assert(k < peeled_k); + + const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); + const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); + const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); + const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); + + const Index idx0 = dm0.baseIndex(p, r, c); + const Index idx1 = dm1.baseIndex(p, r, c); + const Index idx2 = dm2.baseIndex(p, r, c); + const Index idx3 = dm3.baseIndex(p, r, c); + + const Index startDepth = + ((c == startCol) && (r == startRow) && (p == startPlane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = std::min( + peeled_k - + c * patch_rows * patch_planes * patch_depth - // col + r * patch_planes * patch_depth - // row + p * patch_depth + // plane + startDepth, + patch_depth); + eigen_assert((max_depth - startDepth) % packet_size == 0); + + for (Index d = startDepth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock kernel; + kernel.packet[0] = pad0 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel.packet[1] = pad1 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel.packet[2] = pad2 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel.packet[3] = pad3 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + k += packet_size; + } + } + } + } + + for (; k < peeled_k; k += packet_size) { + PacketBlock kernel; + kernel.packet[0] = dm0.loadPacketFast(k); + kernel.packet[1] = dm1.loadPacketFast(k); + kernel.packet[2] = dm2.loadPacketFast(k); + kernel.packet[3] = dm3.loadPacketFast(k); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + } + } else { + for (; k < peeled_k; k += packet_size) { + PacketBlock kernel; + kernel.packet[0] = dm0.loadPacketStandard(k); + kernel.packet[1] = dm1.loadPacketStandard(k); + kernel.packet[2] = dm2.loadPacketStandard(k); + kernel.packet[3] = dm3.loadPacketStandard(k); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + } + } + } + if (!rhs.nonStandardPatches()) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // copy the remaining columns one at a time (nr==1) + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Template specialization for packet_size = 2. We must special-case packet +// blocks with nr > packet_size, e.g. PacketBlock. +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + typedef typename packet_traits::type Packet; + + const int packet_size = 2; + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if (!non_standard_patches) { + const Index patch_depth = rhs.patchDepth(); + + if ((patch_depth % packet_size) == 0) { + const Index patch_cols = rhs.patchCols(); + const Index patch_rows = rhs.patchRows(); + const Index patch_planes = rhs.patchPlanes(); + + const Index startCol = rhs.colOffset(); + const Index max_cols = std::min( + Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) + + startCol, + patch_cols); + + for (Index c = startCol; c < max_cols; ++c) { + eigen_assert(k < peeled_k); + + const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; + const Index max_rows = std::min( + Eigen::divup( + peeled_k - c * patch_rows * patch_planes * patch_depth, + patch_planes * patch_depth) + + startRow, + patch_rows); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + for (Index r = startRow; r < max_rows; ++r) { + eigen_assert(k < peeled_k); + + const Index startPlane = + ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0; + const Index max_planes = std::min( + Eigen::divup( + peeled_k - + c * patch_rows * patch_planes * patch_depth - // col + r * patch_planes * patch_depth, // row + patch_depth) + + startPlane, + patch_planes); + + const bool pad_row0 = dm0.padRow(r); + const bool pad_row1 = dm1.padRow(r); + const bool pad_row2 = dm2.padRow(r); + const bool pad_row3 = dm3.padRow(r); + + for (Index p = startPlane; p < max_planes; ++p) { + eigen_assert(k < peeled_k); + + const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); + const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); + const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); + const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); + + const Index idx0 = dm0.baseIndex(p, r, c); + const Index idx1 = dm1.baseIndex(p, r, c); + const Index idx2 = dm2.baseIndex(p, r, c); + const Index idx3 = dm3.baseIndex(p, r, c); + + const Index startDepth = + ((c == startCol) && (r == startRow) && (p == startPlane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = std::min( + peeled_k - + c * patch_rows * patch_planes * patch_depth - // col + r * patch_planes * patch_depth - // row + p * patch_depth + // plane + startDepth, + patch_depth); + eigen_assert((max_depth - startDepth) % packet_size == 0); + + for (Index d = startDepth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock kernel0; + PacketBlock kernel1; + kernel0.packet[0] = pad0 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel0.packet[1] = pad1 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel1.packet[0] = pad2 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel1.packet[1] = pad3 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + k += packet_size; + } + } + } + } + + for (; k < peeled_k; k += packet_size) { + PacketBlock kernel0; + PacketBlock kernel1; + kernel0.packet[0] = dm0.loadPacketFast(k); + kernel0.packet[1] = dm1.loadPacketFast(k); + kernel1.packet[0] = dm2.loadPacketFast(k); + kernel1.packet[1] = dm3.loadPacketFast(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } else { + for (; k < peeled_k; k += packet_size) { + PacketBlock kernel0; + PacketBlock kernel1; + kernel0.packet[0] = dm0.loadPacketStandard(k); + kernel0.packet[1] = dm1.loadPacketStandard(k); + kernel1.packet[0] = dm2.loadPacketStandard(k); + kernel1.packet[1] = dm3.loadPacketStandard(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } + } + if (!rhs.nonStandardPatches()) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // copy the remaining columns one at a time (nr==1) + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Special case for non-vectorized types such as float16 (packet_size = 1). +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, + Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + + const Index packet_cols4 = (cols / 4) * 4; + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + if (!rhs.nonStandardPatches()) { + for (Index k = 0; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (Index k = 0; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // copy the remaining columns one at a time (nr==1) + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +} // namespace internal + /** CuboidConvolution * \ingroup CXX11_NeuralNetworks_Module * -- cgit v1.2.3 From 29c3c08f23e14eaff1dbd7a3c66139314c045574 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 11 Sep 2018 11:47:14 -0700 Subject: Convert NumPy arrays to Tensors when they're arguments to a defun Previously they were counted in the cache key as if they were Tensors, but were not fed as placeholders, leading to stale values when the trace was reused. There is an 8%ish performance impact from the tuple comprehension on the defun no-signature-call microbenchmarks. I don't see a much faster way to do this without rewriting it in C, but I'm open to ideas. I've avoided re-packing the input tuple unless there's actually a numpy array, so this CL will slow down NumPy defun calls more (in addition to the convert_to_tensor overhead). After: entry { name: "MicroBenchmarks.benchmark_defun_with_signature" iters: 30000 wall_time: 134.219272931 extras { key: "examples_per_sec" value { double_value: 7450.49483699 } } } entry { name: "MicroBenchmarks.benchmark_defun_with_signature_and_kwargs" iters: 30000 wall_time: 142.88717111 extras { key: "examples_per_sec" value { double_value: 6998.52892485 } } } entry { name: "MicroBenchmarks.benchmark_defun_without_signature" iters: 30000 wall_time: 76.2096961339 extras { key: "examples_per_sec" value { double_value: 13121.6898994 } } } entry { name: "MicroBenchmarks.benchmark_defun_without_signature_and_with_kwargs" iters: 30000 wall_time: 81.8309704463 extras { key: "examples_per_sec" value { double_value: 12220.3121208 } } } Before: entry { name: "MicroBenchmarks.benchmark_defun_with_signature" iters: 30000 wall_time: 129.392266273 extras { key: "examples_per_sec" value { double_value: 7728.43716862 } } } entry { name: "MicroBenchmarks.benchmark_defun_with_signature_and_kwargs" iters: 30000 wall_time: 141.65956974 extras { key: "examples_per_sec" value { double_value: 7059.1771656 } } } entry { name: "MicroBenchmarks.benchmark_defun_without_signature" iters: 30000 wall_time: 70.6333637238 extras { key: "examples_per_sec" value { double_value: 14157.6154282 } } } entry { name: "MicroBenchmarks.benchmark_defun_without_signature_and_with_kwargs" iters: 30000 wall_time: 78.4090677897 extras { key: "examples_per_sec" value { double_value: 12753.6269489 } } } PiperOrigin-RevId: 212491803 --- tensorflow/python/eager/function.py | 21 +++++++++++++++++---- tensorflow/python/eager/function_test.py | 9 +++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 03f12139f6..8c30550708 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -34,6 +34,7 @@ from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops @@ -879,9 +880,6 @@ def _encode_arg(arg): _TensorType(arg.values.dtype, arg.values._shape_tuple()), _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), ]) - elif isinstance(arg, np.ndarray): - tensor = ops.convert_to_tensor(arg) - return _TensorType(tensor.dtype, tensor._shape_tuple()) # pylint: enable=protected-access elif isinstance(arg, (list, tuple)): return tuple([_encode_arg(elem) for elem in arg]) @@ -1089,6 +1087,17 @@ class PolymorphicFunction(object): # opposed to named arguments called in a keyword-like fashion. kwds.pop(arg) inputs = args + _deterministic_dict_values(arg_indices_to_values) + flat_inputs = nest.flatten(inputs) + + # Check for NumPy arrays in arguments and convert them to Tensors. + need_packing = False + for index, value in enumerate(flat_inputs): + if isinstance(value, np.ndarray): + flat_inputs[index] = constant_op.constant(value) + need_packing = True + if need_packing: + inputs = nest.pack_sequence_as(structure=inputs, + flat_sequence=flat_inputs) if self._input_signature is None: return inputs, kwds else: @@ -1098,7 +1107,6 @@ class PolymorphicFunction(object): except (ValueError, TypeError): raise ValueError("Structure of Python function inputs does not match " "input_signature.") - flat_inputs = nest.flatten(inputs) if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs): raise ValueError("When input_signature is provided, all inputs to " "the Python function must be Tensors.") @@ -1271,6 +1279,11 @@ def defun(func=None, input_signature=None): tracing the execution of `f(*args, **kwargs)`; this graph is bound to an input signature inferred from `(*args, **kwargs)` and cached for future reuse. + NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects + before being passed to `f`, and are treated as Tensors for caching. This + allows a function to be called multiple times with NumPy arrays having + different values but the same shape and dtype without re-tracing each time. + `tf.contrib.eager.defun` caches graphs for your convenience, letting you define TensorFlow functions without explicitly specifying their signatures. However, this policy is conservative and potentially expensive; for example, diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 92254a2c00..6507bc6d71 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -22,6 +22,8 @@ import functools from multiprocessing.pool import ThreadPool import sys +import numpy + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop @@ -314,6 +316,7 @@ class FunctionTest(test.TestCase): def testDefunNumpyArraysConvertedToTensors(self): def f(x): + self.assertIsInstance(x, ops.Tensor) return x x = random_ops.random_uniform([2, 2]).numpy() @@ -327,6 +330,12 @@ class FunctionTest(test.TestCase): # shouldn't trigger another function definition. self.assertEqual(len(defined._function_cache), 1) + # Test that the numpy array is properly an argument to the graph function. + self.assertEqual(1., defined(numpy.ones([])).numpy()) + self.assertEqual(0., defined(numpy.zeros([])).numpy()) + self.assertEqual(1., defined(array_ops.ones([])).numpy()) + self.assertEqual(0., defined(array_ops.zeros([])).numpy()) + def testDefunCapturedInt32(self): x = constant_op.constant(1, dtype=dtypes.int32) -- cgit v1.2.3 From a9e73ddb3d40514af4144278f6450e5c1c806f8b Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Tue, 11 Sep 2018 12:03:48 -0700 Subject: Make exhaustive_f32_elementwise_op_test build again and mark it as broken It was not running as part of TAP and there have been some regressions. Mark it as broken while we figure out what's going on to unblock b/114790989. PiperOrigin-RevId: 212494775 --- tensorflow/compiler/xla/tests/BUILD | 1 + .../compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d0bda45cf8..30e3077edb 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -647,6 +647,7 @@ xla_test( ], shard_count = 48, tags = [ + "broken", "manual", "notap", ], diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 738f2600d4..51b50d456e 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -45,22 +45,22 @@ class ExhaustiveF32ElementwiseOpTest i < known_incorrect_range.second) { // If the operation is known to be buggy on a specific input clamp that // input to 0 under the assumption that the op is at least correct on 0. - input_literal->Set({i - begin}, 0.0f); + input_literal.Set({i - begin}, 0.0f); } else { - input_literal->Set({i - begin}, tensorflow::bit_cast(i)); + input_literal.Set({i - begin}, tensorflow::bit_cast(i)); } } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); enqueue_op(&builder, input); std::vector expected_result; expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal->Get({i}))); + expected_result.push_back(evaluate_op(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, -- cgit v1.2.3 From 1025b0c68b819a7292b51e51bbf7badc8818f286 Mon Sep 17 00:00:00 2001 From: Olivia Nordquist Date: Tue, 11 Sep 2018 12:18:34 -0700 Subject: disable failing test PiperOrigin-RevId: 212497382 --- tensorflow/contrib/distributions/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 97c53ae2b9..9aadc634da 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -166,6 +166,7 @@ cuda_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], + tags = ["notap"], ) cuda_py_test( -- cgit v1.2.3 From dad6912b530c92b2f362f1cc2a83006a22f604b6 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Tue, 11 Sep 2018 13:12:21 -0700 Subject: Handle model deserialization when output tensor shape is NULL. In flatbuffers, vectors default to NULL. Original change by alanchiao@. PiperOrigin-RevId: 212506392 --- tensorflow/contrib/lite/model.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 241865b3d8..6311d60b91 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -177,6 +177,11 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { namespace { template std::vector FlatBufferIntArrayToVector(T* flat_array) { + // Initialize shape of tensors with null shape. Empty vectors are converted + // to nullptr for models that are constructed via flatbuffers::Pack. + if (flat_array == nullptr) { + return {}; + } std::vector ret(flat_array->Length()); for (int i = 0; i < flat_array->Length(); i++) { ret[i] = flat_array->Get(i); -- cgit v1.2.3 From 418c7258687166fc79a04f5a8c903c782a8ad295 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 11 Sep 2018 13:12:57 -0700 Subject: Optimize Spatial&Cuboid backward kernel convolutions. Without shuffle TensorExecutor uses optimized (specialized) gemm_pack_rhs to pack memory before contraction. Custom rhs packer is much faster than contracting by inner dimension with default packer. 1. CuboidConvolutionBwdKernel: ~10x-25x speedup 2. SpatialConvolutionBwdKernel: ~2x-10x speedup PiperOrigin-RevId: 212506483 --- .../kernels/eigen_backward_cuboid_convolutions.h | 44 ++++++++++------------ .../kernels/eigen_backward_spatial_convolutions.h | 41 +++++++++----------- 2 files changed, 38 insertions(+), 47 deletions(-) diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h index 27918b410b..f12c8d943d 100644 --- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h @@ -239,8 +239,8 @@ CuboidConvolutionBackwardInput( } } - // We will contract along the fused dimension that contains the kernelFilters, - // kernelPlanes, kernelRows and kernelCols. + // We will contract along the collapsed dimension that contains the + // kernelFilters, kernelPlanes, kernelRows and kernelCols. array, 1> contract_dims; if (isColMajor) { // col-major: kernel.contract(output.patches) @@ -331,24 +331,18 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< const TensorReshapingOp< const DSizes::Index, 2>, const OutputBackward>, - const TensorShufflingOp< - const array::Index, - 2>, - const TensorReshapingOp< - const DSizes::Index, 2>, - const TensorVolumePatchOp > > > >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorVolumePatchOp > > >, TensorReshapingOp< const DSizes::Index, 5>, const TensorContractionOp< const array::Index>, 1>, - const TensorShufflingOp< - const array::Index, - 2>, - const TensorReshapingOp< - const DSizes::Index, 2>, - const TensorVolumePatchOp > >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorVolumePatchOp >, const TensorReshapingOp< const DSizes::Index, 2>, const OutputBackward> > > >::type @@ -458,12 +452,16 @@ CuboidConvolutionBackwardKernel( eigen_assert(output_dims[0] == pre_contract_dims[0]); } - array shuffle_dims; - shuffle_dims[0] = 1; - shuffle_dims[1] = 0; - + // We will contract along the collapsed dimension that contains the + // outputCols, outputRows, outputPlanes and OTHERS. array, 1> contract_dims; - contract_dims[0] = IndexPair(1, 0); + if (isColMajor) { + // col-major: output_backward.contract(input.patches) + contract_dims[0] = IndexPair(1, 1); + } else { + // row-major: input.patches.contract(output_backward) + contract_dims[0] = IndexPair(0, 0); + } DSizes kernel_dims; if (isColMajor) { @@ -489,8 +487,7 @@ CuboidConvolutionBackwardKernel( strideRows, strideCols, 1, 1, 1, padding_top_z, padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) - .reshape(pre_contract_dims) - .shuffle(shuffle_dims), + .reshape(pre_contract_dims), contract_dims) .reshape(kernel_dims), input @@ -499,7 +496,6 @@ CuboidConvolutionBackwardKernel( padding_top_z, padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims) - .shuffle(shuffle_dims) .contract(output_backward.reshape(output_dims), contract_dims) .reshape(kernel_dims)); } diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h index 8d06107553..960920c55b 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -238,8 +238,8 @@ SpatialConvolutionBackwardInput( } } - // We will contract along the fused dimension that contains the kernelFilters, - // the kernelRows and the kernelCols. + // We will contract along the collapsed dimension that contains the + // kernelFilters, the kernelRows and the kernelCols. array, 1> contract_dims; if (isColMajor) { // col-major: kernel.contract(output.patches) @@ -332,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< const TensorReshapingOp< const DSizes::Index, 2>, const OutputBackward>, - const TensorShufflingOp< - const array::Index, - 2>, - const TensorReshapingOp< - const DSizes::Index, 2>, - const TensorImagePatchOp > > > >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp > > >, TensorReshapingOp< const DSizes::Index, 4>, const TensorContractionOp< const array::Index>, 1>, - const TensorShufflingOp< - const array::Index, - 2>, - const TensorReshapingOp< - const DSizes::Index, 2>, - const TensorImagePatchOp > >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp >, const TensorReshapingOp< const DSizes::Index, 2>, const OutputBackward> > > >::type @@ -456,12 +449,16 @@ SpatialConvolutionBackwardKernel( eigen_assert(output_dims[0] == pre_contract_dims[0]); } - array shuffle_dims; - shuffle_dims[0] = 1; - shuffle_dims[1] = 0; - + // We will contract along the collapsed dimension that contains the + // outputCols, outputRows and OTHERS. array, 1> contract_dims; - contract_dims[0] = IndexPair(1, 0); + if (isColMajor) { + // col-major: output_backward.contract(input.patches) + contract_dims[0] = IndexPair(1, 1); + } else { + // row-major: input.patches.contract(output_backward) + contract_dims[0] = IndexPair(0, 0); + } // After the contraction, the kernel will have the desired shape // out_depth X in_shape X kernel_rows X kernel_cols @@ -487,8 +484,7 @@ SpatialConvolutionBackwardKernel( kernelRows, kernelCols, row_stride, col_stride, row_in_stride, col_in_stride, 1, 1, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)) - .reshape(pre_contract_dims) - .shuffle(shuffle_dims), + .reshape(pre_contract_dims), contract_dims) .reshape(kernel_dims), input @@ -497,7 +493,6 @@ SpatialConvolutionBackwardKernel( padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)) .reshape(pre_contract_dims) - .shuffle(shuffle_dims) .contract(output_backward.reshape(output_dims), contract_dims) .reshape(kernel_dims)); } -- cgit v1.2.3 From da99f7ca018d4916447d7b984d9d65be1a9615a8 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 11 Sep 2018 13:46:29 -0700 Subject: Make control_flow_ops._ENABLE_COND_V2 public. Note this is not part of the official public API, but we do allow other modules to modify this value (e.g. in tests). PiperOrigin-RevId: 212512883 --- tensorflow/python/framework/test_util.py | 10 ++- .../kernel_tests/control_flow_ops_py_test.py | 72 +++++++++++----------- tensorflow/python/ops/control_flow_ops.py | 4 +- 3 files changed, 42 insertions(+), 44 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index b33cc8f544..6a2c897f3f 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -413,15 +413,13 @@ def enable_cond_v2(fn): The wrapped function """ - # pylint: disable=protected-access def wrapper(*args, **kwargs): - prev_value = control_flow_ops._ENABLE_COND_V2 - control_flow_ops._ENABLE_COND_V2 = True + prev_value = control_flow_ops.ENABLE_COND_V2 + control_flow_ops.ENABLE_COND_V2 = True try: fn(*args, **kwargs) finally: - control_flow_ops._ENABLE_COND_V2 = prev_value - # pylint: enable=protected-access + control_flow_ops.ENABLE_COND_V2 = prev_value return wrapper @@ -438,7 +436,7 @@ def with_cond_v2(cls): Returns: cls with new test methods added """ - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return cls for name, value in cls.__dict__.copy().items(): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index eac97af4ed..bdf7e0e4a0 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -333,7 +333,7 @@ class ControlFlowTest(test.TestCase): res.eval(feed_dict={data: 1.0}) def testCondBool(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113296297") values = constant_op.constant(10) @@ -384,7 +384,7 @@ class ControlFlowTest(test.TestCase): sess.run(r, feed_dict={t: 3}) def testCondIndexedSlices(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113296180") with self.test_session(): @@ -402,7 +402,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(0, ind) def testCondSparseTensor(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113296161 (SparseTensors)") with self.test_session(): @@ -422,7 +422,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(r.values.get_shape(), (2,)) def testCondResource(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -438,7 +438,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval()) def testCondIndexedSlicesDifferentTypes(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113293074") with self.test_session(): @@ -484,14 +484,14 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(11, result) def testCond_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") self._testCond_1(use_gpu=False) self._testCond_1(use_gpu=True) def testCond_2(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -503,7 +503,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(9, result) def testCond_3(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -518,7 +518,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(12, result) def testCond_4(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113324949 (ref vars)") with self.test_session(): @@ -556,7 +556,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(4, count.eval()) def testCond_6(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -583,7 +583,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([11, 12], sess.run(r)) def testCondRef(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -599,7 +599,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([2.0], r.eval()) def testCondWithControl(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/79881896") with self.test_session() as sess: @@ -641,7 +641,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([1.0], sess.run(merged_op.output)) def testCondSwitchIdentity(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/112477618 (Operation returned from cond)") # Make sure the recv identity is not removed by optimization. @@ -658,7 +658,7 @@ class ControlFlowTest(test.TestCase): sess.run(r) def testCondRecvIdentity(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/112477618 (Operation returned from cond)") # Make sure the switch identity is not removed by optimization. @@ -677,7 +677,7 @@ class ControlFlowTest(test.TestCase): sess.run(r) def testCondGrad_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113346829 (gpu failure)") graph = ops.Graph() @@ -706,7 +706,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3})) def testCondGrad_3(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/110550782 (gradient w.r.t external variable)") with self.test_session(): @@ -741,7 +741,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(1.0, result.eval()) def testCondGrad_Gather(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113327884") with self.test_session() as sess: @@ -916,7 +916,7 @@ class ControlFlowTest(test.TestCase): _ = gradients_impl.gradients(loop_with_maxiter, v) def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294340 (enable while_v2)") v = constant_op.constant(1.0) @@ -1375,7 +1375,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(10, sess.run(r, {b: True})) def testWhileCondWithControl(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") # Ensure that no control edges by an outer control dependency context are @@ -1392,7 +1392,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(0, sess.run(loop)) def testWhileCondWithControl_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113324949 (ref vars)") with self.test_session(): @@ -1417,7 +1417,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(65536.0, v.eval()) def testWhileCondExitControl(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294340 (enable while_v2)") with self.test_session(): @@ -1443,7 +1443,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(99, v.eval()) def testCondWhile_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -1456,7 +1456,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(10, r.eval()) def testCondWhile_2(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -1469,7 +1469,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(10, r.eval()) def _testCondWhile_3(self, use_gpu): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294340 (enable while_v2)") with self.test_session(use_gpu=use_gpu) as sess: @@ -1498,7 +1498,7 @@ class ControlFlowTest(test.TestCase): self._testCondWhile_3(use_gpu=True) def testWhileCond_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") with self.test_session(): @@ -1516,7 +1516,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(10, r.eval()) def testWhileCond_2(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") with self.test_session(): @@ -1527,7 +1527,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(10, r.eval()) def testWhileCond_3(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") with self.test_session(): @@ -1872,7 +1872,7 @@ class ControlFlowTest(test.TestCase): self._testWhileGrad_Mul(use_gpu=True, p_iters=10) def _testNestedWhileCondWhileGrad(self, use_gpu): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") with self.test_session(use_gpu=use_gpu): @@ -1913,7 +1913,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(216.0, r[0].eval()) def testWhileGradInCond(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/110550782 (gradient w.r.t external variable)") with self.test_session(): @@ -1964,7 +1964,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) def testCondGradInNestedWhiles(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113346829 (gpu failure)") def outer_body(i, x): @@ -2280,7 +2280,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(1024.0, r.eval()) def testWhileCondGrad_Simple(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") self._testWhileCondGrad_Simple(use_gpu=False) @@ -2633,7 +2633,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(5.0, result.eval()) def testOneValueCond(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -2651,7 +2651,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual([2], i.eval(feed_dict={c: 0})) def testExampleCond(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -2669,7 +2669,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) def testCase(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/112477618 (Operation returned from cond)") with self.test_session(): @@ -2724,7 +2724,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(r6.eval(), 0) def testCaseSideEffects(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/112477618 (Operation returned from cond)") with self.test_session() as sess: @@ -2762,7 +2762,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1]) def testOneOpCond(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113324949 (ref vars)") with self.test_session(): diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index e3c1aa3d5a..3c915b055a 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -61,7 +61,7 @@ from tensorflow.python.util import tf_should_use from tensorflow.python.util.tf_export import tf_export -_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0" +ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0" # We override the 'tuple' for a control flow op, so we keep python's @@ -2026,7 +2026,7 @@ def cond(pred, ``` """ - if _ENABLE_COND_V2: + if ENABLE_COND_V2: return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name) # We needed to make true_fn/false_fn keyword arguments for -- cgit v1.2.3 From 2832a4f9e125c00b64614880fb08376ee03fa2da Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 11 Sep 2018 14:04:27 -0700 Subject: Use Eigen::CuboidConvolutionBackwardInput in Conv3DBackpropInput. Instead of multiple primitive Eigen ops in Conv3DBackpropInput, call directly into the ex-NeuralNetworks module's function CuboidConvolutionBackwardInput. Modest ~10% latency improvement and ~15-20% peak memory reduction. PiperOrigin-RevId: 212516586 --- tensorflow/core/kernels/conv_3d.h | 22 ++++++++++++ tensorflow/core/kernels/conv_grad_ops_3d.cc | 53 +++++++---------------------- 2 files changed, 35 insertions(+), 40 deletions(-) diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h index 02e3655ad1..e5054e062e 100644 --- a/tensorflow/core/kernels/conv_3d.h +++ b/tensorflow/core/kernels/conv_3d.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_CONV_3D_H_ #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h" #include "tensorflow/core/kernels/eigen_cuboid_convolution.h" namespace tensorflow { @@ -28,6 +29,10 @@ namespace functor { template struct CuboidConvolution; +// Backward input pass for the cuboid convolution. +template +struct CuboidConvolutionBackwardInput; + typedef Eigen::ThreadPoolDevice CPUDevice; template @@ -42,6 +47,23 @@ struct CuboidConvolution { } }; +template +struct CuboidConvolutionBackwardInput { + void operator()(const CPUDevice& d, + typename TTypes::Tensor input_backward, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor output_backward, + int stride_planes, int stride_rows, int stride_cols) { + // Need to swap the order of plane/row/col strides when calling Eigen. + input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, + input_backward.dimension(3), // input_planes + input_backward.dimension(2), // input_rows + input_backward.dimension(1), // input_cols + stride_cols, stride_rows, stride_planes); + } +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index 15f1bf9aba..ec7c02ac2b 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -201,50 +201,23 @@ class Conv3DBackpropInputOp : public OpKernel { input_shape = context->input(0).shape(); } EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); - Eigen::array, 5> pad_dims{ - {0, 0}, - {top_pad_planes, bottom_pad_planes}, - {top_pad_rows, bottom_pad_rows}, - {left_pad_cols, right_pad_cols}, - {0, 0}}; + Tensor* in_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); - // Fill out a padded out_backprop. - TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows, - padded_out_cols, out_depth}); - Tensor padded_output; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::v(), - padded_out_shape, &padded_output)); - Eigen::DSizes no_op_shuffle{0, 1, 2, 3, 4}; - Eigen::DSizes eigen_strides{1, strides[0], strides[1], - strides[2], 1}; - functor::InflatePadAndShuffle()( - context->eigen_device(), out_backprop.tensor(), - eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor()); - const Tensor& padded_output_cref = padded_output; - - // Fill a new "reverted" filter. We need to transpose the in_depth and - // out_depth for the filter and reverse the planes, rows and cols. - TensorShape r_filter_shape( - {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth}); - Tensor r_filter; - OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), - r_filter_shape, &r_filter)); - Eigen::DSizes filter_order{0, 1, 2, 4, 3}; - Eigen::array filter_rev_dims{true, true, true, false, false}; - functor::ShuffleAndReverse()( - context->eigen_device(), filter.tensor(), filter_order, - filter_rev_dims, r_filter.tensor()); - const Tensor& r_filter_cref = r_filter; - - // Now we can call conv_3d directly. - functor::CuboidConvolution()( - context->eigen_device(), in_backprop->tensor(), - padded_output_cref.tensor(), r_filter_cref.tensor(), 1, 1, - 1, BrainPadding2EigenPadding(VALID)); + // There is no need to explicitly compute padding values (and pad + // out_backprop), because Eigen uses the same padding inference mechanism as + // Tensorflow. + functor::CuboidConvolutionBackwardInput()( + context->eigen_device(), + in_backprop->tensor(), // input_backward + filter.tensor(), // filter + out_backprop.tensor(), // output_backward + // Order of strides will be reversed before passing to Eigen. + static_cast(strides[0]), // stride_planes + static_cast(strides[1]), // stride_rows + static_cast(strides[2])); // stride_cols } private: -- cgit v1.2.3 From b40ab8d8a024bb934f25ebc3f5260b64c5816ef5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 14:05:59 -0700 Subject: Adds generator support directly to Keras's fit, evaluate, and predict. PiperOrigin-RevId: 212516939 --- tensorflow/python/keras/engine/training.py | 146 +++++++++++++++++---- tensorflow/python/keras/engine/training_test.py | 51 +++++++ tensorflow/python/keras/engine/training_utils.py | 12 ++ tensorflow/python/keras/utils/data_utils.py | 8 +- tensorflow/python/util/tf_inspect.py | 5 + .../api/golden/v1/tensorflow.keras.-model.pbtxt | 6 +- .../golden/v1/tensorflow.keras.-sequential.pbtxt | 6 +- .../golden/v1/tensorflow.keras.models.-model.pbtxt | 6 +- .../v1/tensorflow.keras.models.-sequential.pbtxt | 6 +- .../api/golden/v2/tensorflow.keras.-model.pbtxt | 6 +- .../golden/v2/tensorflow.keras.-sequential.pbtxt | 6 +- .../golden/v2/tensorflow.keras.models.-model.pbtxt | 6 +- .../v2/tensorflow.keras.models.-sequential.pbtxt | 6 +- 13 files changed, 223 insertions(+), 47 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 49b25e307e..c6749468c8 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -41,6 +41,7 @@ from tensorflow.python.keras.engine import training_eager from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine.network import Network +from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import math_ops from tensorflow.python.ops import weights_broadcast_ops @@ -1338,6 +1339,9 @@ class Model(Network): initial_epoch=0, steps_per_epoch=None, validation_steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, **kwargs): """Trains the model for a fixed number of epochs (iterations on a dataset). @@ -1350,19 +1354,23 @@ class Model(Network): - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - A `tf.data` dataset or a dataset iterator. Should return a tuple - of either (inputs, targets) or (inputs, targets, sample_weights). + of either `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A generator or `keras.utils.Sequence` returning `(inputs, targets)` + or `(inputs, targets, sample weights)`. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset or dataset - iterator, `y` should not be specified - (since targets will be obtained from the iterator). + tensor targets, or inversely). If `x` is a dataset, dataset + iterator, generator, or `keras.utils.Sequence` instance, `y` should + not be specified (since targets will be obtained from `x`). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the - form of symbolic tensors, datasets, or dataset iterators - (since they generate batches). + form of symbolic tensors, dataset, dataset iterators, + generators, or `keras.utils.Sequence` instances (since they generate + batches). epochs: Integer. Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided. @@ -1384,7 +1392,8 @@ class Model(Network): on this data at the end of each epoch. The validation data is selected from the last samples in the `x` and `y` data provided, before shuffling. This argument is - not supported when `x` is a dataset or a dataset iterator. + not supported when `x` is a dataset, dataset iterator, generator or + `keras.utils.Sequence` instance. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. @@ -1415,8 +1424,9 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset or a dataset iterator, instead - provide the sample_weights as the third element of `x`. + supported when `x` is a dataset, dataset iterator, generator, or + `keras.utils.Sequence` instance, instead provide the sample_weights + as the third element of `x`. initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run). @@ -1430,6 +1440,20 @@ class Model(Network): validation_steps: Only relevant if `steps_per_epoch` is specified. Total number of steps (batches of samples) to validate before stopping. + max_queue_size: Integer. Used for generator or `keras.utils.Sequence` + input only. Maximum size for the generator queue. + If unspecified, `max_queue_size` will default to 10. + workers: Integer. Used for generator or `keras.utils.Sequence` input + only. Maximum number of processes to spin up + when using process-based threading. If unspecified, `workers` + will default to 1. If 0, will execute the generator on the main + thread. + use_multiprocessing: Boolean. Used for generator or + `keras.utils.Sequence` input only. If `True`, use process-based + threading. If unspecified, `use_multiprocessing` will default to + `False`. Note that because this implementation relies on + multiprocessing, you should not pass non-picklable arguments to + the generator as they can't be passed easily to children processes. **kwargs: Used for backwards compatibility. Returns: @@ -1446,6 +1470,23 @@ class Model(Network): # TODO(fchollet): this method may be creating reference cycles, which would # lead to accumulating garbage in memory when called in a loop. Investigate. + if data_utils.is_generator_or_sequence(x): + training_utils.check_generator_arguments(y, sample_weight) + return self.fit_generator( + x, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_data=validation_data, + validation_steps=validation_steps, + class_weight=class_weight, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + shuffle=shuffle, + initial_epoch=initial_epoch) + # Backwards compatibility if batch_size is None and steps_per_epoch is None: batch_size = 32 @@ -1588,7 +1629,10 @@ class Model(Network): batch_size=None, verbose=1, sample_weight=None, - steps=None): + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False): """Returns the loss value & metrics values for the model in test mode. Computation is done in batches. @@ -1602,18 +1646,21 @@ class Model(Network): - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - A `tf.data` dataset or a dataset iterator. + - A generator or `keras.utils.Sequence` instance. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely). - If `x` is a dataset or a dataset iterator, `y` should not be specified - (since targets will be obtained from the iterator/dataset). + If `x` is a dataset, dataset iterator, generator or + `keras.utils.Sequence` instance, `y` should not be specified (since + targets will be obtained from the iterator/dataset). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors, datasets, or dataset iterators - (since they generate batches). + form of symbolic tensors, dataset, dataset iterators, + generators, or `keras.utils.Sequence` instances (since they generate + batches). verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. sample_weight: Optional Numpy array of weights for @@ -1627,11 +1674,25 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset or a dataset iterator. + supported when `x` is a dataset or a dataset iterator, instead pass + sample weights as the third element of `x`. steps: Integer or `None`. Total number of steps (batches of samples) before declaring the evaluation round finished. Ignored with the default value of `None`. + max_queue_size: Integer. Used for generator or `keras.utils.Sequence` + input only. Maximum size for the generator queue. + If unspecified, `max_queue_size` will default to 10. + workers: Integer. Used for generator or `keras.utils.Sequence` input + only. Maximum number of processes to spin up when using + process-based threading. If unspecified, `workers` will default + to 1. If 0, will execute the generator on the main thread. + use_multiprocessing: Boolean. Used for generator or + `keras.utils.Sequence` input only. If `True`, use process-based + threading. If unspecified, `use_multiprocessing` will default to + `False`. Note that because this implementation relies on + multiprocessing, you should not pass non-picklable arguments to + the generator as they can't be passed easily to children processes. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1642,6 +1703,16 @@ class Model(Network): Raises: ValueError: in case of invalid arguments. """ + if data_utils.is_generator_or_sequence(x): + training_utils.check_generator_arguments(y, sample_weight) + return self.evaluate_generator( + x, + steps=steps, + verbose=verbose, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) + # Backwards compatibility. if batch_size is None and steps is None: batch_size = 32 @@ -1688,7 +1759,14 @@ class Model(Network): verbose=verbose, steps=steps) - def predict(self, x, batch_size=None, verbose=0, steps=None): + def predict(self, + x, + batch_size=None, + verbose=0, + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False): """Generates output predictions for the input samples. Computation is done in batches. @@ -1700,16 +1778,32 @@ class Model(Network): - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - A `tf.data` dataset or a dataset iterator. + - A generator or `keras.utils.Sequence` instance. batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors, dataset, or dataset iterators - (since they generate batches). + form of symbolic tensors, dataset, dataset iterators, + generators, or `keras.utils.Sequence` instances (since they generate + batches). verbose: Verbosity mode, 0 or 1. steps: Total number of steps (batches of samples) before declaring the prediction round finished. Ignored with the default value of `None`. + max_queue_size: Integer. Used for generator or `keras.utils.Sequence` + input only. Maximum size for the generator queue. + If unspecified, `max_queue_size` will default to 10. + workers: Integer. Used for generator or `keras.utils.Sequence` input + only. Maximum number of processes to spin up when using + process-based threading. If unspecified, `workers` will default + to 1. If 0, will execute the generator on the main thread. + use_multiprocessing: Boolean. Used for generator or + `keras.utils.Sequence` input only. If `True`, use process-based + threading. If unspecified, `use_multiprocessing` will default to + `False`. Note that because this implementation relies on + multiprocessing, you should not pass non-picklable arguments to + the generator as they can't be passed easily to children processes. + Returns: Numpy array(s) of predictions. @@ -1720,6 +1814,15 @@ class Model(Network): or in case a stateful model receives a number of samples that is not a multiple of the batch size. """ + if data_utils.is_generator_or_sequence(x): + return self.predict_generator( + x, + steps=steps, + verbose=verbose, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) + # Backwards compatibility. if batch_size is None and steps is None: batch_size = 32 @@ -2071,7 +2174,7 @@ class Model(Network): Arguments: generator: Generator yielding tuples (inputs, targets) or (inputs, targets, sample_weights) - or an instance of Sequence (keras.utils.Sequence) + or an instance of `keras.utils.Sequence` object in order to avoid duplicate data when using multiprocessing. steps: Total number of steps (batches of samples) @@ -2135,9 +2238,8 @@ class Model(Network): Arguments: generator: Generator yielding batches of input samples - or an instance of Sequence (keras.utils.Sequence) - object in order to avoid duplicate data - when using multiprocessing. + or an instance of `keras.utils.Sequence` object in order to + avoid duplicate data when using multiprocessing. steps: Total number of steps (batches of samples) to yield from `generator` before stopping. Optional for `Sequence`: if unspecified, will use diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 8938333b1a..380130095b 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1322,6 +1322,57 @@ class TestGeneratorMethods(test.TestCase): workers=0, use_multiprocessing=False) + @tf_test_util.run_in_graph_and_eager_modes + def test_generator_input_to_fit_eval_predict(self): + val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) + + def custom_generator(): + while True: + yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) + + inputs = keras.layers.Input(shape=(10,)) + x = keras.layers.Dense(10, activation='relu')(inputs) + outputs = keras.layers.Dense(1, activation='sigmoid')(x) + model = keras.Model(inputs, outputs) + + model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy') + model.fit( + custom_generator(), + steps_per_epoch=2, + validation_data=val_data, + epochs=2) + model.evaluate(custom_generator(), steps=2) + model.predict(custom_generator(), steps=2) + + @tf_test_util.run_in_graph_and_eager_modes + def test_sequence_input_to_fit_eval_predict(self): + val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) + + class CustomSequence(keras.utils.Sequence): + + def __getitem__(self, idx): + return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) + + def __len__(self): + return 2 + + inputs = keras.layers.Input(shape=(10,)) + x = keras.layers.Dense(10, activation='relu')(inputs) + outputs = keras.layers.Dense(1, activation='sigmoid')(x) + model = keras.Model(inputs, outputs) + + model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy') + model.fit(CustomSequence(), validation_data=val_data, epochs=2) + model.evaluate(CustomSequence()) + model.predict(CustomSequence()) + + with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'): + model.fit(CustomSequence(), y=np.ones([10, 1])) + + with self.assertRaisesRegexp(ValueError, + '`sample_weight` argument is not supported'): + model.fit(CustomSequence(), sample_weight=np.ones([10, 1])) + class TestTrainingUtils(test.TestCase): diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 898e9223cb..8e9fab81d6 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -797,6 +797,18 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None): 'Received: x=%s, validation_split=%f' % (x, validation_split)) +def check_generator_arguments(y=None, sample_weight=None): + """Validates arguments passed when using a generator.""" + if y is not None: + raise ValueError('`y` argument is not supported when data is' + 'a generator or Sequence instance. Instead pass targets' + ' as the second element of the generator.') + if sample_weight is not None: + raise ValueError('`sample_weight` argument is not supported when data is' + 'a generator or Sequence instance. Instead pass sample' + ' weights as the third element of the generator.') + + def check_steps_argument(input_data, steps, steps_name): """Validates `steps` argument based on input data's type. diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index d93a7b6afc..b736daa46d 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -40,6 +40,7 @@ from six.moves.urllib.error import URLError from six.moves.urllib.request import urlopen from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -93,6 +94,11 @@ else: from six.moves.urllib.request import urlretrieve +def is_generator_or_sequence(x): + """Check if `x` is a Keras generator type.""" + return tf_inspect.isgenerator(x) or isinstance(x, Sequence) + + def _extract_archive(file_path, path='.', archive_format='auto'): """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. @@ -551,7 +557,7 @@ class OrderedEnqueuer(SequenceEnqueuer): self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda workers, initializer=init_pool, initargs=(seqs,)) else: - # We do not need the init since it's threads. + # We do not need the init since it's threads. self.executor_fn = lambda _: ThreadPool(workers) self.workers = workers self.queue = queue.Queue(max_queue_size) diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index 778121e15b..967c872c2a 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -325,6 +325,11 @@ def isfunction(object): # pylint: disable=redefined-builtin return _inspect.isfunction(tf_decorator.unwrap(object)[1]) +def isgenerator(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.isgenerator.""" + return _inspect.isgenerator(tf_decorator.unwrap(object)[1]) + + def ismethod(object): # pylint: disable=redefined-builtin """TFDecorator-aware replacement for inspect.ismethod.""" return _inspect.ismethod(tf_decorator.unwrap(object)[1]) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt index d843194ef0..0869de0243 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt @@ -151,7 +151,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "evaluate_generator" @@ -159,7 +159,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" @@ -219,7 +219,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "predict_generator" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt index b8e9baca71..20f39fae1e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt @@ -156,7 +156,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "evaluate_generator" @@ -164,7 +164,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" @@ -228,7 +228,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "predict_classes" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt index 472b9818df..4011719317 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt @@ -151,7 +151,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "evaluate_generator" @@ -159,7 +159,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" @@ -219,7 +219,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "predict_generator" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt index 937516eff1..8a12ac1ad8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt @@ -156,7 +156,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "evaluate_generator" @@ -164,7 +164,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" @@ -228,7 +228,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "predict_classes" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt index d843194ef0..0869de0243 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt @@ -151,7 +151,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "evaluate_generator" @@ -159,7 +159,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" @@ -219,7 +219,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "predict_generator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt index b8e9baca71..20f39fae1e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt @@ -156,7 +156,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "evaluate_generator" @@ -164,7 +164,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" @@ -228,7 +228,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "predict_classes" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt index 472b9818df..4011719317 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt @@ -151,7 +151,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "evaluate_generator" @@ -159,7 +159,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" @@ -219,7 +219,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "predict_generator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt index 937516eff1..8a12ac1ad8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt @@ -156,7 +156,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "evaluate_generator" @@ -164,7 +164,7 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit_generator" @@ -228,7 +228,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], " } member_method { name: "predict_classes" -- cgit v1.2.3 From 72410969ca8dd7f1be48672c6cb943940edb9f31 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 11 Sep 2018 14:10:31 -0700 Subject: Update defun to support extra params as function attributes. PiperOrigin-RevId: 212517784 --- tensorflow/python/eager/function.py | 79 ++++++++++++++++++++++++++++++-- tensorflow/python/eager/function_test.py | 61 ++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 8c30550708..348bf4650f 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -27,6 +27,7 @@ import threading import numpy as np import six +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context @@ -60,6 +61,10 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access +# TODO(scottzhu): Update this to allow arbitrary attribute names in future. +WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_" + + def _create_substitute_placeholder(value, name, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put @@ -100,6 +105,44 @@ def _get_device_functions(ctx, graph): return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access +def _parse_func_attrs(attributes): + """Convert the keyword arguments into function_def attributes. + + Currently only support primitive types: bool, int, float and string. + + Args: + attributes: the dictionary of attributes. + Returns: + A dict of attributes where the key is the name of attribute and the value + is the AttrValue proto. + Raises: + ValueError: If the kwargs contains unwhitelisted name or unsupported value + types. + """ + attrs = {} + for key, value in attributes.items(): + if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX): + raise ValueError("Attribute name is not whitelisted. " + "Whitelisted: prefix %s, got: %s" % + (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key)) + + if isinstance(value, attr_value_pb2.AttrValue): + attrs[key] = value + # bool type check has to happen before int since bool is a subclass of int. + elif isinstance(value, bool): + attrs[key] = attr_value_pb2.AttrValue(b=value) + elif isinstance(value, int): + attrs[key] = attr_value_pb2.AttrValue(i=value) + elif isinstance(value, float): + attrs[key] = attr_value_pb2.AttrValue(f=value) + elif isinstance(value, str): + attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) + else: + raise ValueError("Unsupported attribute type for %s with type %s" % + (key, type(value))) + return attrs + + class FuncGraph(ops.Graph): """Graph representing a function body. @@ -486,7 +529,7 @@ class Function(object): self._num_outputs = len(self._func_graph.outputs) self._output_shapes = tuple( output.shape for output in self._func_graph.outputs) - self._attrs = attrs or {} + self._attrs = _parse_func_attrs(attrs) self._device_functions = tuple( self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access @@ -909,7 +952,8 @@ class PolymorphicFunction(object): def __init__(self, python_function, name, - input_signature=None): + input_signature=None, + attributes=None): """Initializes a polymorphic function. Args: @@ -918,6 +962,8 @@ class PolymorphicFunction(object): input_signature: a possibly nested sequence of `TensorSpec` objects specifying the input signature of this function. If `None`, a separate function is instantiated for each inferred input signature. + attributes: dict, extra keyword arguments that will be added as attribute + of the function. Raises: ValueError: if `input_signature` is not None and the `python_function`'s @@ -935,6 +981,7 @@ class PolymorphicFunction(object): self._name = name self._function_cache = collections.OrderedDict() self._variables = [] + self._function_attributes = attributes or {} self._lock = threading.Lock() @@ -1149,7 +1196,8 @@ class PolymorphicFunction(object): if graph_function is None: graph_function = Function( func_graph_from_py_func(self._name, self._python_function, args, - kwds, self._input_signature)) + kwds, self._input_signature), + self._function_attributes) self._variables.extend( [v for v in graph_function.variables if v not in self._variables]) self._function_cache[cache_key] = graph_function @@ -1483,7 +1531,29 @@ def defun(func=None, input_signature=None): TypeError: If `input_signature` is neither `None` nor a sequence of `tf.contrib.eager.TensorSpec` objects. """ + return defun_with_attributes(func=func, input_signature=input_signature) + + +def defun_with_attributes(func=None, input_signature=None, attributes=None): + """Compiles a Python function into a callable TensorFlow graph. + + This function supports adding extra function attributes. See detailed + documentation in defun(). Currently this is not exposed in public API since we + don't expect user to directly use attributes, and attribute won't work by + itself. This assumption might change in future. + Args: + func: function to be compiled. + input_signature: same as defun()'s input_signature. + attributes: A dictionary of arguments which will be added to function def as + attributes. Currently only support primitive types as value, and only + whitelisted attribute name is allowed. Unwhitelisted attribute name or + unsupported value will result into ValueError. + + Returns: + Same as the return value of defun, with attributes added to the function in + graph. + """ if input_signature is not None: _validate_signature(input_signature) @@ -1495,7 +1565,8 @@ def defun(func=None, input_signature=None): name = "function" return tf_decorator.make_decorator( function, - PolymorphicFunction(function, name, input_signature=input_signature)) + PolymorphicFunction(function, name, input_signature=input_signature, + attributes=attributes)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 6507bc6d71..e6a49b66cf 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1501,6 +1501,67 @@ class FunctionTest(test.TestCase): side_effecting_function.python_function() self.assertAllEqual(state, [0, 0]) + def testFunctionWithExtraAttributes(self): + @function.defun_with_attributes(attributes={'experimental_1': 'value1', + 'experimental_2': 2}) + def matmul(x, y): + return math_ops.matmul(x, y) + + def add(x, y): + return math_ops.add(x, y) + defun_add = function.defun_with_attributes( + add, attributes={'experimental_3': True, 'experimental_4': 1.0}) + + with context.graph_mode(), self.test_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t) + double = defun_add(t, t) + self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) + self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + self.assertRegexpMatches( + functions[0].definition.signature.name, '.*matmul.*') + attrs = functions[0].definition.attr + self.assertEqual(len(attrs), 2) + self.assertEqual(attrs['experimental_1'].s, b'value1') + self.assertEqual(attrs['experimental_2'].i, 2) + + self.assertRegexpMatches( + functions[1].definition.signature.name, '.*add.*') + attrs = functions[1].definition.attr + self.assertEqual(len(attrs), 2) + self.assertEqual(attrs['experimental_3'].b, True) + self.assertEqual(attrs['experimental_4'].f, 1.0) + # pylint: enable=protected-access + + def testFunctionWithInvalidAttribute(self): + @function.defun_with_attributes(attributes={'attr1': 'value1'}) + def matmul(x, y): + return math_ops.matmul(x, y) + + with self.assertRaisesRegexp(ValueError, + '.*Attribute name is not whitelisted.*'): + with context.graph_mode(), self.test_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + matmul(t, t) + + @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) + def add(x, y): + return math_ops.add(x, y) + + with self.assertRaisesRegexp(ValueError, + '.*Unsupported attribute type.*'): + with context.graph_mode(), self.test_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + add(t, t) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): -- cgit v1.2.3 From 6ebe0abcc6bb3c3b50975cd2550bec2012389673 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Tue, 11 Sep 2018 14:17:07 -0700 Subject: Construct placer after the first optimization pass is run. PiperOrigin-RevId: 212518982 --- tensorflow/core/kernels/partitioned_function_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 7bb403290d..3ab7404ea9 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -127,12 +127,12 @@ class PartitionedCallOp : public AsyncOpKernel { optimization_options.graph = &graph; optimization_options.flib_def = overlay_lib; optimization_options.device_set = &device_set; - Placer placer(graph.get(), &device_set); OP_REQUIRES_OK_ASYNC( ctx, OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::PRE_PLACEMENT, optimization_options), done); + Placer placer(graph.get(), &device_set); OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); OP_REQUIRES_OK_ASYNC( ctx, -- cgit v1.2.3 From 328aeaeec83795c7de2589ca97a0b6d8b9a873e0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 14:31:09 -0700 Subject: Fixing broadcast pow. PiperOrigin-RevId: 212521825 --- .../lite/kernels/internal/reference/reference_ops.h | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 0abacf85e1..977367026d 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -4877,16 +4877,22 @@ inline void Pow(const RuntimeShape& input1_shape, const T* input1_data, } template -inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape, +inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape, const T* input1_data, - const RuntimeShape& input2_shape, + const RuntimeShape& unextended_input2_shape, const T* input2_data, - const RuntimeShape& output_shape, + const RuntimeShape& unextended_output_shape, T* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, - &desc2); + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); for (int b = 0; b < output_shape.Dims(0); ++b) { for (int y = 0; y < output_shape.Dims(1); ++y) { -- cgit v1.2.3 From ba650a5c989106330519dbde0de368f580435a8b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 14:45:36 -0700 Subject: Fix typos in the comment for the class Categorical. PiperOrigin-RevId: 212524769 --- tensorflow/python/ops/distributions/categorical.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py index dd25fce2ec..fbbacf2521 100644 --- a/tensorflow/python/ops/distributions/categorical.py +++ b/tensorflow/python/ops/distributions/categorical.py @@ -69,7 +69,7 @@ class Categorical(distribution.Distribution): The Categorical distribution is closely related to the `OneHotCategorical` and `Multinomial` distributions. The Categorical distribution can be intuited as generating samples according to `argmax{ OneHotCategorical(probs) }` itself - being identical to `argmax{ Multinomial(probs, total_count=1) }. + being identical to `argmax{ Multinomial(probs, total_count=1) }`. #### Mathematical Details @@ -83,7 +83,7 @@ class Categorical(distribution.Distribution): The number of classes, `K`, must not exceed: - the largest integer representable by `self.dtype`, i.e., - `2**(mantissa_bits+1)` (IEE754), + `2**(mantissa_bits+1)` (IEEE 754), - the maximum `Tensor` index, i.e., `2**31-1`. In other words, -- cgit v1.2.3 From f3242baaf10842ff4753b5974f426cf963fa8eef Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 Sep 2018 15:02:21 -0700 Subject: Add support for populating a feature columns to output tensors dictionary in input_layer. PiperOrigin-RevId: 212528172 --- tensorflow/python/feature_column/feature_column.py | 25 ++++++++++++---- .../python/feature_column/feature_column_test.py | 34 ++++++++++++++++++++++ .../api/golden/v1/tensorflow.feature_column.pbtxt | 2 +- .../api/golden/v2/tensorflow.feature_column.pbtxt | 2 +- 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 2246d2f3e9..9984379e9d 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -169,7 +169,8 @@ def _internal_input_layer(features, weight_collections=None, trainable=True, cols_to_vars=None, - scope=None): + scope=None, + cols_to_output_tensors=None): """See input_layer. `scope` is a name or variable scope to use.""" feature_columns = _normalize_feature_columns(feature_columns) @@ -202,14 +203,17 @@ def _internal_input_layer(features, trainable=trainable) num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access batch_size = array_ops.shape(tensor)[0] - output_tensors.append( - array_ops.reshape(tensor, shape=(batch_size, num_elements))) + output_tensor = array_ops.reshape( + tensor, shape=(batch_size, num_elements)) + output_tensors.append(output_tensor) if cols_to_vars is not None: # Retrieve any variables created (some _DenseColumn's don't create # variables, in which case an empty list is returned). cols_to_vars[column] = ops.get_collection( ops.GraphKeys.GLOBAL_VARIABLES, scope=variable_scope.get_variable_scope().name) + if cols_to_output_tensors is not None: + cols_to_output_tensors[column] = output_tensor _verify_static_batch_size_equality(output_tensors, ordered_columns) return array_ops.concat(output_tensors, 1) @@ -219,7 +223,8 @@ def input_layer(features, feature_columns, weight_collections=None, trainable=True, - cols_to_vars=None): + cols_to_vars=None, + cols_to_output_tensors=None): """Returns a dense `Tensor` as input layer based on given `feature_columns`. Generally a single example in training data is described with FeatureColumns. @@ -264,6 +269,9 @@ def input_layer(features, dimension=10): [ Date: Tue, 11 Sep 2018 15:17:57 -0700 Subject: Add a printout at the start of MetaOptimizer::Optimize() to make it easier to see the total cost of running Grappler in logs. Also add a couple of VLOG(1) statements to see breakdown between main graph and function optimization. PiperOrigin-RevId: 212531430 --- tensorflow/core/grappler/optimizers/meta_optimizer.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index a5fd33d28b..8c99598748 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -331,10 +331,12 @@ Status MetaOptimizer::RunOptimizer( Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { + LOG(INFO) << "Starting optimization for grappler item: " << item.id; optimization_results_.clear(); // 1. Optimize main graph TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph)); + VLOG(1) << "Optimized main graph."; // 2. Optimize function library FunctionLibraryDefinition flib(OpRegistry::Global(), @@ -398,7 +400,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } } - VLOG(3) << "Optimized " << optimized_funcs.size() + VLOG(1) << "Optimized " << optimized_funcs.size() << " functions: " << str_util::Join(optimized_funcs, ", "); return Status::OK(); -- cgit v1.2.3 From 7f9f25a008369ac90e7b96c4f58a3dd1c662d89c Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Tue, 11 Sep 2018 15:28:10 -0700 Subject: Move Quantile Stream Resource to core. Allow each Resource to manage multiple streams that share the same quantile config -- number of quantiles and epsilon. Previously each resource manage only one stream, so we will have to create resources equal to the number of features, which is cumbersome when input is high dimensional. If 1000 features use 100 quantiles (which is hardcoded today), then 1000 resources is required. This cl will create the number of resources linear to the number of parameter servers, if 2 parameter servers are present, then only 2 resources is required, one for each ps. Remove time stamp token as the ops are called once. PiperOrigin-RevId: 212533735 --- .../base_api/api_def_BoostedTreesBucketize.pbtxt | 34 ++ ..._BoostedTreesCreateQuantileStreamResource.pbtxt | 29 ++ ...api_def_BoostedTreesMakeQuantileSummaries.pbtxt | 40 ++ ...edTreesQuantileStreamResourceAddSummaries.pbtxt | 22 + ...f_BoostedTreesQuantileStreamResourceFlush.pbtxt | 31 ++ ...QuantileStreamResourceGetBucketBoundaries.pbtxt | 27 ++ ...oostedTreesQuantileStreamResourceHandleOp.pbtxt | 5 + ...tedTreesQuantileStreamResourceInitialized.pbtxt | 20 + tensorflow/core/kernels/boosted_trees/BUILD | 16 +- .../core/kernels/boosted_trees/quantile_ops.cc | 453 +++++++++++++++++++++ .../core/kernels/boosted_trees/quantiles/BUILD | 4 +- .../quantiles/quantile_stream_resource.h | 96 +++++ tensorflow/core/ops/boosted_trees_ops.cc | 125 ++++++ tensorflow/python/kernel_tests/boosted_trees/BUILD | 13 + .../boosted_trees/quantile_ops_test.py | 140 +++++++ tensorflow/python/ops/boosted_trees_ops.py | 6 + 16 files changed, 1059 insertions(+), 2 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt create mode 100644 tensorflow/core/kernels/boosted_trees/quantile_ops.cc create mode 100644 tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h create mode 100644 tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt new file mode 100644 index 0000000000..cdaeb5091c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt @@ -0,0 +1,34 @@ +op { + graph_op_name: "BoostedTreesBucketize" + visibility: HIDDEN + in_arg { + name: "float_values" + description: <