diff options
author | Scott Zhu <scottzhu@google.com> | 2018-09-18 18:18:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 18:22:20 -0700 |
commit | ff2e46cd768b9161235f10f6f8bbb23cb27314dc (patch) | |
tree | d20cc6c25820966c9d8f8cbba68adf1f5a8c3f63 /tensorflow/python/eager | |
parent | c2dc702159cfccb623b99daf2f9df875a1f3dbfd (diff) |
Update the grappler plugin to support the @defun generated function and ops.
PiperOrigin-RevId: 213554813
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function_test.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 6326a5b45f..4a1bde3f5e 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -26,6 +26,7 @@ import weakref import numpy from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop @@ -1729,6 +1730,51 @@ class FunctionTest(test.TestCase): 'be Tensors;.*'): graph_function('Not a Tensor.') + def testSwapImplementationWithGrapplerPlugin(self): + rewrites = rewriter_config_pb2.RewriterConfig() + # function_optimizer has to be turn off, otherwise it will delete the + # registered function if it does not get called. + # TODO(scottzhu): Move the ExperimentalImplementationSelector to be called + # before function_optimizer in future. + rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF + customer_optimizer = rewrites.custom_optimizers.add() + customer_optimizer.name = 'ExperimentalImplementationSelector' + rewrites.min_graph_nodes = -1 + graph_options = config_pb2.GraphOptions( + rewrite_options=rewrites, build_cost_model=1) + config = config_pb2.ConfigProto(graph_options=graph_options) + + with context.graph_mode(), self.cached_session( + config=config, graph=ops.Graph(), use_gpu=True) as sess: + + @function.defun_with_attributes( + attributes={ + 'experimental_api_implements': 'random_boost', + 'experimental_api_preferred_device': 'CPU' + }) + def cpu_boost(x): + return math_ops.add(x, 2.0) + + @function.defun_with_attributes( + attributes={ + 'experimental_api_implements': 'random_boost', + 'experimental_api_preferred_device': 'GPU' + }) + def gpu_boost(x): + return math_ops.add(x, 4.0) + + x = constant_op.constant(1.0) + + function.register(cpu_boost, x) + y = gpu_boost(x) + y_value = sess.run(y) + + if test.is_gpu_available(): + self.assertEquals(y_value, 5.0) + else: + # Grappler fallback to use the CPU impl even called with GPU function. + self.assertEquals(y_value, 3.0) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): |