aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-01-29 18:30:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 18:34:39 -0800
commit87a3c967973641d3b0d2a16d17add184ed967392 (patch)
tree4cd133e7acd12a9b173ce172c516fcff912ee98e /tensorflow/python/lib
parente3d99c92975efc2010d0e1e2dd4c3eb787a8d67c (diff)
TFE: Register a GPU kernel for tfe.py_func.
PiperOrigin-RevId: 183765122
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/core/py_func.cc72
1 files changed, 51 insertions, 21 deletions
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index d3bfa0ee33..e0422ef80a 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "numpy/arrayobject.h"
#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -53,6 +54,12 @@ struct PyCall {
// with this "token".
string token;
+ // The device on which Tensors are stored; only used for EagerPyFunc.
+ Device* device;
+
+ // True if and only if the op has been placed on a GPU.
+ bool gpu;
+
// True if the call is associated with an EagerPyFunc.
bool eager;
@@ -71,7 +78,12 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
PyObject* arg = nullptr;
const Tensor& t = call->ins[i];
if (call->eager) {
- arg = EagerTensorFromHandle(TFE_NewTensorHandle(t));
+ if (call->gpu) {
+ arg = EagerTensorFromHandle(new TFE_TensorHandle(t, call->device));
+ } else {
+ // TFE_TensorHandle assumes that CPU is identified by `nullptr`.
+ arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr));
+ }
if (arg == nullptr) {
return errors::Internal("Unable to procure EagerTensor from Tensor.");
}
@@ -84,7 +96,8 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
}
PyList_SetItem(lst, i, arg);
}
- *tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
+ *tuple = Py_BuildValue("(sON)", call->token.c_str(),
+ call->gpu ? Py_True : Py_False, lst);
CHECK(*tuple);
return Status::OK();
}
@@ -150,15 +163,9 @@ bool IsSingleNone(PyObject* obj) {
}
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
-Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
- Tensor* output_tensor,
- TF_Status* tf_status) {
- // TODO(akshayka): Lift the restriction requiring output tensors to
- // lie in host memory; EagerPyFunc should be able to dispatch ops on GPU
- // tensors, so we should eventually implement a GPU kernel for EagerPyFunc.
- *output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory(
- EagerTensor_Handle(eager_tensor), tf_status);
- return StatusFromTF_Status(tf_status);
+void ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
+ Tensor* output_tensor) {
+ *output_tensor = EagerTensor_Handle(eager_tensor)->t;
}
// Calls the registered py function through the trampoline.
@@ -201,15 +208,23 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
}
// Process the return values and convert them to TF Tensors.
- Status s;
+ Status s = Status::OK();
if (PyList_Check(result)) {
+ // `result` is a Python list; if this operation is an `EagerPyFunc`, then
+ // every item in the list must be an `EagerTensor`; otherwise, every element
+ // must be a NumPy array.
call->out.clear();
for (int i = 0; i < PyList_Size(result); ++i) {
Tensor t;
if (call->eager) {
- auto tf_status = tensorflow::make_safe(TF_NewStatus());
- s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t,
- tf_status.get());
+ const PyObject* item = PyList_GetItem(result, i);
+ if (EagerTensor_CheckExact(item)) {
+ ExtractTensorFromEagerTensor(item, &t);
+ } else {
+ s = errors::FailedPrecondition(
+ "Expected EagerTensor, found PyObject of type: ",
+ Py_TYPE(item)->tp_name);
+ }
} else {
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
}
@@ -220,16 +235,15 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
call->out.push_back(t);
}
} else if (EagerTensor_CheckExact(result) || result == Py_None) {
+ // result is an `EagerTensor` or `None`.
DCHECK(call->eager);
Tensor t;
if (result != Py_None) {
- auto tf_status = tensorflow::make_safe(TF_NewStatus());
- s = ExtractTensorFromEagerTensor(result, &t, tf_status.get());
- if (s.ok()) {
- call->out.push_back(t);
- }
+ ExtractTensorFromEagerTensor(result, &t);
+ call->out.push_back(t);
}
} else if (PyArray_Check(result)) {
+ // `result` is a NumPy array.
DCHECK(!call->eager);
if (!IsSingleNone(result)) {
Tensor t;
@@ -239,7 +253,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
}
}
} else {
- s = errors::Internal("Unexpected pyobject is returned: ",
+ s = errors::Internal("Unexpected PyObject was returned: ",
Py_TYPE(result)->tp_name);
}
Py_DECREF(result);
@@ -429,12 +443,24 @@ class PyFuncOp : public OpKernel {
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
eager_ = type_string() == "EagerPyFunc";
+ gpu_ = ctx->device_type().type_string() == DEVICE_GPU;
}
void Compute(OpKernelContext* ctx) override {
PyCall call;
call.token = token_;
+ call.gpu = gpu_;
call.eager = eager_;
+ if (call.eager) {
+ // Eager's C API uses `Device`, whereas `OpKernelContext` stores a
+ // `DeviceBase`; attempt to downcast.
+ call.device = dynamic_cast<Device*>(ctx->device());
+ if (call.device == nullptr) {
+ ctx->CtxFailureWithWarning(
+ errors::Internal("Unrecognized device class"));
+ }
+ }
+
for (int i = 0; i < ctx->num_inputs(); ++i) {
call.ins.push_back(ctx->input(i));
}
@@ -476,6 +502,9 @@ class PyFuncOp : public OpKernel {
private:
string token_;
+ // True if and only if this op has been placed on a GPU.
+ bool gpu_;
+
// True if and only if this op should execute the python function eagerly,
// i.e., if and only if the eager attribute is set.
bool eager_;
@@ -486,5 +515,6 @@ class PyFuncOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
+REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_GPU), PyFuncOp);
} // end namespace tensorflow