aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-07 18:47:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 18:53:18 -0700
commit4fd48f57cd1dcd960bea1757e1c59032db66b3d0 (patch)
tree49ef65ef479a2904ff08d5ed9bfb65bac08a64f4 /tensorflow/compiler/tests
parent3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (diff)
Decluster some must-be-constant ops to reduce XLA recompilations
The CL is organized as follows: - The main change is in jit/partially_decluster_pass. - tf2xla/const_analysis now takes an "edge_filter" to facilitate use by jit/partially_decluster_pass. - tests/dense_layer_test.py was using the execution of ListDiff as what I assume is a sanity check to see that the XLA cluster ran. With this CL the ListDiff op gets declustered so we now check for "MatMult" for the sanity check. - Some tests were dropping TF_XLA_FLAGS; fixed them to not do so. PiperOrigin-RevId: 212071118
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/dense_layer_test.py7
-rw-r--r--tensorflow/compiler/tests/jit_test.py5
2 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 04f3b3ef49..0af74c2d8f 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase):
Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
"""
- os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
+ os.environ["TF_XLA_FLAGS"] = (
+ "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
config = config_pb2.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
config_pb2.OptimizerOptions.ON_1)
@@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase):
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(1, XlaLaunchOpCount(labels))
- self.assertFalse(InLabels(labels, "ListDiff"))
+ self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
@@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase):
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(2, XlaLaunchOpCount(labels))
- self.assertFalse(InLabels(labels, "ListDiff"))
+ self.assertFalse(InLabels(labels, "MatMult"))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 6e0db54b7a..0839fb123e 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase):
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
arg1 = np.random.rand(2, 2).astype(np.float32)
- os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
- "--tf_xla_cpu_global_jit")
+ os.environ["TF_XLA_FLAGS"] = (
+ "--tf_xla_fusion_only=true "
+ "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
tf_op, tf_count = self.simpleTest(arg0, arg1,
config_pb2.OptimizerOptions.OFF)
self.assertEqual(0, tf_count)