aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-18 15:02:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 15:05:23 -0700
commitd964834a922e77198fd387aac6c6cc5970a31e7d (patch)
treee6e5e914abf941b161d46ad4c2e940422643eccf /tensorflow/python/framework/test_util.py
parent325ba9ece698d04082b173ba300a10623d27de96 (diff)
Merged commit includes the following changes:
193422827 by yifeif: Fix buildifier error. -- 193421691 by skyewm: Make GraphModeFunctions work with _USE_C_SHAPES=True. Tensor._handle_data is going away. This change adds special hooks for propagating the resource handle shape information through EagerTensors. -- 193421473 by A. Unique TensorFlower: Register dynamic_stitch for DT_VARIANT type. -- 193421175 by nolivia: disabling flaky tsan test -- 193420117 by nolivia: disabling flaky test in tensorflow that has no apparent culprit -- PiperOrigin-RevId: 193422827
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 70e70abc06..f954b9d6c7 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -464,6 +464,30 @@ def with_c_api(cls):
return cls
+def with_c_shapes(cls):
+ """Adds methods that call original methods but with C API shapes enabled.
+
+ Note this enables C shapes in new methods after running the test class's
+ setup method.
+
+ Args:
+ cls: class to decorate
+
+ Returns:
+ cls with new test methods added
+ """
+ # If C shapes are already enabled, don't do anything. Some tests break if the
+ # same test is run twice, so this allows us to turn on the C shapes by default
+ # without breaking these tests.
+ if ops._USE_C_SHAPES:
+ return cls
+
+ for name, value in cls.__dict__.copy().items():
+ if callable(value) and name.startswith("test"):
+ setattr(cls, name + "WithCShapes", enable_c_shapes(value))
+ return cls
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.