aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib/core
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2017-12-07 15:18:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 15:22:18 -0800
commitf37380b064948fb6dd45feef0e8d93130c2f9884 (patch)
tree1481e8ee34040533d5927501a4c8ba5b3f8a8f45 /tensorflow/python/lib/core
parent2d4c29cd6a0627fdd71a752e6bd919204c7cb8bf (diff)
Add tfe.py_func, a tf.py_func-like construct that wraps a Python function and executes it eagerly.
In particular, an EagerPyFunc op is added that wraps a Python function and executes it eagerly. The wrapped function should take Tensors as inputs and return Tensors as outputs. Because functions wrapped in an EagerPyFunc are executed eagerly, they can make use of TensorFlow operations. EagerPyFunc should be differentiable, in principle; a gradient will be implemented and registered in a future change. Once a gradient is implemented, tfe.py_func will probably be the easiest mechanism for experimenting with custom ops. tfe.py_func will also make it easier to translate python functions with side-effects into defun-able code. PiperOrigin-RevId: 178303818
Diffstat (limited to 'tensorflow/python/lib/core')
-rw-r--r--tensorflow/python/lib/core/py_func.cc72
-rw-r--r--tensorflow/python/lib/core/py_func.h20
2 files changed, 74 insertions, 18 deletions
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index a42282b055..eae1c2eea6 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <array>
#include "numpy/arrayobject.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -25,8 +27,10 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/py_util.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
#include <Python.h>
namespace tensorflow {
@@ -48,6 +52,9 @@ struct PyCall {
// with this "token".
string token;
+ // True if the call is associated with an EagerPyFunc.
+ bool eager;
+
// Inputs and outputs of this function invocation.
std::vector<Tensor> ins;
std::vector<Tensor> out;
@@ -55,19 +62,26 @@ struct PyCall {
// Givens the 'call', prepares the token and inputs as a python tuple
// that is appropriate for calling the trampoline.
-Status MakeArgTuple(PyCall* call, PyObject** tuple) {
+Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
int64 n = call->ins.size();
PyObject* lst = PyList_New(n);
CHECK(lst);
for (int64 i = 0; i < n; ++i) {
+ PyObject* arg = nullptr;
const Tensor& t = call->ins[i];
- PyObject* a = nullptr;
- Status s = ConvertTensorToNdarray(t, &a);
- if (!s.ok()) {
- Py_DECREF(lst);
- return s;
+ if (call->eager) {
+ arg = EagerTensorFromHandle(TFE_NewTensorHandle(t));
+ if (arg == nullptr) {
+ return errors::Internal("Unable to procure EagerTensor from Tensor.");
+ }
+ } else {
+ Status s = ConvertTensorToNdarray(t, &arg);
+ if (!s.ok()) {
+ Py_DECREF(lst);
+ return s;
+ }
}
- PyList_SetItem(lst, i, a);
+ PyList_SetItem(lst, i, arg);
}
*tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
CHECK(*tuple);
@@ -133,6 +147,18 @@ bool IsSingleNone(PyObject* obj) {
return item == Py_None;
}
+// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
+Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
+ Tensor* output_tensor,
+ TF_Status* tf_status) {
+ // TODO(akshayka): Lift the restriction requiring output tensors to
+ // lie in host memory; EagerPyFunc should be able to dispatch ops on GPU
+ // tensors, so we should eventually implement a GPU kernel for EagerPyFunc.
+ *output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory(
+ EagerTensor_Handle(eager_tensor), tf_status);
+ return StatusFromTF_Status(tf_status);
+}
+
// Calls the registered py function through the trampoline.
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
*out_log_on_error = true;
@@ -172,21 +198,37 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
}
}
- // Process the return values and converts them to tf Tensors.
+ // Process the return values and convert them to TF Tensors.
Status s;
if (PyList_Check(result)) {
- // 'result' is a list.
call->out.clear();
for (int i = 0; i < PyList_Size(result); ++i) {
Tensor t;
- s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
+ if (call->eager) {
+ auto tf_status = tensorflow::make_safe(TF_NewStatus());
+ s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t,
+ tf_status.get());
+ } else {
+ s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
+ }
+
if (!s.ok()) {
break;
}
call->out.push_back(t);
}
+ } else if (EagerTensor_CheckExact(result) || result == Py_None) {
+ DCHECK(call->eager);
+ Tensor t;
+ if (result != Py_None) {
+ auto tf_status = tensorflow::make_safe(TF_NewStatus());
+ s = ExtractTensorFromEagerTensor(result, &t, tf_status.get());
+ if (s.ok()) {
+ call->out.push_back(t);
+ }
+ }
} else if (PyArray_Check(result)) {
- // 'result' is a single ndarray.
+ DCHECK(!call->eager);
if (!IsSingleNone(result)) {
Tensor t;
s = ConvertNdarrayToTensor(result, &t);
@@ -375,11 +417,13 @@ class PyFuncOp : public OpKernel {
public:
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
+ eager_ = type_string() == "EagerPyFunc";
}
void Compute(OpKernelContext* ctx) override {
PyCall call;
call.token = token_;
+ call.eager = eager_;
for (int i = 0; i < ctx->num_inputs(); ++i) {
call.ins.push_back(ctx->input(i));
}
@@ -418,9 +462,15 @@ class PyFuncOp : public OpKernel {
private:
string token_;
+ // True if and only if this op should execute the python function eagerly,
+ // i.e., if and only if the eager attribute is set.
+ bool eager_;
+
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
};
+
REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
+REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
} // end namespace tensorflow
diff --git a/tensorflow/python/lib/core/py_func.h b/tensorflow/python/lib/core/py_func.h
index 5a451d5f43..3197a7ddfa 100644
--- a/tensorflow/python/lib/core/py_func.h
+++ b/tensorflow/python/lib/core/py_func.h
@@ -24,21 +24,27 @@ limitations under the License.
namespace tensorflow {
-// Called by py code on initialization.
+// Called by python code on initialization.
//
// "trampoline" must represent a python function which has the
// following signature:
-// (string, list(ndarray)) -> ndarray | list(ndarray) | python scalar
+// (string, list(ndarray)) | (string, list(EagerTensor)) ->
+// ndarray | list(ndarray) | python scalar |
+// EagerTensor | list(EagerTensor) | None
//
// The trampoline takes two arguments, the first is a string token
// used by the python frontend's dispatching logic; the second is a
-// list of numpy ndarrays.
+// list of numpy ndarrays or EagerTensor objects. It can return a
+// single numpy ndarray, a list of numpy ndarrays, a python scalar, an
+// EagerTensor, a list of EagerTensors, or None.
//
-// The trampoline can return a single numpy ndarray, a list of numpy
-// ndarrays, or a simply python scalar. The C++ runtime converts them,
-// if supported, back to Tensor objects.
+// PyFunc requires inputs and outputs to be ndarrays. EagerPyFunc requires
+// inputs to be a list of EagerTensors and outputs to be an EagerTensor, a list
+// of EagerTensors, or None.
//
-// This is called by script_ops.py during its module initialization.
+// The C++ runtime converts outputs back to Tensor objects.
+//
+// This function is called by script_ops.py during its module initialization.
//
// TODO(zhifengc): Support distributed runtime.
void InitializePyTrampoline(PyObject* trampoline);