aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-09-18 18:18:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 18:22:20 -0700
commitff2e46cd768b9161235f10f6f8bbb23cb27314dc (patch)
treed20cc6c25820966c9d8f8cbba68adf1f5a8c3f63 /tensorflow/python/eager
parentc2dc702159cfccb623b99daf2f9df875a1f3dbfd (diff)
Update the grappler plugin to support the @defun generated function and ops.
PiperOrigin-RevId: 213554813
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function_test.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 6326a5b45f..4a1bde3f5e 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -26,6 +26,7 @@ import weakref
import numpy
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import keras
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
@@ -1729,6 +1730,51 @@ class FunctionTest(test.TestCase):
'be Tensors;.*'):
graph_function('Not a Tensor.')
+ def testSwapImplementationWithGrapplerPlugin(self):
+ rewrites = rewriter_config_pb2.RewriterConfig()
+ # function_optimizer has to be turn off, otherwise it will delete the
+ # registered function if it does not get called.
+ # TODO(scottzhu): Move the ExperimentalImplementationSelector to be called
+ # before function_optimizer in future.
+ rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
+ customer_optimizer = rewrites.custom_optimizers.add()
+ customer_optimizer.name = 'ExperimentalImplementationSelector'
+ rewrites.min_graph_nodes = -1
+ graph_options = config_pb2.GraphOptions(
+ rewrite_options=rewrites, build_cost_model=1)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+
+ with context.graph_mode(), self.cached_session(
+ config=config, graph=ops.Graph(), use_gpu=True) as sess:
+
+ @function.defun_with_attributes(
+ attributes={
+ 'experimental_api_implements': 'random_boost',
+ 'experimental_api_preferred_device': 'CPU'
+ })
+ def cpu_boost(x):
+ return math_ops.add(x, 2.0)
+
+ @function.defun_with_attributes(
+ attributes={
+ 'experimental_api_implements': 'random_boost',
+ 'experimental_api_preferred_device': 'GPU'
+ })
+ def gpu_boost(x):
+ return math_ops.add(x, 4.0)
+
+ x = constant_op.constant(1.0)
+
+ function.register(cpu_boost, x)
+ y = gpu_boost(x)
+ y_value = sess.run(y)
+
+ if test.is_gpu_available():
+ self.assertEquals(y_value, 5.0)
+ else:
+ # Grappler fallback to use the CPU impl even called with GPU function.
+ self.assertEquals(y_value, 3.0)
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):