aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-03-13 16:31:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-13 17:16:37 -0700
commit39dc589ad5ffccef4caeb099588d8e33d181ca6c (patch)
tree13b079f16cfce8bf110de94312e93d9078a16317 /tensorflow/python/framework/test_util.py
parenta20c9fb63ef1eeabbe439b5a13aa29deaca44861 (diff)
Introduce _USE_C_SHAPES toggle along with _USE_C_API toggle.
This is a second level of staging before fully enabling the C API. With _USE_C_API enabled but _USE_C_SHAPES disabled, the C API is used for everything but retrieving the shape of Tensors (i.e. we continue using the existing Python shape inference to implement Tensor.shape). This is useful because many tests fail with the C API fully enabled. This will allow us to enable everything but the full shape inference, fix the remaining broken tests, and then enable the full shape shape inference. This change also introduces the test_util.enable_c_shapes test method decorator. This can be used to enable C shapes for a specific test method. This is useful for tests that have already been modified to work with full C shape inference. PiperOrigin-RevId: 188949619
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 284e264acd..e9e86e452b 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -407,6 +407,31 @@ def enable_c_api(fn):
return wrapper
+def enable_c_shapes(fn):
+ """Decorator for enabling C shapes on a test.
+
+ Note this enables the C shapes after running the test class's setup/teardown
+ methods.
+
+ Args:
+ fn: the function to be wrapped
+
+ Returns:
+ The wrapped function
+ """
+
+ def wrapper(*args, **kwargs):
+ prev_value = ops._USE_C_SHAPES
+ # Only use C shapes if the C API is already enabled.
+ ops._USE_C_SHAPES = ops._USE_C_API
+ try:
+ fn(*args, **kwargs)
+ finally:
+ ops._USE_C_SHAPES = prev_value
+
+ return wrapper
+
+
# This decorator is a hacky way to run all the test methods in a decorated
# class with and without C API enabled.
# TODO(iga): Remove this and its uses once we switch to using C API by default.