diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-09-07 18:47:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 18:53:18 -0700 |
commit | 4fd48f57cd1dcd960bea1757e1c59032db66b3d0 (patch) | |
tree | 49ef65ef479a2904ff08d5ed9bfb65bac08a64f4 /tensorflow/compiler/tests | |
parent | 3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (diff) |
Decluster some must-be-constant ops to reduce XLA recompilations
The CL is organized as follows:
- The main change is in jit/partially_decluster_pass.
- tf2xla/const_analysis now takes an "edge_filter" to facilitate use by
jit/partially_decluster_pass.
- tests/dense_layer_test.py was using the execution of ListDiff as what I
assume is a sanity check to see that the XLA cluster ran. With this CL the
ListDiff op gets declustered so we now check for "MatMult" for the sanity
check.
- Some tests were dropping TF_XLA_FLAGS; fixed them to not do so.
PiperOrigin-RevId: 212071118
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/dense_layer_test.py | 7 | ||||
-rw-r--r-- | tensorflow/compiler/tests/jit_test.py | 5 |
2 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 04f3b3ef49..0af74c2d8f 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase): Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. """ - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) config = config_pb2.ConfigProto() config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1) @@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase): labels = GetRunMetadataLabels(run_metadata) self.assertEqual(1, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertFalse(InLabels(labels, "MatMult")) def testDenseLayerJitScopeDefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. @@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase): labels = GetRunMetadataLabels(run_metadata) self.assertEqual(2, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertFalse(InLabels(labels, "MatMult")) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 6e0db54b7a..0839fb123e 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase): def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) |