aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-30 23:39:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-30 23:47:56 -0700
commitff18944249f723cf6e2825a3165f1efbb64c4880 (patch)
treeb2b16931d228162042eb5cf0da614dc4fba83ead /tensorflow/python/eager
parentda8349412fe03c9f55307c7f2674f072073d1b40 (diff)
Move EagerTensor from python to C.
PiperOrigin-RevId: 170617321
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/BUILD7
-rw-r--r--tensorflow/python/eager/benchmarks_test.py47
-rw-r--r--tensorflow/python/eager/context.py10
-rw-r--r--tensorflow/python/eager/core_test.py5
-rw-r--r--tensorflow/python/eager/execute.py15
-rw-r--r--tensorflow/python/eager/execution_callbacks.py2
-rw-r--r--tensorflow/python/eager/ops_test.py2
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc646
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h61
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc122
-rw-r--r--tensorflow/python/eager/tape.py8
-rw-r--r--tensorflow/python/eager/tensor_test.py127
12 files changed, 885 insertions, 167 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index dee967d18d..da62229959 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -6,7 +6,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
cc_library(
name = "pywrap_tfe_lib",
- srcs = ["pywrap_tfe_src.cc"],
+ srcs = [
+ "pywrap_tensor.cc",
+ "pywrap_tfe_src.cc",
+ ],
hdrs = ["pywrap_tfe.h"],
visibility = ["//tensorflow:internal"],
deps = [
@@ -14,8 +17,10 @@ cc_library(
"//tensorflow/c/eager:c_api",
"//tensorflow/core:lib",
"//tensorflow/python:ndarray_tensor",
+ "//tensorflow/python:ndarray_tensor_bridge",
"//tensorflow/python:numpy_lib",
"//tensorflow/python:py_seq_tensor",
+ "//tensorflow/python:safe_ptr",
"//util/python:python_headers",
],
)
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 52aff5c8d6..407d1e979c 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -37,6 +37,7 @@ from tensorflow.python.eager import backprop # pylint: disable=unused-import
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_math_ops
@@ -61,18 +62,41 @@ def benchmark_create_tensor(n):
def label(s):
return "{:20s}".format(s)
- with timer(label("np.array([[3]])"), iters=n) as iters:
+ with timer(label("np.array([[3.0]])"), iters=n) as iters:
for _ in iters:
- np.array([[3]])
+ np.array([[3.0]])
- with timer(label("Tensor([[3]])"), iters=n) as iters:
+ ctx = context.context()
+ handle = ctx._handle
+ device = ctx.device_name
+ # May be warmup GPU.
+ ops.EagerTensor([[3.0]], context=handle, device=device)
+
+ # float32
+ dtype = dtypes.float32.as_datatype_enum
+ three = [[3.0]]
+ with timer(label("EagerTensor([[3.0]])"), iters=n) as iters:
for _ in iters:
- ops.EagerTensor([[3]], context.context())
+ ops.EagerTensor(three, context=handle, device=device, dtype=dtype)
- ctx = context.context()
- with timer(label("Tensor([[3]], ctx)"), iters=n) as iters:
+ np_3 = np.array([[3.0]], dtype=np.float32)
+ with timer(label("EagerTensor(np.array([[3.0]]))"), iters=n) as iters:
+ for _ in iters:
+ ops.EagerTensor(np_3, context=handle, device=device, dtype=dtype)
+
+ # int32.
+ # This is interesting since int32 will be kept on host memory for the GPU
+ # case.
+ dtype = dtypes.int32.as_datatype_enum
+ three = [[3]]
+ with timer(label("EagerTensor([[3]])"), iters=n) as iters:
+ for _ in iters:
+ ops.EagerTensor(three, context=handle, device=device, dtype=dtype)
+
+ np_3 = np.array([[3]], dtype=np.int32)
+ with timer(label("EagerTensor(np.array([[3]]))"), iters=n) as iters:
for _ in iters:
- ops.EagerTensor([[3]], ctx)
+ ops.EagerTensor(np_3, context=handle, device=device, dtype=dtype)
def benchmark_matmul(shape, n, use_gpu=False):
@@ -103,17 +127,16 @@ def benchmark_matmul(shape, n, use_gpu=False):
for _ in iters:
gen_math_ops._mat_mul(m, m, transpose_b=transpose_b)
+ inputs = [m, m]
# pylint: disable=protected-access
- input_handles = [m._handle, m._handle]
ctx_handle = context.context()._handle
# pylint: enable=protected-access
attrs = ("transpose_a", False, "transpose_b", transpose_b, "T",
m.dtype.as_datatype_enum)
with timer(label("TFE_Py_Execute"), iters=n) as iters:
for _ in iters:
- pywrap_tensorflow.TFE_DeleteTensorHandle(
- pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "MatMul",
- input_handles, attrs, 1)[0])
+ pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "MatMul",
+ inputs, attrs, 1)
f = function.defun(math_ops.matmul)
with timer(label("defun(tf.matmul)"), iters=n) as iters:
@@ -133,6 +156,8 @@ class BenchmarksTest(test_util.TensorFlowTestCase):
if context.context().num_gpus() > 0:
print("---- RUNNING ON GPU NOW ----")
+ with context.device("/device:GPU:0"):
+ benchmark_create_tensor(FLAGS.iters or 30000)
benchmark_matmul([2, 2], FLAGS.iters or 30000, use_gpu=True)
benchmark_matmul([100, 28 * 28], FLAGS.iters or 1000, use_gpu=True)
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 9acd14d4b4..02ff567e9e 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -121,16 +121,6 @@ class Context(object):
else:
return devices
- def __del__(self):
- try:
- if self._context_handle is not None:
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_DeleteContext(self._context_handle, status)
- except (AttributeError, TypeError):
- # Sometimes deletion during program shutdown throws exception as other
- # modules are no longer available.
- pass
-
def __str__(self):
if self._context_handle is None:
return "Eager TensorFlow Context. Devices currently uninitialized."
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index 653d92d7c5..041d388fad 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import threading
+
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -138,7 +139,7 @@ class TFETest(test_util.TensorFlowTestCase):
x = x.as_cpu_tensor()
# Invalid device
- with self.assertRaises(errors.InvalidArgumentError):
+ with self.assertRaises(RuntimeError):
x.as_gpu_tensor(context.context().num_gpus() + 1)
def testNumpyForceCPU(self):
@@ -153,7 +154,7 @@ class TFETest(test_util.TensorFlowTestCase):
ta = constant_op.constant([[1, 2], [3, 4]])
tb = ta.as_cpu_tensor()
- self.assertNotEqual(ta._handle, tb._handle)
+ self.assertNotEqual(id(ta), id(tb))
self.assertAllEqual(ta.numpy(), tb.numpy())
def testRegisterExceptionClass(self):
diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py
index 312fc97c80..808955560f 100644
--- a/tensorflow/python/eager/execute.py
+++ b/tensorflow/python/eager/execute.py
@@ -53,32 +53,27 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
Raises:
An exception on error.
"""
- # TODO(apassos) move this to convert_to_tensor
- # pylint: disable=protected-access
- input_handles = [c._handle for c in inputs]
device_name = ctx.device_name
+ # pylint: disable=protected-access
try:
- outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
- op_name, input_handles, attrs,
- num_outputs)
+ tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
+ op_name, inputs, attrs,
+ num_outputs)
except core._NotOkStatusException as e:
if name is not None:
message = e.message + " name: " + name
else:
message = e.message
six.raise_from(core._status_to_exception(e.code, message), None)
- # pylint: enable=protected-access
- tensors = [ops._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
# TODO(alive, cais): Use the execution callback mechanism.
if core.active_trace() is not None:
for t in tensors:
- # pylint: disable=protected-access
core.active_trace().record_tensor(op_name,
ops.tensor_id(t),
t.device,
t.shape.num_elements())
- # pylint: enable=protected-access
+ # pylint: enable=protected-access
# TODO(cais): Optimize this, perhaps by replacing this execute function with
# a different one when there are execution callback(s).
diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py
index 1903704a3f..6b0e7f5c3f 100644
--- a/tensorflow/python/eager/execution_callbacks.py
+++ b/tensorflow/python/eager/execution_callbacks.py
@@ -162,7 +162,7 @@ def inf_nan_callback(op_type,
# TODO(cais): Consider moving this into execute.py.
# pylint: disable=protected-access
pywrap_tensorflow.TFE_Py_Execute(
- ctx._handle, output.device, "CheckNumerics", [output._handle],
+ ctx._handle, output.device, "CheckNumerics", [output],
check_numerics_op_attrs, 1)
# pylint: enable=protected-access
except core._NotOkStatusException: # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py
index 734369a729..e61e96aa96 100644
--- a/tensorflow/python/eager/ops_test.py
+++ b/tensorflow/python/eager/ops_test.py
@@ -33,7 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
-class TargetTest(test_util.TensorFlowTestCase):
+class OpsTest(test_util.TensorFlowTestCase):
def testExecuteBasic(self):
three = constant_op.constant(3)
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
new file mode 100644
index 0000000000..18337bdd45
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -0,0 +1,646 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdlib.h>
+
+#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
+#include "tensorflow/python/lib/core/numpy.h"
+#include "tensorflow/python/lib/core/py_seq_tensor.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
+
+#include "tensorflow/python/eager/pywrap_tfe.h"
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/python/lib/core/ndarray_tensor.h"
+
+namespace {
+
+TFE_Context* GetContext(PyObject* ctx) {
+ TFE_Context* context =
+ reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
+ if (context == nullptr) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting a PyCapsule encoded context handle. Got ",
+ Py_TYPE(ctx)->tp_name)
+ .c_str());
+ }
+ return context;
+}
+
+// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
+// The two may share underlying storage so changes to one may reflect in the
+// other.
+TFE_TensorHandle* NumpyToTensorHandle(PyObject* obj) {
+ tensorflow::Tensor t;
+ auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
+ if (cppstatus.ok()) {
+ return TFE_NewTensorHandle(t);
+ } else {
+ PyErr_SetString(PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "Failed to convert numpy ndarray to a Tensor (",
+ cppstatus.error_message(), ").")
+ .c_str());
+ return nullptr;
+ }
+}
+
+// Casts data referred to by `handle` from type `src_type_enum` to type
+// `dst_type_enum`.
+TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
+ TF_DataType src_type_enum,
+ TF_DataType dst_type_enum, TF_Status* out_status) {
+ if (ctx == nullptr) return nullptr;
+ const char* op_name = "Cast";
+ const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0";
+ TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
+#define RETURN_ERROR \
+ { \
+ TFE_DeleteOp(op); \
+ return nullptr; \
+ }
+ if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
+ TFE_OpSetDevice(op, device_name, out_status);
+ if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
+ TFE_OpAddInput(op, handle, out_status);
+ if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
+ TFE_OpSetAttrType(op, "SrcT", src_type_enum);
+ TFE_OpSetAttrType(op, "DstT", dst_type_enum);
+ TFE_TensorHandle* output = nullptr;
+ int num_outputs = 1;
+ TFE_Execute(op, &output, &num_outputs, out_status);
+ if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 ||
+ output == nullptr) {
+ if (output != nullptr) {
+ TFE_DeleteTensorHandle(output);
+ }
+ RETURN_ERROR
+ }
+ TFE_DeleteOp(op);
+ return output;
+#undef RETURN_ERROR
+}
+
+TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx,
+ PyObject* dev) {
+ const char* device = "";
+ if (dev != nullptr && dev != Py_None) {
+ device = PyBytes_AsString(dev);
+#if PY_MAJOR_VERSION >= 3
+ if (device == nullptr) {
+ PyErr_Clear();
+ device = PyUnicode_AsUTF8(dev);
+ }
+#endif
+ if (device == nullptr) {
+ PyErr_SetString(PyExc_TypeError,
+ "Error parsing device argument to CopyToDevice");
+ return nullptr;
+ }
+ }
+ TFE_Context* context = GetContext(ctx);
+ if (context == nullptr) { // PyErr already set by GetContext
+ return nullptr;
+ }
+ auto status = tensorflow::make_safe(TF_NewStatus());
+ TFE_TensorHandle* new_handle =
+ TFE_TensorHandleCopyToDevice(handle, context, device, status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat("Error copying tensor to device: ", device,
+ ". ", TF_Message(status.get()))
+ .c_str());
+ return nullptr;
+ }
+ return new_handle;
+}
+
+// Helper function to convert `v` to an int and store it in `*out`. Returns true
+// on success, false otherwise.
+// Note that we assume that v is a python int (not long) representing a
+// TF_DataType value.
+bool PyIntToDataType(PyObject* v, int* out) {
+#if PY_MAJOR_VERSION < 3
+ if (PyInt_Check(v)) {
+ *out = PyInt_AS_LONG(v);
+ return true;
+ }
+#else
+ if (PyLong_Check(v)) {
+ *out = PyLong_AsLong(v);
+ return true;
+ }
+#endif
+ return false;
+}
+
+// Helper function to create a python integer from TF_DataType.
+PyObject* PyIntFromDataType(TF_DataType l) {
+#if PY_MAJOR_VERSION < 3
+ return PyInt_FromLong(l);
+#else
+ return PyLong_FromLong(l);
+#endif
+}
+
+} // namespace
+
+extern "C" {
+
+static const int kMaxEagerTensorParentSize = 32;
+
+// TODO(agarwal): store context handle in EagerTensor.
+typedef struct EagerTensor {
+ PyObject_HEAD;
+ // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
+ // parent class. The parent class is set at runtime, so we don't know the
+ // exact size at compile time.
+ char unused[kMaxEagerTensorParentSize];
+ TFE_TensorHandle* handle;
+ int64_t id;
+ // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
+ // be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
+ // tensors, this will contain a serialized HandleData proto with shape
+ // inference metadata about shapes and dtypes of resources accessible from
+ // this handle.
+ // Note that we assume that handle_data cannot participate in reference
+ // cycles, and hence don't provide GC support for it.
+ PyObject* handle_data;
+
+ // This stores `_keras_mask` object and is set by Tensorflow layers.
+ PyObject* keras_mask;
+} EagerTensor;
+
+// tp_init for EagerTensor.
+int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
+ self->id = get_uid();
+ self->handle = nullptr;
+ Py_INCREF(Py_None);
+ self->handle_data = Py_None;
+ Py_INCREF(Py_None);
+ self->keras_mask = Py_None;
+ PyObject* value;
+ PyObject* context = nullptr;
+ PyObject* device = nullptr;
+ PyObject* dtype = Py_None;
+ const char* kwlist[] = {"value", "context", "device", "dtype", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O",
+ const_cast<char**>(kwlist), &value, &context,
+ &device, &dtype)) {
+ return -1;
+ }
+ // Extract dtype
+ int desired_dtype = -1;
+ if (dtype != Py_None) {
+ if (!PyIntToDataType(dtype, &desired_dtype)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting a DataType value for dtype. Got ",
+ Py_TYPE(dtype)->tp_name)
+ .c_str());
+ return -1;
+ }
+ }
+ tensorflow::Safe_TFE_TensorHandlePtr handle =
+ tensorflow::make_safe(static_cast<TFE_TensorHandle*>(nullptr));
+ PyErr_Clear();
+ if (PyArray_Check(value)) {
+ int desired_np_dtype = -1;
+ if (desired_dtype >= 0) {
+ if (!tensorflow::TF_DataType_to_PyArray_TYPE(
+ static_cast<TF_DataType>(desired_dtype), &desired_np_dtype)
+ .ok()) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Invalid dtype argument value ", desired_dtype)
+ .c_str());
+ return -1;
+ }
+ }
+ PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
+ int current_np_dtype = PyArray_TYPE(array);
+ auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr));
+ if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) ||
+ !PyArray_ISCARRAY(array)) {
+ int new_dtype =
+ desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype;
+ safe_value = tensorflow::make_safe(
+ PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0,
+ NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr));
+ if (PyErr_Occurred()) return -1;
+ if (safe_value == nullptr) {
+ PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value");
+ return -1;
+ }
+ value = safe_value.get();
+ }
+ handle = tensorflow::make_safe(NumpyToTensorHandle(value));
+ } else {
+ tensorflow::Tensor t;
+ // TODO(josh11b): Have PySeqToTensor set python errors instead of
+ // returning Status.
+ auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
+ if (!cppstatus.ok()) {
+ PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
+ return -1;
+ }
+ handle = tensorflow::make_safe(TFE_NewTensorHandle(t));
+ }
+ if (PyErr_Occurred()) return -1;
+ if (handle == nullptr) {
+ PyErr_SetString(PyExc_ValueError, "Error while creating an EagerTensor");
+ return -1;
+ }
+ TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
+ if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
+ auto out_status = tensorflow::make_safe(TF_NewStatus());
+ handle = tensorflow::make_safe(
+ EagerCast(GetContext(context), handle.get(), handle_dtype,
+ static_cast<TF_DataType>(desired_dtype), out_status.get()));
+ if (TF_GetCode(out_status.get()) != TF_OK) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ tensorflow::strings::StrCat("Error while casting from DataType ",
+ handle_dtype, " to ", desired_dtype, ". ",
+ TF_Message(out_status.get()))
+ .c_str());
+ return -1;
+ }
+ handle_dtype = TFE_TensorHandleDataType(handle.get());
+ }
+
+ // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
+ // memory. We approximate the same behavior for eager execution - keeping
+ // int32 tensors in host memory.
+ //
+ // We do so to preclude the need for callers into such kernels from having to
+ // explicitly place the int32 tensors in host memory. For example, without
+ // this, one needed:
+ //
+ // with tf.device('/gpu:0'):
+ // ...// code here
+ // with tf.device('/cpu:0'):
+ // shape = tf.constant(...)
+ // y = tf.random_uniform(shape)
+ //
+ // Without the CPU device block, tfe.ops.random_uniform would fail since the
+ // kernel expects the shape in host memory.
+ //
+ // With this support, we simplify the code:
+ //
+ // with tf.device('/gpu:0'):
+ // y = tf.random_uniform(...)
+ //
+ // The approximation is not exact there are GPU kernels which do not require
+ // host memory for int32 tensors. This will lead to a discrepancy between
+ // eager and graph execution.
+ // TODO(ashankar): Fix this.
+ if (handle_dtype != TF_INT32) {
+ // Note that this is a shallow copy and will share the underlying buffer
+ // if copying to the same device.
+ handle = tensorflow::make_safe(CopyToDevice(handle.get(), context, device));
+ if (handle == nullptr) return -1;
+ }
+ self->handle = handle.release();
+ return 0;
+}
+
+// tp_dealloc for EagerTensor.
+void EagerTensor_dealloc(EagerTensor* self) {
+ Py_DECREF(self->handle_data);
+ Py_DECREF(self->keras_mask);
+ TFE_DeleteTensorHandle(self->handle);
+ self->handle = nullptr;
+ PyObject* id = PyLong_FromLongLong(self->id);
+ PyObject* func = PyObject_GetAttrString(reinterpret_cast<PyObject*>(self),
+ "_delete_trace");
+ Py_TYPE(self)->tp_free(self);
+ self = nullptr;
+ // Note that we run `func` after calling `tp_free`. Otherwise calling that
+ // function can potentially trigger garbage collection that observes `self`
+ // in this half deleted state and crashes.
+ // Note that `func` is a staticmethod and does not need `self` to be around
+ // for running.
+ // We clear (and later restore) any errors that have already been set. Else
+ // these erorrs may appear randomly as part of the function execution.
+ PyObject *a, *b, *c;
+ PyErr_Fetch(&a, &b, &c);
+ PyObject_CallFunctionObjArgs(func, id, nullptr);
+ PyErr_Restore(a, b, c);
+ Py_DECREF(func);
+ Py_DECREF(id);
+}
+
+// Getter for `_id`.
+static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
+ return PyLong_FromLongLong(self->id);
+}
+
+// Getter for `_datatype_enum`.
+static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
+ return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
+}
+
+// Getter for `_shape_tuple`.
+static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
+ auto handle = self->handle;
+ int n = TFE_TensorHandleNumDims(handle);
+ PyObject* shape = PyTuple_New(n);
+ if (PyErr_Occurred()) return nullptr;
+ for (int i = 0; i < n; ++i) {
+ PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i));
+ if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
+ Py_DECREF(shape);
+ if (dim != nullptr) Py_DECREF(dim);
+ PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
+ return nullptr;
+ }
+ }
+ return shape;
+}
+
+static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
+ Py_INCREF(self->handle_data);
+ return self->handle_data;
+}
+
+static int EagerTensor_settensor_handle(EagerTensor* self, PyObject* value,
+ void* unused) {
+ Py_DECREF(self->handle_data);
+ Py_INCREF(value);
+ self->handle_data = value;
+ return 0;
+}
+
+static PyObject* EagerTensor_keras_mask(EagerTensor* self, void* unused) {
+ Py_INCREF(self->keras_mask);
+ return self->keras_mask;
+}
+
+static int EagerTensor_setkeras_mask(EagerTensor* self, PyObject* value,
+ void* unused) {
+ Py_DECREF(self->keras_mask);
+ Py_INCREF(value);
+ self->keras_mask = value;
+ return 0;
+}
+// Function `_copy_to_device`.
+static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
+ PyObject* kwds) {
+ const char* kwlist[] = {"context", "device", nullptr};
+ PyObject* ctx = nullptr;
+ PyObject* dev = nullptr;
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast<char**>(kwlist),
+ &ctx, &dev) ||
+ !ctx || !dev) {
+ return nullptr;
+ }
+ auto handle = CopyToDevice(self->handle, ctx, dev);
+ return EagerTensorFromHandle(handle);
+}
+
+// Function `_numpy`.
+// Convert an EagerTensor to a Python numpy.ndarray object.
+// The two may share underlying storage so changes to one may reflect in the
+// other.
+// Note that if `self` is not on CPU, we raise an Exception.
+static PyObject* EagerTensor_numpy(EagerTensor* self) {
+ auto status = tensorflow::make_safe(TF_NewStatus());
+ const tensorflow::Tensor* t =
+ TFE_TensorHandleUnderlyingTensorInHostMemory(self->handle, status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get()));
+ return nullptr;
+ }
+ PyObject* ret = nullptr;
+ auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
+ if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) {
+ Py_XDECREF(ret);
+ return nullptr;
+ } else {
+ return ret;
+ }
+}
+
+// Getter `device`.
+static PyObject* EagerTensor_device(EagerTensor* self) {
+#if PY_MAJOR_VERSION >= 3
+ return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle));
+#else
+ return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle));
+#endif
+}
+
+static PyGetSetDef EagerTensor_getseters[] = {
+ {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
+ const_cast<char*>("_id"), nullptr},
+ {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
+ const_cast<char*>("device"), nullptr},
+ {const_cast<char*>("_handle_data"), (getter)EagerTensor_tensor_handle,
+ (setter)EagerTensor_settensor_handle, const_cast<char*>("_tensor_handle"),
+ nullptr},
+ {const_cast<char*>("_keras_mask"), (getter)EagerTensor_keras_mask,
+ (setter)EagerTensor_setkeras_mask, const_cast<char*>("_keras_mask"),
+ nullptr},
+ {nullptr} /* Sentinel */
+};
+
+static PyMethodDef EagerTensor_methods[] = {
+ {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
+ PyDoc_STR("_numpy")},
+ {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
+ PyDoc_STR("_datatype_enum")},
+ {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
+ PyDoc_STR("_shape_tuple")},
+ {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
+ METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
+ {nullptr, nullptr},
+};
+
+// Note that here we are trying to dynamically create a new class as a subclass
+// of a "HEAPTYPE" class that is itself created in python code and passed in at
+// runtime. This is fairly atypical and undocumented.
+//
+// We use the following strategy for this. Unfortunately, we have to use
+// different approaches for python2.x vs python3.x
+// For python2.x, we create the class as a static type and set its tp_base to
+// the passed in type. Unfortunately setting tp_flags to include
+// Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
+// initialization of the underlying PyHeapTypeObject and not doing that leads to
+// some random crashes especially during garbage collection.
+// python3.x explicitly disables a static subclass of a HEAPTYPE base class.
+// However it provides a new function, PyType_FromSpecWithBases, to create
+// types dynamically.
+
+// Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
+PyTypeObject* EagerTensorType = nullptr;
+
+#if PY_MAJOR_VERSION >= 3
+static PyType_Slot EagerTensor_Type_slots[] = {
+ Py_tp_dealloc,
+ reinterpret_cast<void*>(EagerTensor_dealloc),
+ Py_tp_methods,
+ reinterpret_cast<void*>(EagerTensor_methods),
+ Py_tp_getset,
+ reinterpret_cast<void*>(EagerTensor_getseters),
+ Py_tp_init,
+ reinterpret_cast<void*>(EagerTensor_init),
+ 0,
+ nullptr,
+};
+
+PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0,
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
+ EagerTensor_Type_slots};
+#else
+// TODO(agarwal): support active_trace.
+static PyTypeObject _EagerTensorType = {
+ // clang-format off
+ PyVarObject_HEAD_INIT(nullptr, 0)
+ // clang-format on
+ "EagerTensor", /* tp_name */
+ sizeof(EagerTensor), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ (destructor)EagerTensor_dealloc, /* tp_dealloc */
+ nullptr, /* tp_print */
+ nullptr, /* tp_getattr */
+ nullptr, /* tp_setattr */
+ nullptr, /* tp_compare */
+ nullptr, /* tp_repr */
+ nullptr, /* tp_as_number */
+ nullptr, /* tp_as_sequence */
+ nullptr, /* tp_as_mapping */
+ nullptr, /* tp_hash */
+ nullptr, /* tp_call */
+ nullptr, /* tp_str */
+ nullptr, /* tp_getattro */
+ nullptr, /* tp_setattro */
+ nullptr, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT, /* tp_flags */
+ nullptr, /* tp_doc */
+ nullptr, /* tp_traverse */
+ nullptr, /* tp_clear */
+ nullptr, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ nullptr, /* tp_iter */
+ nullptr, /* tp_iternext */
+ EagerTensor_methods, /* tp_methods */
+ nullptr, /* tp_members */
+ EagerTensor_getseters, /* tp_getset */
+ nullptr, /* tp_base */
+ nullptr, /* tp_dict */
+ nullptr, /* tp_descr_get */
+ nullptr, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ (initproc)EagerTensor_init, /* tp_init */
+ nullptr, /* tp_alloc */
+ nullptr, /* tp_new */
+};
+
+#endif
+
+} // extern "C"
+
+bool EagerTensor_CheckExact(const PyObject* o) {
+ return Py_TYPE(o) == EagerTensorType;
+}
+
+TFE_TensorHandle* EagerTensorHandle(const PyObject* o) {
+ return reinterpret_cast<const EagerTensor*>(o)->handle;
+}
+
+PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
+ if (handle == nullptr) {
+ return nullptr;
+ }
+ EagerTensor* t = reinterpret_cast<EagerTensor*>(
+ EagerTensorType->tp_new(EagerTensorType, Py_None, Py_None));
+ if (t != nullptr) {
+ t->id = get_uid();
+ Py_INCREF(Py_None);
+ t->handle_data = Py_None;
+ Py_INCREF(Py_None);
+ t->keras_mask = Py_None;
+ t->handle = handle;
+ }
+ return reinterpret_cast<PyObject*>(t);
+}
+
+PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
+ if (!PyType_Check(base_class)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting a class definition for `base_class` passed to ",
+ "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
+ .c_str());
+ return nullptr;
+ }
+ // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
+ // EagerTensor to allow for the space usage of the base class.
+ PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
+ if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Unable to create subclass EagerTensor from base class ",
+ Py_TYPE(base_class)->tp_name,
+ ". Need its size to be <= ", kMaxEagerTensorParentSize)
+ .c_str());
+ return nullptr;
+ }
+ if (base_class_type->tp_itemsize != 0) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Unable to create subclass EagerTensor from base class ",
+ Py_TYPE(base_class)->tp_name,
+ " which supports variable length instances.")
+ .c_str());
+ return nullptr;
+ }
+ Py_INCREF(base_class);
+#if PY_MAJOR_VERSION >= 3
+ PyObject* bases = PyTuple_New(1);
+ PyTuple_SET_ITEM(bases, 0, base_class);
+ EagerTensorType = reinterpret_cast<PyTypeObject*>(
+ PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
+ if (PyErr_Occurred()) {
+ return nullptr;
+ }
+ if (EagerTensorType == nullptr) {
+ PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
+ return nullptr;
+ }
+#else
+ _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
+
+ if (PyType_Ready(&_EagerTensorType) < 0) {
+ if (PyErr_Occurred()) return nullptr;
+ PyErr_SetString(PyExc_RuntimeError,
+ "Error while creating EagerTensor type.");
+ return nullptr;
+ }
+ EagerTensorType = &_EagerTensorType;
+ Py_INCREF(EagerTensorType);
+#endif
+ // We disable instance based attribute lookup. Its not clear if these
+ // dictionaries are correctly initialized in the first place.
+ EagerTensorType->tp_dictoffset = 0;
+ return reinterpret_cast<PyObject*>(EagerTensorType);
+}
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 3b887954d0..5a72f422cf 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include <Python.h>
@@ -44,38 +45,46 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
PyObject* attrs, TFE_OutputTensorHandles* outputs,
TF_Status* out_status);
-// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
-//
-// The two may share underlying storage so changes to one may reflect in the
-// other.
-PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status);
-
-// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
-//
-// The two may share underlying storage so changes to one may reflect in the
-// other.
-TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj);
-
-// Convert a Python sequence value to a TFE_TensorHandle.
-//
-// The dtype of the result is determined by the type of values found
-// in *obj, *dtype is the desired type but it is only considered a
-// hint. *dtype should be an integer representing the desired DataType
-// enum value, or Py_None. Unlike TFE_Py_NumpyToTensorHandle, this
-// always makes a copy. Returns nullptr and raises an exception on
-// error.
-// TODO(josh11b): Cast to dtype automatically.
-TFE_TensorHandle* TFE_Py_SequenceToTensorHandle(PyObject* obj, PyObject* dtype);
-
// Registers e as the Exception class for handling not ok Status. Returns
// Py_None if registration succeeds, else throws a TypeError and returns NULL.
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
-// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using the
-// class registered via TFE_Py_RegisterExceptionClass) and returns -1.
-int TFE_Py_MaybeRaiseException(TF_Status* status);
+// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
+// `exception` if not nullptr, else using the class registered via
+// TFE_Py_RegisterExceptionClass), and returns -1.
+int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception);
+
+// Returns 0 if 'status' is ok. Otherwise, raises an exception (using
+// `exception` if not nullptr, else using the class registered via
+// TFE_Py_RegisterExceptionClass), and returns -1.
+int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
+ PyObject* exception);
// Returns the string associated with the passed-in python object.
char* TFE_GetPythonString(PyObject* o);
+// Returns a unique id on each call.
+int64_t get_uid();
+
+// Wraps the output of get_uid as a Python Long object. Ownership is passed to
+// the caller.
+PyObject* TFE_Py_UID();
+
+// Deleter for Context objects, called from the Capsule that owns it.
+void TFE_DeleteContextCapsule(PyObject* context);
+
+// Returns true if o is an instance of EagerTensor, but not a subclass. Else
+// returns false.
+bool EagerTensor_CheckExact(const PyObject* o);
+
+// Helper function to construct a new EagerTensor from a TFE_TensorHandle.
+PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle);
+
+// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
+TFE_TensorHandle* EagerTensorHandle(const PyObject* o);
+
+// Creates the `EagerTensor` class by subclassing `base_class` and returns the
+// newly created type, or nullptr on error.
+PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
+
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index b6fd9d6b44..a2079d009f 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -13,16 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Must be included first.
-#include "tensorflow/python/lib/core/numpy.h"
-
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/python/lib/core/ndarray_tensor.h"
-#include "tensorflow/python/lib/core/py_seq_tensor.h"
+#include "tensorflow/core/platform/types.h"
using tensorflow::string;
@@ -320,6 +316,14 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs,
}
}
}
+
+// Python subclass of Exception that is created on not ok Status.
+tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
+PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
+
+static tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
+static tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
+
} // namespace
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
@@ -352,65 +356,6 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
TFE_DeleteOp(op);
}
-PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status) {
- const tensorflow::Tensor* t =
- TFE_TensorHandleUnderlyingTensorInHostMemory(h, status);
- if (TF_GetCode(status) != TF_OK) {
- Py_RETURN_NONE;
- }
- PyObject* ret = nullptr;
- auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
- if (!cppstatus.ok()) {
- TF_SetStatus(status, TF_Code(cppstatus.code()),
- cppstatus.error_message().c_str());
- }
- if (ret != nullptr) return ret;
- Py_RETURN_NONE;
-}
-
-namespace {
-// Python subclass of Exception that is created on not ok Status.
-tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
-PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
-
-void PyRaiseException(TF_Code error_code, const char* msg) {
- tensorflow::mutex_lock l(exception_class_mutex);
- if (exception_class != nullptr) {
- PyErr_SetObject(exception_class, Py_BuildValue("si", msg, error_code));
- } else {
- PyErr_SetString(PyExc_RuntimeError, msg);
- }
-}
-
-} // namespace
-
-TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) {
- tensorflow::Tensor t;
- auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
- if (cppstatus.ok()) {
- return TFE_NewTensorHandle(t);
- } else {
- PyRaiseException(TF_INVALID_ARGUMENT,
- tensorflow::strings::StrCat(
- "failed to convert numpy ndarray to a Tensor (",
- cppstatus.error_message(), ")")
- .c_str());
- }
- return nullptr;
-}
-
-TFE_TensorHandle* TFE_Py_SequenceToTensorHandle(PyObject* obj,
- PyObject* dtype) {
- tensorflow::Tensor t;
- auto cppstatus = tensorflow::PySeqToTensor(obj, dtype, &t);
- if (cppstatus.ok()) {
- return TFE_NewTensorHandle(t);
- } else {
- PyRaiseException(TF_INVALID_ARGUMENT, cppstatus.error_message().c_str());
- }
- return nullptr;
-}
-
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
tensorflow::mutex_lock l(exception_class_mutex);
if (exception_class != nullptr) {
@@ -429,9 +374,39 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
}
}
-int TFE_Py_MaybeRaiseException(TF_Status* status) {
+int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
if (TF_GetCode(status) == TF_OK) return 0;
- PyRaiseException(TF_GetCode(status), TF_Message(status));
+ const char* msg = TF_Message(status);
+ if (exception == nullptr) {
+ tensorflow::mutex_lock l(exception_class_mutex);
+ if (exception_class != nullptr) {
+ PyErr_SetObject(exception_class,
+ Py_BuildValue("si", msg, TF_GetCode(status)));
+ return -1;
+ } else {
+ exception = PyExc_RuntimeError;
+ }
+ }
+ // May be update already set exception.
+ PyErr_SetString(exception, msg);
+ return -1;
+}
+
+int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
+ PyObject* exception) {
+ if (status.ok()) return 0;
+ const char* msg = status.error_message().c_str();
+ if (exception == nullptr) {
+ tensorflow::mutex_lock l(exception_class_mutex);
+ if (exception_class != nullptr) {
+ PyErr_SetObject(exception_class, Py_BuildValue("si", msg, status.code()));
+ return -1;
+ } else {
+ exception = PyExc_RuntimeError;
+ }
+ }
+ // May be update already set exception.
+ PyErr_SetString(exception, msg);
return -1;
}
@@ -446,3 +421,18 @@ char* TFE_GetPythonString(PyObject* o) {
#endif
return nullptr;
}
+
+int64_t get_uid() {
+ tensorflow::mutex_lock l(_uid_mutex);
+ return _uid++;
+}
+
+PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
+
+void TFE_DeleteContextCapsule(PyObject* context) {
+ TF_Status* status = TF_NewStatus();
+ TFE_Context* ctx =
+ reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
+ TFE_DeleteContext(ctx, status);
+ TF_DeleteStatus(status);
+}
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index e4fdaa111a..84814d48fd 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -135,9 +135,9 @@ class Tape(object):
# adding an explicit stack if this ever gets out of hand
self._delete_tensor_id(tensor_id)
- def delete_trace(self, tensor):
+ def delete_trace(self, tensor_id):
"""Deletes any trace we have for this tensor."""
- self._delete_tensor_id(tid(tensor))
+ self._delete_tensor_id(tensor_id)
def export(self):
"""Exports the internal state of this tape.
@@ -237,10 +237,10 @@ def record_operation(op_type, output_tensors, input_tensors, side_outputs,
backward_function)
-def delete_trace(tensor):
+def delete_trace(tensor_id):
"""Deletes traces for this Tensor from all tapes in the stack."""
for t in _tape_stack.stack:
- t.delete_trace(tensor)
+ t.delete_trace(tensor_id)
def top_tape_watched_tensors():
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 8a8cf0e2c3..953807fc2a 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -21,26 +21,90 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import context
+from tensorflow.python.eager import core
from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+def _create_tensor(value, device=None, dtype=None):
+ ctx = context.context()
+ if device is None:
+ device = ctx.device_name
+ if dtype is not None:
+ dtype = dtype.as_datatype_enum
+ try:
+ return ops.EagerTensor(
+ value, context=ctx._handle, device=device, dtype=dtype)
+ except core._NotOkStatusException as e: # pylint: disable=protected-access
+ raise core._status_to_exception(e.code, e.message)
+
+
class TFETensorTest(test_util.TensorFlowTestCase):
def testScalarTensor(self):
- t = constant_op.constant(3)
- self.assertEqual(t.numpy(), constant_op.constant(np.array(3)).numpy())
+ t = _create_tensor(3, dtype=dtypes.int32)
+ self.assertEqual(t.numpy(), _create_tensor(np.array(3)).numpy())
self.assertEqual(dtypes.int32, t.dtype)
self.assertEqual(0, t.shape.ndims)
self.assertAllEqual([], t.shape.as_list())
+ self.assertIn("tf.Tensor", str(t))
+ self.assertIn("tf.Tensor", repr(t))
+
+ def testBadConstructorArgs(self):
+ ctx = context.context()
+ handle = ctx._handle
+ device = ctx.device_name
+ # Missing context.
+ with self.assertRaisesRegexp(
+ TypeError, r"Required argument 'context' \(pos 2\) not found"):
+ ops.EagerTensor(1, device=device)
+ # Missing device.
+ with self.assertRaisesRegexp(
+ TypeError, r"Required argument 'device' \(pos 3\) not found"):
+ ops.EagerTensor(1, context=handle)
+ # Bad dtype type.
+ with self.assertRaisesRegexp(TypeError,
+ "Expecting a DataType value for dtype. Got"):
+ ops.EagerTensor(1, context=handle, device=device, dtype="1")
+ # Following errors happen when trying to copy to GPU.
+ if not context.context().num_gpus():
+ self.skipTest("No GPUs found")
+ with ops.device("/device:GPU:0"):
+ device = ctx.device_name
+ # Bad context.
+ with self.assertRaisesRegexp(
+ TypeError, "Expecting a PyCapsule encoded context handle. Got"):
+ ops.EagerTensor(1.0, context=1, device=device)
+ # Bad device.
+ with self.assertRaisesRegexp(
+ TypeError, "Error parsing device argument to CopyToDevice"):
+ ops.EagerTensor(1.0, context=handle, device=1)
+
+ def testNumpyValue(self):
+ values = np.array([3.0])
+ t = _create_tensor(values)
+ self.assertAllEqual(values, t.numpy())
+
+ def testNumpyValueWithCast(self):
+ values = np.array([3.0], dtype=np.float32)
+ t = _create_tensor(values, dtype=dtypes.float64)
+ self.assertAllEqual(values, t.numpy())
+ ctx = context.context()
+ # Bad dtype value.
+ with self.assertRaisesRegexp(TypeError, "Invalid dtype argument value"):
+ ops.EagerTensor(
+ values, context=ctx._handle, device=ctx.device_name, dtype=12345)
+
+ def testNumpyOrderHandling(self):
+ n = np.array([[1, 2], [3, 4]], order="F")
+ t = _create_tensor(n)
+ self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
def testTensorAndNumpyMatrix(self):
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
- actual = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ actual = _create_tensor([[1.0, 2.0], [3.0, 4.0]])
self.assertAllEqual(expected, actual.numpy())
self.assertEqual(np.float32, actual.numpy().dtype)
self.assertEqual(dtypes.float32, actual.dtype)
@@ -48,56 +112,50 @@ class TFETensorTest(test_util.TensorFlowTestCase):
def testFloatDowncast(self):
# Unless explicitly specified, float64->float32
- t = constant_op.constant(3.0)
+ t = _create_tensor(3.0)
self.assertEqual(dtypes.float32, t.dtype)
- t = constant_op.constant(3.0, dtype=dtypes.float64)
+ t = _create_tensor(3.0, dtype=dtypes.float64)
self.assertEqual(dtypes.float64, t.dtype)
def testBool(self):
- t = constant_op.constant(False)
+ t = _create_tensor(False)
if t:
self.assertFalse(True)
def testIntDowncast(self):
- t = constant_op.constant(3)
+ t = _create_tensor(3)
self.assertEqual(dtypes.int32, t.dtype)
- t = constant_op.constant(3, dtype=dtypes.int64)
+ t = _create_tensor(3, dtype=dtypes.int64)
self.assertEqual(dtypes.int64, t.dtype)
- t = constant_op.constant(2**33)
+ t = _create_tensor(2**33)
self.assertEqual(dtypes.int64, t.dtype)
def testTensorCreationFailure(self):
- with self.assertRaises(Exception):
+ with self.assertRaises(ValueError):
# Should fail because the each row of the Python object has a different
# number of columns.
- self.assertEqual(None, constant_op.constant([[1], [1, 2]]))
-
- def testNumpyOrderHandling(self):
- n = np.array([[1, 2], [3, 4]], order="F")
- t = constant_op.constant(n)
- self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
+ self.assertEqual(None, _create_tensor([[1], [1, 2]]))
def testMultiLineTensorStr(self):
- t = constant_op.constant(np.eye(3))
+ t = _create_tensor(np.eye(3))
tensor_str = str(t)
self.assertIn("shape=%s, dtype=%s" % (t.shape, t.dtype.name), tensor_str)
self.assertIn(str(t.numpy()), tensor_str)
def testMultiLineTensorRepr(self):
- t = constant_op.constant(np.eye(3))
+ t = _create_tensor(np.eye(3))
tensor_repr = repr(t)
self.assertTrue(tensor_repr.startswith("<"))
self.assertTrue(tensor_repr.endswith(">"))
- self.assertIn(
- "id=%d, shape=%s, dtype=%s, numpy=\n%r" % (
- t._id, t.shape, t.dtype.name, t.numpy()), tensor_repr)
+ self.assertIn("id=%d, shape=%s, dtype=%s, numpy=\n%r" %
+ (t._id, t.shape, t.dtype.name, t.numpy()), tensor_repr)
def testTensorStrReprObeyNumpyPrintOptions(self):
orig_threshold = np.get_printoptions()["threshold"]
orig_edgeitems = np.get_printoptions()["edgeitems"]
np.set_printoptions(threshold=2, edgeitems=1)
- t = constant_op.constant(np.arange(10, dtype=np.int32))
+ t = _create_tensor(np.arange(10, dtype=np.int32))
self.assertIn("[0 ..., 9]", str(t))
self.assertIn("[0, ..., 9]", repr(t))
@@ -105,30 +163,30 @@ class TFETensorTest(test_util.TensorFlowTestCase):
np.set_printoptions(threshold=orig_threshold, edgeitems=orig_edgeitems)
def testZeroDimTensorStr(self):
- t = constant_op.constant(42)
+ t = _create_tensor(42)
self.assertIn("42, shape=(), dtype=int32", str(t))
def testZeroDimTensorRepr(self):
- t = constant_op.constant(42)
+ t = _create_tensor(42)
self.assertTrue(repr(t).startswith("<"))
self.assertTrue(repr(t).endswith(">"))
self.assertIn("id=%d, shape=(), dtype=int32, numpy=42" % t._id, repr(t))
def testZeroSizeTensorStr(self):
- t = constant_op.constant(np.zeros(0, dtype=np.float32))
+ t = _create_tensor(np.zeros(0, dtype=np.float32))
self.assertIn("[], shape=(0,), dtype=float32", str(t))
def testZeroSizeTensorRepr(self):
- t = constant_op.constant(np.zeros(0, dtype=np.float32))
+ t = _create_tensor(np.zeros(0, dtype=np.float32))
self.assertTrue(repr(t).startswith("<"))
self.assertTrue(repr(t).endswith(">"))
- self.assertIn(
- "id=%d, shape=(0,), dtype=float32, numpy=%r" % (t._id, t.numpy()),
- repr(t))
+ self.assertIn("id=%d, shape=(0,), dtype=float32, numpy=%r" % (t._id,
+ t.numpy()),
+ repr(t))
def testStringTensor(self):
t_np_orig = np.array([[b"a", b"ab"], [b"abc", b"abcd"]])
- t = constant_op.constant(t_np_orig)
+ t = _create_tensor(t_np_orig)
t_np = t.numpy()
self.assertTrue(np.all(t_np == t_np_orig), "%s vs %s" % (t_np, t_np_orig))
@@ -137,9 +195,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
self.skipTest("No GPUs found")
with ops.device("/device:GPU:0"):
with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Can't copy Tensor with type string to device"):
- constant_op.constant("test string")
+ RuntimeError, "Can't copy Tensor with type string to device"):
+ _create_tensor("test string")
if __name__ == "__main__":