diff options
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 70e70abc06..f954b9d6c7 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -464,6 +464,30 @@ def with_c_api(cls): return cls +def with_c_shapes(cls): + """Adds methods that call original methods but with C API shapes enabled. + + Note this enables C shapes in new methods after running the test class's + setup method. + + Args: + cls: class to decorate + + Returns: + cls with new test methods added + """ + # If C shapes are already enabled, don't do anything. Some tests break if the + # same test is run twice, so this allows us to turn on the C shapes by default + # without breaking these tests. + if ops._USE_C_SHAPES: + return cls + + for name, value in cls.__dict__.copy().items(): + if callable(value) and name.startswith("test"): + setattr(cls, name + "WithCShapes", enable_c_shapes(value)) + return cls + + def assert_no_new_pyobjects_executing_eagerly(f): """Decorator for asserting that no new Python objects persist after a test. |