aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2017-12-07 15:18:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 15:22:18 -0800
commitf37380b064948fb6dd45feef0e8d93130c2f9884 (patch)
tree1481e8ee34040533d5927501a4c8ba5b3f8a8f45
parent2d4c29cd6a0627fdd71a752e6bd919204c7cb8bf (diff)
Add tfe.py_func, a tf.py_func-like construct that wraps a Python function and executes it eagerly.
In particular, an EagerPyFunc op is added that wraps a Python function and executes it eagerly. The wrapped function should take Tensors as inputs and return Tensors as outputs. Because functions wrapped in an EagerPyFunc are executed eagerly, they can make use of TensorFlow operations. EagerPyFunc should be differentiable, in principle; a gradient will be implemented and registered in a future change. Once a gradient is implemented, tfe.py_func will probably be the easiest mechanism for experimenting with custom ops. tfe.py_func will also make it easier to translate python functions with side-effects into defun-able code. PiperOrigin-RevId: 178303818
-rw-r--r--tensorflow/contrib/eager/python/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/tfe.py3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_EagerPyFunc.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EagerPyFunc.pbtxt4
-rw-r--r--tensorflow/core/ops/script_ops.cc14
-rw-r--r--tensorflow/python/BUILD4
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py221
-rw-r--r--tensorflow/python/lib/core/py_func.cc72
-rw-r--r--tensorflow/python/lib/core/py_func.h20
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/script_ops.py176
12 files changed, 399 insertions, 127 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index 6e9bb87d58..fb667cd91b 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -19,6 +19,7 @@ py_library(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:numerics",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:script_ops",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:backprop",
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index 1697c879de..770a7e3e7a 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -23,6 +23,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@list_devices
@@num_gpus
+@@py_func
@@defun
@@implicit_gradients
@@implicit_value_and_gradients
@@ -101,8 +102,10 @@ from tensorflow.python.framework.test_util import IsolateTest
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.ops.variable_scope import EagerVariableStore
+from tensorflow.python.ops import script_ops
from tensorflow.python.util.all_util import remove_undocumented
+py_func = script_ops.eager_py_func
defun = function.defun
implicit_gradients = backprop.implicit_grad
implicit_value_and_gradients = backprop.implicit_val_and_grad
diff --git a/tensorflow/core/api_def/base_api/api_def_EagerPyFunc.pbtxt b/tensorflow/core/api_def/base_api/api_def_EagerPyFunc.pbtxt
new file mode 100644
index 0000000000..9231368e16
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_EagerPyFunc.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "EagerPyFunc"
+ summary: "Eagerly executes a python function to compute func(input)->output. The"
+ description: <<END
+semantics of the input, output, and attributes are the same as those for
+PyFunc.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EagerPyFunc.pbtxt b/tensorflow/core/api_def/python_api/api_def_EagerPyFunc.pbtxt
new file mode 100644
index 0000000000..ee0f95dacb
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EagerPyFunc.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "EagerPyFunc"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/ops/script_ops.cc b/tensorflow/core/ops/script_ops.cc
index 8197327b56..c7c594a999 100644
--- a/tensorflow/core/ops/script_ops.cc
+++ b/tensorflow/core/ops/script_ops.cc
@@ -51,4 +51,18 @@ REGISTER_OP("PyFuncStateless")
A stateless version of PyFunc.
)doc");
+REGISTER_OP("EagerPyFunc")
+ .Input("input: Tin")
+ .Output("output: Tout")
+ .Attr("token: string")
+ .Attr("Tin: list(type) >= 0")
+ .Attr("Tout: list(type) >=0")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Eagerly executes a python function to compute func(input)->output. The
+semantics of the input, output, and attributes are the same as those for
+PyFunc.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 868ffcb473..e5c4347833 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -280,10 +280,14 @@ cc_library(
":ndarray_tensor_bridge",
":numpy_lib",
":py_util",
+ ":safe_ptr",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/c/eager:c_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:script_ops_op_lib",
+ "//tensorflow/python/eager:pywrap_tfe_lib",
"//third_party/py/numpy:headers",
"//util/python:python_headers",
],
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 0660f40300..f017004e1a 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1645,6 +1645,8 @@ cuda_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:script_ops",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:function",
],
tags = ["no_windows"],
)
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 7ed99c1be9..92fb68820e 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -23,82 +23,93 @@ from six.moves import queue
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session as session_lib
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class PyOpTest(test.TestCase):
+def np_func(x, y):
+ return np.sinh(x) + np.cosh(y)
- def testBasic(self):
- def my_func(x, y):
- return np.sinh(x) + np.cosh(y)
+def matmul(x, y):
+ return math_ops.matmul(x, y)
- # single type
+
+class PyFuncTest(test.TestCase):
+ """Encapsulates tests for py_func and eager_py_func."""
+
+ # ----- Tests for py_func -----
+ def testSingleType(self):
with self.test_session():
x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32)
- z = script_ops.py_func(my_func, [x, y], dtypes.float32)
- self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32))
+ z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
+ self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
- # scalar
+ def testScalar(self):
with self.test_session():
x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32)
- z = script_ops.py_func(my_func, [x, y], [dtypes.float32])
- self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32))
+ z = self.evaluate(
+ script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
+ self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
- # array
+ def testArray(self):
with self.test_session():
x = constant_op.constant([1.0, 2.0], dtypes.float64)
y = constant_op.constant([2.0, 3.0], dtypes.float64)
- z = script_ops.py_func(my_func, [x, y], [dtypes.float64])
- self.assertAllEqual(z[0].eval(),
- my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
+ z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
+ self.assertAllEqual(z[0],
+ np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
- # a bit exotic type (complex64)
+ def testComplexType(self):
with self.test_session():
x = constant_op.constant(1 + 2j, dtypes.complex64)
y = constant_op.constant(3 + 4j, dtypes.complex64)
- z, = script_ops.py_func(my_func, [x, y], [dtypes.complex64])
- self.assertAllClose(z.eval(), my_func(1 + 2j, 3 + 4j))
+ z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
+ self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
- # a bit excotic function (rfft)
+ def testRFFT(self):
with self.test_session():
x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
def rfft(x):
return np.fft.rfft(x).astype(np.complex64)
- y, = script_ops.py_func(rfft, [x], [dtypes.complex64])
- self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.]))
+ y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64))
+ self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
- # returns a python literal.
+ def testPythonLiteral(self):
with self.test_session():
def literal(x):
- return 1.0 if x == 0.0 else 0.0
+ return 1.0 if float(x) == 0.0 else 0.0
x = constant_op.constant(0.0, dtypes.float64)
- y, = script_ops.py_func(literal, [x], [dtypes.float64])
- self.assertAllClose(y.eval(), 1.0)
+ y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64))
+ self.assertAllClose(y, 1.0)
- # returns a list
+ def testList(self):
with self.test_session():
def list_func(x):
return [x, x + 1]
x = constant_op.constant(0.0, dtypes.float64)
- y, z = script_ops.py_func(list_func, [x], [dtypes.float64] * 2)
- self.assertAllClose(y.eval(), 0.0)
- self.assertAllClose(z.eval(), 1.0)
+ y = self.evaluate(
+ script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
+ self.assertAllClose(y, [0.0, 1.0])
+ def testTuple(self):
# returns a tuple
with self.test_session():
@@ -106,17 +117,17 @@ class PyOpTest(test.TestCase):
return x, x + 1
x = constant_op.constant(0.0, dtypes.float64)
- y, z = script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2)
- self.assertAllClose(y.eval(), 0.0)
- self.assertAllClose(z.eval(), 1.0)
+ y = self.evaluate(
+ script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2))
+ self.assertAllClose(y, [0.0, 1.0])
# returns a tuple, Tout and inp a tuple
with self.test_session():
x = constant_op.constant(0.0, dtypes.float64)
- y, z = script_ops.py_func(tuple_func, (x,), (dtypes.float64,
- dtypes.float64))
- self.assertAllClose(y.eval(), 0.0)
- self.assertAllClose(z.eval(), 1.0)
+ y = self.evaluate(
+ script_ops.py_func(tuple_func, (x,),
+ (dtypes.float64, dtypes.float64)))
+ self.assertAllClose(y, [0.0, 1.0])
def testStrings(self):
@@ -128,10 +139,12 @@ class PyOpTest(test.TestCase):
with self.test_session():
x = constant_op.constant([b"hello", b"hi"], dtypes.string)
- y, = script_ops.py_func(read_fixed_length_numpy_strings, [],
- [dtypes.string])
- z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
- self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])
+ y = self.evaluate(
+ script_ops.py_func(read_fixed_length_numpy_strings, [],
+ dtypes.string))
+ z = self.evaluate(
+ script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
+ self.assertAllEqual(z, [b"hello there", b"hi there"])
def testStringsAreConvertedToBytes(self):
@@ -143,10 +156,12 @@ class PyOpTest(test.TestCase):
with self.test_session():
x = constant_op.constant(["hello", "hi"], dtypes.string)
- y, = script_ops.py_func(read_fixed_length_numpy_strings, [],
- [dtypes.string])
- z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
- self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])
+ y = self.evaluate(
+ script_ops.py_func(read_fixed_length_numpy_strings, [],
+ dtypes.string))
+ z = self.evaluate(
+ script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
+ self.assertAllEqual(z, [b"hello there", b"hi there"])
def testObjectArraysAreConvertedToBytes(self):
@@ -186,16 +201,8 @@ class PyOpTest(test.TestCase):
def testNoInput(self):
with self.test_session():
- x, = script_ops.py_func(lambda: 42.0, [], [dtypes.float64])
- self.assertAllClose(x.eval(), 42.0)
-
- def testCleanup(self):
- for _ in xrange(1000):
- g = ops.Graph()
- with g.as_default():
- c = constant_op.constant([1.], dtypes.float32)
- _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
- self.assertTrue(script_ops._py_funcs.size() < 100)
+ x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
+ self.assertAllClose(x, 42.0)
def testAlias(self):
with self.test_session():
@@ -242,8 +249,8 @@ class PyOpTest(test.TestCase):
# Create a numpy array aliasing a tensor and a tensor aliasing this array
z, = script_ops.py_func(ident, [p], [dtypes.float32])
z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0]
- # above instead of using its memory as the return value of
- # session.run
+ # above instead of using its memory as the return value of
+ # session.run
self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
def testStateful(self):
@@ -319,10 +326,10 @@ class PyOpTest(test.TestCase):
def value(self):
return self._value
- with self.test_session() as sess:
+ with self.test_session():
s = State()
op = s.increment(constant_op.constant(2, dtypes.int64))
- ret = sess.run(op)
+ ret = self.evaluate(op)
self.assertIsNone(ret)
self.assertAllEqual([3], s.value)
@@ -336,15 +343,24 @@ class PyOpTest(test.TestCase):
with self.test_session() as sess:
self.assertEqual(sess.run(f), [])
- def _testExceptionHandling(self, py_exp, tf_exp):
+ def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
def raise_exception():
raise py_exp("blah") # pylint: disable=not-callable
- f = script_ops.py_func(raise_exception, [], [])
- with self.test_session() as sess:
+ if eager:
+ if context.in_eager_mode():
+ with self.assertRaisesRegexp(tf_exp, "blah"):
+ f = script_ops.eager_py_func(raise_exception, [], [])
+ return
+ else:
+ f = script_ops.eager_py_func(raise_exception, [], [])
+ else:
+ f = script_ops.py_func(raise_exception, [], [])
+
+ with self.test_session():
with self.assertRaisesRegexp(tf_exp, "blah"):
- sess.run(f)
+ self.evaluate(f)
def testExceptionHandling(self):
self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
@@ -358,6 +374,89 @@ class PyOpTest(test.TestCase):
self._testExceptionHandling(WeirdError, errors.UnknownError)
+ # ----- Tests shared by py_func and eager_py_func -----
+ def testCleanup(self):
+ for _ in xrange(1000):
+ g = ops.Graph()
+ with g.as_default():
+ c = constant_op.constant([1.], dtypes.float32)
+ _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
+ _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
+ self.assertTrue(script_ops._py_funcs.size() < 100)
+
+ # ----- Tests for eager_py_func -----
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerSingleOutputInt32(self):
+ a = array_ops.ones((3, 3), dtype=dtypes.int32)
+ x = array_ops.ones((3, 1), dtype=dtypes.int32)
+ output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
+ with self.test_session():
+ ret = self.evaluate(output)
+ self.assertAllEqual(ret, [[3], [3], [3]])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerSingleOutputFloat32(self):
+ a = array_ops.ones((3, 3), dtype=dtypes.float32)
+ x = array_ops.ones((3, 1), dtype=dtypes.float32)
+ output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
+ with self.test_session():
+ ret = self.evaluate(output)
+ self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerArrayOutput(self):
+ a = array_ops.ones((3, 3), dtype=dtypes.int32)
+ x = array_ops.ones((3, 1), dtype=dtypes.int32)
+ output = script_ops.eager_py_func(
+ lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32])
+
+ with self.test_session():
+ ret = self.evaluate(output)
+ self.assertAllEqual(ret, [[[3], [3], [3]]])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerReturnNone(self):
+
+ def no_return_value():
+ return
+
+ output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
+ ret = self.evaluate(output)
+ if context.in_eager_mode():
+ self.assertEquals(len(ret), 0)
+ else:
+ self.assertIsNone(ret)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerPyFuncInDefun(self):
+
+ def wrapper():
+ a = array_ops.ones((3, 3), dtype=dtypes.int32)
+ x = array_ops.ones((3, 1), dtype=dtypes.int32)
+ return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
+
+ wrapped = function.defun(wrapper)
+ ret = self.evaluate(wrapped())
+ self.assertAllEqual(ret, [[3], [3], [3]])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerExceptionHandling(self):
+ self._testExceptionHandling(
+ ValueError, errors.InvalidArgumentError, eager=True)
+ self._testExceptionHandling(
+ TypeError, errors.InvalidArgumentError, eager=True)
+ self._testExceptionHandling(
+ StopIteration, errors.OutOfRangeError, eager=True)
+ self._testExceptionHandling(
+ MemoryError, errors.ResourceExhaustedError, eager=True)
+ self._testExceptionHandling(
+ NotImplementedError, errors.UnimplementedError, eager=True)
+
+ class WeirdError(Exception):
+ pass
+
+ self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index a42282b055..eae1c2eea6 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <array>
#include "numpy/arrayobject.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -25,8 +27,10 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/py_util.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
#include <Python.h>
namespace tensorflow {
@@ -48,6 +52,9 @@ struct PyCall {
// with this "token".
string token;
+ // True if the call is associated with an EagerPyFunc.
+ bool eager;
+
// Inputs and outputs of this function invocation.
std::vector<Tensor> ins;
std::vector<Tensor> out;
@@ -55,19 +62,26 @@ struct PyCall {
// Givens the 'call', prepares the token and inputs as a python tuple
// that is appropriate for calling the trampoline.
-Status MakeArgTuple(PyCall* call, PyObject** tuple) {
+Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
int64 n = call->ins.size();
PyObject* lst = PyList_New(n);
CHECK(lst);
for (int64 i = 0; i < n; ++i) {
+ PyObject* arg = nullptr;
const Tensor& t = call->ins[i];
- PyObject* a = nullptr;
- Status s = ConvertTensorToNdarray(t, &a);
- if (!s.ok()) {
- Py_DECREF(lst);
- return s;
+ if (call->eager) {
+ arg = EagerTensorFromHandle(TFE_NewTensorHandle(t));
+ if (arg == nullptr) {
+ return errors::Internal("Unable to procure EagerTensor from Tensor.");
+ }
+ } else {
+ Status s = ConvertTensorToNdarray(t, &arg);
+ if (!s.ok()) {
+ Py_DECREF(lst);
+ return s;
+ }
}
- PyList_SetItem(lst, i, a);
+ PyList_SetItem(lst, i, arg);
}
*tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
CHECK(*tuple);
@@ -133,6 +147,18 @@ bool IsSingleNone(PyObject* obj) {
return item == Py_None;
}
+// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
+Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
+ Tensor* output_tensor,
+ TF_Status* tf_status) {
+ // TODO(akshayka): Lift the restriction requiring output tensors to
+ // lie in host memory; EagerPyFunc should be able to dispatch ops on GPU
+ // tensors, so we should eventually implement a GPU kernel for EagerPyFunc.
+ *output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory(
+ EagerTensor_Handle(eager_tensor), tf_status);
+ return StatusFromTF_Status(tf_status);
+}
+
// Calls the registered py function through the trampoline.
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
*out_log_on_error = true;
@@ -172,21 +198,37 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
}
}
- // Process the return values and converts them to tf Tensors.
+ // Process the return values and convert them to TF Tensors.
Status s;
if (PyList_Check(result)) {
- // 'result' is a list.
call->out.clear();
for (int i = 0; i < PyList_Size(result); ++i) {
Tensor t;
- s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
+ if (call->eager) {
+ auto tf_status = tensorflow::make_safe(TF_NewStatus());
+ s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t,
+ tf_status.get());
+ } else {
+ s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
+ }
+
if (!s.ok()) {
break;
}
call->out.push_back(t);
}
+ } else if (EagerTensor_CheckExact(result) || result == Py_None) {
+ DCHECK(call->eager);
+ Tensor t;
+ if (result != Py_None) {
+ auto tf_status = tensorflow::make_safe(TF_NewStatus());
+ s = ExtractTensorFromEagerTensor(result, &t, tf_status.get());
+ if (s.ok()) {
+ call->out.push_back(t);
+ }
+ }
} else if (PyArray_Check(result)) {
- // 'result' is a single ndarray.
+ DCHECK(!call->eager);
if (!IsSingleNone(result)) {
Tensor t;
s = ConvertNdarrayToTensor(result, &t);
@@ -375,11 +417,13 @@ class PyFuncOp : public OpKernel {
public:
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
+ eager_ = type_string() == "EagerPyFunc";
}
void Compute(OpKernelContext* ctx) override {
PyCall call;
call.token = token_;
+ call.eager = eager_;
for (int i = 0; i < ctx->num_inputs(); ++i) {
call.ins.push_back(ctx->input(i));
}
@@ -418,9 +462,15 @@ class PyFuncOp : public OpKernel {
private:
string token_;
+ // True if and only if this op should execute the python function eagerly,
+ // i.e., if and only if the eager attribute is set.
+ bool eager_;
+
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
};
+
REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
+REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
} // end namespace tensorflow
diff --git a/tensorflow/python/lib/core/py_func.h b/tensorflow/python/lib/core/py_func.h
index 5a451d5f43..3197a7ddfa 100644
--- a/tensorflow/python/lib/core/py_func.h
+++ b/tensorflow/python/lib/core/py_func.h
@@ -24,21 +24,27 @@ limitations under the License.
namespace tensorflow {
-// Called by py code on initialization.
+// Called by python code on initialization.
//
// "trampoline" must represent a python function which has the
// following signature:
-// (string, list(ndarray)) -> ndarray | list(ndarray) | python scalar
+// (string, list(ndarray)) | (string, list(EagerTensor)) ->
+// ndarray | list(ndarray) | python scalar |
+// EagerTensor | list(EagerTensor) | None
//
// The trampoline takes two arguments, the first is a string token
// used by the python frontend's dispatching logic; the second is a
-// list of numpy ndarrays.
+// list of numpy ndarrays or EagerTensor objects. It can return a
+// single numpy ndarray, a list of numpy ndarrays, a python scalar, an
+// EagerTensor, a list of EagerTensors, or None.
//
-// The trampoline can return a single numpy ndarray, a list of numpy
-// ndarrays, or a simply python scalar. The C++ runtime converts them,
-// if supported, back to Tensor objects.
+// PyFunc requires inputs and outputs to be ndarrays. EagerPyFunc requires
+// inputs to be a list of EagerTensors and outputs to be an EagerTensor, a list
+// of EagerTensors, or None.
//
-// This is called by script_ops.py during its module initialization.
+// The C++ runtime converts outputs back to Tensor objects.
+//
+// This function is called by script_ops.py during its module initialization.
//
// TODO(zhifengc): Support distributed runtime.
void InitializePyTrampoline(PyObject* trampoline);
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index af014a7e39..b3f7c26695 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -341,6 +341,7 @@ TruncatedNormal
# script_ops
PyFunc
PyFuncStateless
+EagerPyFunc
# sdca_ops
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 2c3667dffe..c0c1ade495 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -29,11 +29,41 @@ import numpy as np
import six
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import context
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_script_ops
+class EagerFunc(object):
+ """A wrapper for a function owned by an EagerPyFunc."""
+
+ def __init__(self, func, Tout):
+ """Constructs an EagerFunc.
+
+ Args:
+ func: The function to wrap.
+ Tout: A list of datatypes for the output; an empty list if the output is
+ None.
+ """
+ self._func = func
+ self._out_dtypes = Tout
+
+ def __call__(self, *args, **kwargs):
+ """Passes args, kwargs to `self._func`, which is executed eagerly."""
+ with context.eager_mode():
+ ret = self._func(*args, **kwargs)
+ if isinstance(ret, (tuple, list)):
+ return [
+ ops.convert_to_tensor(x, dtype=dtype)
+ for (x, dtype) in zip(ret, self._out_dtypes)
+ ]
+ elif ret is None:
+ return ret
+ else:
+ return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])
+
+
class FuncRegistry(object):
"""A helper class to keep track of registered py functions.
@@ -91,16 +121,20 @@ class FuncRegistry(object):
if func is None:
raise ValueError("callback %s is not found" % token)
ret = func(*args)
- # Strings seem to lead to a memory leak here if they're not wrapped in a
- # list.
- if isinstance(ret, six.binary_type):
- ret = [ret]
- # Ensures that we return either a single numpy array or a list of numpy
- # arrays.
- if isinstance(ret, (tuple, list)):
- return [self._convert(x) for x in ret]
+
+ if isinstance(func, EagerFunc):
+ return ret
else:
- return self._convert(ret)
+ # Strings seem to lead to a memory leak here if they're not wrapped in a
+ # list.
+ if isinstance(ret, six.binary_type):
+ ret = [ret]
+ # Ensures that we return either a single numpy array or a list of numpy
+ # arrays.
+ if isinstance(ret, (tuple, list)):
+ return [self._convert(x) for x in ret]
+ else:
+ return self._convert(ret)
def size(self):
"""Returns how many functions are currently registered."""
@@ -129,6 +163,86 @@ class CleanupFunc(object):
_py_funcs.remove(self._token)
+def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
+ """See documentation for py_func and eager_py_func."""
+
+ is_list_or_tuple = False
+ if isinstance(Tout, (list, tuple)):
+ is_list_or_tuple = True
+ else:
+ Tout = [Tout]
+
+ if eager:
+ func = EagerFunc(func, Tout)
+
+ token = _py_funcs.insert(func)
+ # We tie the registered function's lifetime with the current default graph,
+ # i.e., when the current graph is destroyed, we remove its py funcs.
+ graph = ops.get_default_graph()
+
+ # pylint: disable=protected-access
+ while isinstance(graph, function._FuncGraph):
+ # If the py_func was declared inside a _FuncGraph, its lifetime should be
+ # bound to that of the outer graph instead.
+ graph = graph._outer_graph
+
+ cleanup = CleanupFunc(token)
+
+ # TODO(zhifengc): Consider adding a Graph method to collect
+ # `cleanup` objects in one of its member.
+ if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
+ graph._cleanup_py_funcs_used_in_graph = []
+
+ # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
+ # will be destroyed and their __del__ will remove the 'token' from
+ # the funcs registry.
+ graph._cleanup_py_funcs_used_in_graph.append(cleanup)
+ # pylint: enable=protected-access
+
+ # pylint: disable=protected-access
+ if eager:
+ result = gen_script_ops._eager_py_func(
+ input=inp, token=token, Tout=Tout, name=name)
+ else:
+ if stateful:
+ result = gen_script_ops._py_func(
+ input=inp, token=token, Tout=Tout, name=name)
+ else:
+ result = gen_script_ops._py_func_stateless(
+ input=inp, token=token, Tout=Tout, name=name)
+ # pylint: enable=protected-access
+ return result if is_list_or_tuple else result[0]
+
+
+def eager_py_func(func, inp, Tout, name=None):
+ """Wraps a python function into a TensorFlow op.
+
+ When the returned op is executed, `func` is invoked with eager execution
+ enabled. Inputs are Tensor objects and func must return None or objects
+ that may be converted to Tensor objects.
+
+ This function has the same limitations as `py_func` with respect to
+ serialization and distribution.
+
+ Args:
+ func: A Python function which accepts a list of `Tensor` objects
+ having element types that match the corresponding `tf.Tensor` objects
+ in `inp` and returns a list of `Tensor` objects (or a single
+ `Tensor`, or `None`) having element types that match the
+ corresponding values in `Tout`.
+ inp: A list of `Tensor` objects.
+ Tout: A list or tuple of tensorflow data types or a single tensorflow data
+ type if there is only one, indicating what `func` returns; an empty list
+ if no value is returned (i.e., if the return value is `None`).
+ name: A name for the operation (optional).
+
+ Returns:
+ A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
+ if `func` returns None.
+ """
+ return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
+
+
def py_func(func, inp, Tout, stateful=True, name=None):
"""Wraps a python function and uses it as a TensorFlow op.
@@ -182,46 +296,12 @@ def py_func(func, inp, Tout, stateful=True, name=None):
Returns:
A list of `Tensor` or a single `Tensor` which `func` computes.
"""
- token = _py_funcs.insert(func)
- # We tie the registered function's life-time with the current
- # default graph. I.e., when the current graph is destroyed, we
- # should remove its py funcs.
- g = ops.get_default_graph()
-
- # pylint: disable=protected-access
- while isinstance(g, function._FuncGraph):
- # If the py_func was declared inside a _FuncGraph, its lifetime should be
- # bound to that of the outer graph instead.
- g = g._outer_graph
-
- cleanup = CleanupFunc(token)
-
- # TODO(zhifengc): Consider adding a Graph method to collect
- # `cleanup` objects in one of its member.
- if not hasattr(g, "_cleanup_py_funcs_used_in_graph"):
- g._cleanup_py_funcs_used_in_graph = []
-
- # When g is destroyed, elements in _cleanup_py_funcs_used_in_graph
- # will be destroyed and their __del__ will remove the 'token' from
- # the funcs registry.
- g._cleanup_py_funcs_used_in_graph.append(cleanup)
- # pylint: enable=protected-access
-
- if isinstance(Tout, (list, tuple)):
- is_list_or_tuple = True
- else:
- Tout = [Tout]
- is_list_or_tuple = False
- # pylint: disable=protected-access
- if stateful:
- result = gen_script_ops._py_func(
- input=inp, token=token, Tout=Tout, name=name)
- else:
- result = gen_script_ops._py_func_stateless(
- input=inp, token=token, Tout=Tout, name=name)
- # pylint: enable=protected-access
- return result if is_list_or_tuple else result[0]
+ return _internal_py_func(
+ func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
+# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be
+# differentiable, i.e., the gradient of PyFunc should propagate Nones if the
+# eager attribute is not set, and otherwise, it should return the gradient.
ops.NotDifferentiable("PyFunc")
ops.NotDifferentiable("PyFuncStateless")