aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 70e70abc06..f954b9d6c7 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -464,6 +464,30 @@ def with_c_api(cls):
return cls
+def with_c_shapes(cls):
+ """Adds methods that call original methods but with C API shapes enabled.
+
+ Note this enables C shapes in new methods after running the test class's
+ setup method.
+
+ Args:
+ cls: class to decorate
+
+ Returns:
+ cls with new test methods added
+ """
+ # If C shapes are already enabled, don't do anything. Some tests break if the
+ # same test is run twice, so this allows us to turn on the C shapes by default
+ # without breaking these tests.
+ if ops._USE_C_SHAPES:
+ return cls
+
+ for name, value in cls.__dict__.copy().items():
+ if callable(value) and name.startswith("test"):
+ setattr(cls, name + "WithCShapes", enable_c_shapes(value))
+ return cls
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.