diff options
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/dense_layer_test.py | 7 | ||||
-rw-r--r-- | tensorflow/compiler/tests/jit_test.py | 5 |
2 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 04f3b3ef49..0af74c2d8f 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase): Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. """ - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) config = config_pb2.ConfigProto() config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1) @@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase): labels = GetRunMetadataLabels(run_metadata) self.assertEqual(1, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertFalse(InLabels(labels, "MatMult")) def testDenseLayerJitScopeDefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. @@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase): labels = GetRunMetadataLabels(run_metadata) self.assertEqual(2, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertFalse(InLabels(labels, "MatMult")) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 6e0db54b7a..0839fb123e 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase): def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) |