aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-05-25 13:58:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-25 14:01:43 -0700
commitea72dd6c0c8a5572d2ca34be4ca0ab43efc4570d (patch)
tree5a14a3128201bc9ee0ce2203d0191245f10b0ece /tensorflow
parenta6369634688dafa4c061f499a6adbb4e78666832 (diff)
Add EagerTensor profiler and device shape utilities
This change includes the following steps to make EagerTensor profiler work: - Add a PaddedShapeFn to XlaDevice::Metadata. We need a backend-independent way to get a fully-padded shape and its layout on the device. This function is set during device construction. CPU and GPU devices effectively get an identity function since they neither change the layout nor pad. TPU gets the appropriate function. - Add TFE_TensorDebugInfo struct and C API methods for it. These methods are necessary to fetch the shape and layout from under the C API to the Python level. This can be a home for more debug information later. - Make EagerTensor weak referencable. This involves adding a pointer to the list of current weak references. This addition should have negligible overhead when profiler is not used. The only operations on this field are setting it to null on construction and checking if it is null on destruction. - Adding C++ functions callable from Python to register an instance of EagerTensorProfiler and retrieve debug information for a given EagerTensor. These functions are used in the new "inspect" module. - Finally, writing the actual profiler. PiperOrigin-RevId: 198098380
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/c/eager/BUILD27
-rw-r--r--tensorflow/c/eager/c_api.cc4
-rw-r--r--tensorflow/c/eager/c_api.h39
-rw-r--r--tensorflow/c/eager/c_api_debug.cc167
-rw-r--r--tensorflow/c/eager/c_api_debug_test.cc50
-rw-r--r--tensorflow/c/eager/c_api_internal.h8
-rw-r--r--tensorflow/c/eager/c_api_test.cc125
-rw-r--r--tensorflow/c/eager/c_api_test_util.cc163
-rw-r--r--tensorflow/c/eager/c_api_test_util.h53
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc11
-rw-r--r--tensorflow/compiler/jit/xla_device.cc36
-rw-r--r--tensorflow/compiler/jit/xla_device.h21
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc3
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc11
-rw-r--r--tensorflow/contrib/cmake/tf_c.cmake1
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc181
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h13
-rw-r--r--tensorflow/python/pywrap_tfe.i2
18 files changed, 733 insertions, 182 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 9ce781fab0..f265da2c2c 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -14,6 +14,7 @@ tf_cuda_library(
name = "c_api",
srcs = [
"c_api.cc",
+ "c_api_debug.cc",
"c_api_internal.h",
],
hdrs = ["c_api.h"],
@@ -45,6 +46,7 @@ tf_cuda_library(
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/jit",
+ "//tensorflow/compiler/jit:xla_device",
],
"//conditions:default": [],
}) + [
@@ -99,9 +101,31 @@ tf_cuda_library(
],
)
+tf_cuda_library(
+ name = "c_api_test_util",
+ testonly = 1,
+ srcs = ["c_api_test_util.cc"],
+ hdrs = ["c_api_test_util.h"],
+ visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow:__subpackages__",
+ ],
+ deps = [
+ ":c_api",
+ "//tensorflow/c:c_test_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cuda_cc_test(
name = "c_api_test",
- srcs = ["c_api_test.cc"],
+ srcs = [
+ "c_api_debug_test.cc",
+ "c_api_test.cc",
+ ],
extra_copts = tfe_xla_copts(),
tags = [
"guitar",
@@ -109,6 +133,7 @@ tf_cuda_cc_test(
],
deps = [
":c_api",
+ ":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 216210c88c..81221c4078 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -73,10 +73,6 @@ string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
-#ifdef TENSORFLOW_EAGER_USE_XLA
-std::atomic_int_fast64_t func_id_generator(0);
-#endif // TENSORFLOW_EAGER_USE_XLA
-
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 574a097e0d..1862af3ce2 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -191,6 +191,45 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(
TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name,
TF_Status* status);
+// Debugging/Profiling information for TFE_TensorHandle
+//
+// TFE_TensorDebugInfo contains information useful for debugging and
+// profiling tensors.
+typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
+
+// Retrieves TFE_TensorDebugInfo for `handle`.
+// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller
+// is responsible for deleting returned TFE_TensorDebugInfo.
+// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate
+// error and nullptr is returned. This function can block till the operation
+// that produces `handle` has completed.
+TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
+ TFE_TensorHandle* handle, TF_Status* status);
+
+// Deletes `debug_info`.
+TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
+ TFE_TensorDebugInfo* debug_info);
+
+// Returns the number of dimensions used to represent the tensor on its device.
+// The number of dimensions used to reprensent the tensor on device can be
+// different from the number returned by TFE_TensorHandleNumDims.
+// The return value was current at the time of TFE_TensorDebugInfo creation.
+TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
+ TFE_TensorDebugInfo* debug_info);
+
+// Returns the number of elements in dimension `dim_index`.
+// Tensor representation on device can be transposed from its representation
+// on host. The data contained in dimension `dim_index` on device
+// can correspond to the data contained in another dimension in on-host
+// representation. The dimensions are indexed using the standard TensorFlow
+// major-to-minor order (slowest varying dimension first),
+// not the XLA's minor-to-major order.
+// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns
+// the number of elements in a dimension after padding.
+// The return value was current at the time of TFE_TensorDebugInfo creation.
+TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim(
+ TFE_TensorDebugInfo* debug_info, int dim_index);
+
// Description of the TensorFlow op to execute.
//
// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e.,
diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc
new file mode 100644
index 0000000000..5006b76f19
--- /dev/null
+++ b/tensorflow/c/eager/c_api_debug.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include <vector>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api_internal.h"
+#ifdef TENSORFLOW_EAGER_USE_XLA
+#include "tensorflow/compiler/jit/xla_device.h"
+#endif // TENSORFLOW_EAGER_USE_XLA
+
+using tensorflow::int64;
+using tensorflow::string;
+
+namespace {
+
+std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
+ TF_Status* status) {
+ std::vector<int64> shape;
+ int rank = TFE_TensorHandleNumDims(handle, status);
+ if (!status->status.ok()) {
+ return shape;
+ }
+ shape.reserve(rank);
+ for (int i = 0; i < rank; ++i) {
+ shape.push_back(TFE_TensorHandleDim(handle, i, status));
+ if (!status->status.ok()) {
+ return shape;
+ }
+ }
+ return shape;
+}
+
+} // namespace
+
+extern "C" {
+
+TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
+ TFE_TensorHandle* handle, TF_Status* status) {
+ const tensorflow::Tensor* tensor;
+ status->status = handle->handle->Tensor(&tensor);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+
+ tensorflow::Device* device;
+ status->status = handle->handle->Device(&device);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+ // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
+ tensorflow::XlaDevice* xla_device =
+ dynamic_cast<tensorflow::XlaDevice*>(device);
+ if (xla_device != nullptr) {
+ tensorflow::XlaDevice::PaddedShapeFn shape_fn =
+ xla_device->metadata().padded_shape_fn();
+ xla::Shape padded_shape;
+ status->status = shape_fn(*tensor, &padded_shape);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ if (VLOG_IS_ON(3)) {
+ std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
+ if (!status->status.ok()) {
+ // Ignore the status here as we are simply logging.
+ status->status = tensorflow::Status::OK();
+ } else {
+ VLOG(3) << "Fully padded shape of ["
+ << tensorflow::str_util::Join(shape_to_log, ", ") << "] is "
+ << padded_shape.DebugString();
+ }
+ }
+
+ if (xla::ShapeUtil::IsTuple(padded_shape)) {
+ if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
+ // Currently, the only case of XlaTensor containing a tuple shape is to
+ // represent 64 bit ints, doubles, and complex numbers (we don't support
+ // 64bit complex numbers).
+ status->status = tensorflow::errors::InvalidArgument(
+ "XlaTensors should only contain tuples of size 2. Shape: ",
+ padded_shape.DebugString());
+ return nullptr;
+ }
+
+ // shape0 is not a const& because we will assign it to padded_shape below.
+ // It is illegal to assign a part of a message to itself.
+ xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
+ const xla::Shape& shape1 =
+ xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
+ if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "XlaTensors should not contain nested tuples. Shape: ",
+ padded_shape.DebugString());
+ return nullptr;
+ }
+ if (!xla::ShapeUtil::Equal(shape0, shape1)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Subshapes of XlaTensors should be the same. Shape: ",
+ padded_shape.DebugString());
+ return nullptr;
+ }
+
+ // Since the only case we handle here are two equal subshapes, we
+ // simply return one of them. The caller will interpret it as this
+ // shape directly storing the 64bit types. This approximation is good
+ // enough for this API's debugging use case.
+ padded_shape = shape0;
+ }
+
+ int rank = padded_shape.dimensions_size();
+ std::vector<int64> dev_dims;
+ dev_dims.reserve(rank);
+ if (rank == 1) {
+ // Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
+ dev_dims.push_back(padded_shape.dimensions(0));
+ } else {
+ for (int i = rank - 1; i >= 0; --i) {
+ int64 dim_index = padded_shape.layout().minor_to_major(i);
+ dev_dims.push_back(padded_shape.dimensions(dim_index));
+ }
+ }
+ status->status = tensorflow::Status::OK();
+ return new TFE_TensorDebugInfo(dev_dims);
+ }
+#endif // TENSORFLOW_EAGER_USE_XLA
+
+ // If the tensor is not an XLA tensor, the device shape is
+ // the same as regular tensor shape.
+ std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ return new TFE_TensorDebugInfo(dev_dims);
+}
+
+TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
+ TFE_TensorDebugInfo* debug_info) {
+ delete debug_info;
+}
+
+TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
+ TFE_TensorDebugInfo* debug_info) {
+ return debug_info->dev_dims.size();
+}
+
+TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim(
+ TFE_TensorDebugInfo* debug_info, int dim_index) {
+ return debug_info->dev_dims[dim_index];
+}
+
+} // extern "C"
diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc
new file mode 100644
index 0000000000..cddb9f6e00
--- /dev/null
+++ b/tensorflow/c/eager/c_api_debug_test.cc
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include <string.h>
+#include "tensorflow/c/eager/c_api_test_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+TEST(CApiDebug, ScalarCPU) {
+ TFE_TensorHandle* h = TestScalarTensorHandle();
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ ASSERT_EQ(0, TFE_TensorDebugInfoOnDeviceNumDims(debug_info));
+
+ TFE_DeleteTensorDebugInfo(debug_info);
+ TFE_DeleteTensorHandle(h);
+ TF_DeleteStatus(status);
+}
+
+TEST(CApiDebug, 2DCPU) {
+ TFE_TensorHandle* h = TestMatrixTensorHandle3X2();
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ ASSERT_EQ(2, TFE_TensorDebugInfoOnDeviceNumDims(debug_info));
+ // Shape is the same for CPU tensors.
+ EXPECT_EQ(3, TFE_TensorDebugInfoOnDeviceDim(debug_info, 0));
+ EXPECT_EQ(2, TFE_TensorDebugInfoOnDeviceDim(debug_info, 1));
+
+ TFE_DeleteTensorDebugInfo(debug_info);
+ TFE_DeleteTensorHandle(h);
+ TF_DeleteStatus(status);
+}
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 2b8384d720..04a6efc47c 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -107,6 +107,14 @@ struct TFE_TensorHandle {
tensorflow::TensorHandle* handle;
};
+struct TFE_TensorDebugInfo {
+ TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
+ : dev_dims(dims) {}
+
+ // Fully-padded, minor-to-major.
+ std::vector<tensorflow::int64> dev_dims;
+};
+
struct TFE_Op {
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
// primitive operation.
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 49646bb735..27ff5f7211 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include <string.h>
+#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -32,122 +33,6 @@ using tensorflow::string;
namespace {
-TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
- int64_t dims[] = {2, 2};
- double data[] = {1.0, 2.0, 3.0, 4.0};
- TF_Tensor* t = TF_AllocateTensor(
- TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
- memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TF_Status* status = TF_NewStatus();
- TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteTensor(t);
- TF_DeleteStatus(status);
- return th;
-}
-
-TFE_TensorHandle* TestMatrixTensorHandle() {
- int64_t dims[] = {2, 2};
- float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
- TF_Tensor* t = TF_AllocateTensor(
- TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
- memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TF_Status* status = TF_NewStatus();
- TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteTensor(t);
- TF_DeleteStatus(status);
- return th;
-}
-
-TFE_TensorHandle* TestMatrixTensorHandle3X2() {
- int64_t dims[] = {3, 2};
- double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
- TF_Tensor* t = TF_AllocateTensor(
- TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
- memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TF_Status* status = TF_NewStatus();
- TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteTensor(t);
- TF_DeleteStatus(status);
- return th;
-}
-
-TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
- TF_Status* status = TF_NewStatus();
-
- TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(op, a, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(op, b, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteStatus(status);
- TFE_OpSetAttrBool(op, "transpose_a", 0);
- TFE_OpSetAttrBool(op, "transpose_b", 0);
- TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
-
- return op;
-}
-
-TFE_TensorHandle* TestAxisTensorHandle() {
- int64_t dims[] = {1};
- int data[] = {1};
- TF_Tensor* t = TF_AllocateTensor(
- TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
- memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TF_Status* status = TF_NewStatus();
- TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteTensor(t);
- TF_DeleteStatus(status);
- return th;
-}
-
-TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
- TFE_TensorHandle* axis) {
- TF_Status* status = TF_NewStatus();
-
- TFE_Op* op = TFE_NewOp(ctx, "Min", status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(op, input, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(op, axis, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpSetAttrBool(op, "keep_dims", 1);
- TFE_OpSetAttrType(op, "Tidx", TF_INT32);
- TF_DeleteStatus(status);
- TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
-
- return op;
-}
-
-// If there is a GPU device, returns true and sets 'gpu_device_name'
-// accordingly.
-bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) {
- std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
- TF_NewStatus(), TF_DeleteStatus);
- TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
- CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
-
- const int num_devices = TF_DeviceListCount(devices);
- for (int i = 0; i < num_devices; ++i) {
- const string device_type(TF_DeviceListType(devices, i, status.get()));
- CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
- const string device_name(TF_DeviceListName(devices, i, status.get()));
- CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
- if (device_type == "GPU") {
- *gpu_device_name = device_name;
- LOG(INFO) << "Found GPU device " << device_name;
- TF_DeleteDeviceList(devices);
- return true;
- }
- }
- TF_DeleteDeviceList(devices);
- return false;
-}
-
void BM_InitOp(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
@@ -536,7 +421,7 @@ void TensorHandleSilentCopy(bool async) {
// Disable the test if no GPU is present.
string gpu_device_name;
- if (GetGPUDeviceName(ctx, &gpu_device_name)) {
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@@ -583,7 +468,7 @@ void TensorHandleSilentCopyLocal(bool async) {
// Disable the test if no GPU is present.
string gpu_device_name;
- if (GetGPUDeviceName(ctx, &gpu_device_name)) {
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@@ -624,7 +509,7 @@ void SetAndGetOpDevices(bool async) {
// Disable the test if no GPU is present.
string gpu_device_name;
- if (GetGPUDeviceName(ctx, &gpu_device_name)) {
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_OpSetDevice(matmul, "GPU:0", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
const char* device_name = TFE_OpGetDevice(matmul, status);
@@ -688,7 +573,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
- TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2();
+ TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2();
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0",
status);
diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc
new file mode 100644
index 0000000000..5607c9dcb0
--- /dev/null
+++ b/tensorflow/c/eager/c_api_test_util.cc
@@ -0,0 +1,163 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api_test_util.h"
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+using tensorflow::string;
+
+TFE_TensorHandle* TestScalarTensorHandle() {
+ float data[] = {1.0f};
+ TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
+ int64_t dims[] = {2, 2};
+ double data[] = {1.0, 2.0, 3.0, 4.0};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_TensorHandle* TestMatrixTensorHandle() {
+ int64_t dims[] = {2, 2};
+ float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
+ int64_t dims[] = {3, 2};
+ double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_TensorHandle* TestMatrixTensorHandle3X2() {
+ int64_t dims[] = {3, 2};
+ float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, a, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, b, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrBool(op, "transpose_a", 0);
+ TFE_OpSetAttrBool(op, "transpose_b", 0);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
+
+ return op;
+}
+
+TFE_TensorHandle* TestAxisTensorHandle() {
+ int64_t dims[] = {1};
+ int data[] = {1};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
+ TFE_TensorHandle* axis) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "Min", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, input, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, axis, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpSetAttrBool(op, "keep_dims", 1);
+ TFE_OpSetAttrType(op, "Tidx", TF_INT32);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
+
+ return op;
+}
+
+bool GetDeviceName(TFE_Context* ctx, string* device_name,
+ const char* device_type) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
+ CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ const int num_devices = TF_DeviceListCount(devices);
+ for (int i = 0; i < num_devices; ++i) {
+ const string dev_type(TF_DeviceListType(devices, i, status.get()));
+ CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
+ const string dev_name(TF_DeviceListName(devices, i, status.get()));
+ CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
+ if (dev_type == device_type) {
+ *device_name = dev_name;
+ LOG(INFO) << "Found " << device_type << " device " << *device_name;
+ TF_DeleteDeviceList(devices);
+ return true;
+ }
+ }
+ TF_DeleteDeviceList(devices);
+ return false;
+}
diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h
new file mode 100644
index 0000000000..474cae67c8
--- /dev/null
+++ b/tensorflow/c/eager/c_api_test_util.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
+#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include "tensorflow/core/platform/types.h"
+
+// Return a tensor handle containing a float scalar
+TFE_TensorHandle* TestScalarTensorHandle();
+
+// Return a tensor handle containing a 2x2 matrix of doubles
+TFE_TensorHandle* DoubleTestMatrixTensorHandle();
+
+// Return a tensor handle containing a 2x2 matrix of floats
+TFE_TensorHandle* TestMatrixTensorHandle();
+
+// Return a tensor handle containing a 3x2 matrix of doubles
+TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
+
+// Return a tensor handle containing a 3x2 matrix of floats
+TFE_TensorHandle* TestMatrixTensorHandle3X2();
+
+// Return a matmul op multiplying `a` by `b`.
+TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
+
+// Return an 1-D INT32 tensor containing a single value 1.
+TFE_TensorHandle* TestAxisTensorHandle();
+
+// Return an op taking minimum of `input` long `axis` dimension.
+TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
+ TFE_TensorHandle* axis);
+
+// If there is a device of type `device_type`, returns true
+// and sets 'device_name' accordingly.
+// `device_type` must be either "GPU" or "TPU".
+bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
+ const char* device_type);
+
+#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index ea9e036604..43648402f6 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -50,11 +50,12 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
(void)registrations;
std::unique_ptr<XlaDevice> device;
- TF_RETURN_IF_ERROR(
- XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
- name_prefix, registration,
- /*transfer_as_literal=*/false,
- /*shape_representation_fn=*/{}, &device));
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
+ DEVICE_CPU_XLA_JIT, options, name_prefix,
+ registration,
+ /*transfer_as_literal=*/false,
+ /*shape_representation_fn=*/{},
+ /*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index f13b46c532..ed007d603e 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -105,6 +106,25 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
return alloc_ptr;
}
+namespace {
+
+// Default PaddedShapeFn implementation that simply returns the unpadded
+// on-device shape. This is accurate for CPU and GPU devices that neither
+// transpose nor pad tensors.
+Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
+ const tensorflow::XlaTensor* xla_tensor =
+ tensorflow::XlaTensor::FromTensor(&tensor);
+ if (xla_tensor == nullptr) {
+ return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
+ }
+
+ const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
+ *shape = shaped_buffer.on_device_shape();
+ return Status::OK();
+}
+
+} // namespace
+
/* static */ Status XlaDevice::Create(
const string& platform_name, const string& device_name, int device_ordinal,
const string& jit_device_name, const SessionOptions& options,
@@ -112,7 +132,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
- std::unique_ptr<XlaDevice>* device) {
+ const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal;
@@ -133,17 +153,20 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
device->reset(new XlaDevice(
options, attrs, device_ordinal, DeviceType(jit_device_name),
- platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
+ platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
+ padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ PaddedShapeFn padded_shape_fn)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform),
- shape_representation_fn_(std::move(shape_representation_fn)) {}
+ shape_representation_fn_(std::move(shape_representation_fn)),
+ padded_shape_fn_(std::move(padded_shape_fn)) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@@ -178,10 +201,11 @@ XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
+ const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
+ const PaddedShapeFn& padded_shape_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name,
- shape_representation_fn),
+ shape_representation_fn, padded_shape_fn),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index d5d345d43b..02e88ee679 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -45,13 +45,19 @@ namespace tensorflow {
class XlaDevice : public LocalDevice {
public:
+ // Given a tensor, sets `xla::Shape*` the shape of tensor's representation
+ // on device, fully padded. On error, the contents of `xla::Shape*`
+ // are undefined.
+ typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn;
+
// Wrapper class to store metadata about the XlaDevice, where it can be
// retrieved e.g., when lazily creating the XlaCompilationCache device.
class Metadata {
public:
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ PaddedShapeFn padded_shape_fn);
// The index of the device on this host.
int device_ordinal() const;
@@ -62,12 +68,14 @@ class XlaDevice : public LocalDevice {
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
return shape_representation_fn_;
}
+ const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
private:
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+ PaddedShapeFn padded_shape_fn_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@@ -81,6 +89,8 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
+ // If padded_shape_fn is empty, a default implementation that returns
+ // the on-host shape is used.
static Status Create(
const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
@@ -88,12 +98,16 @@ class XlaDevice : public LocalDevice {
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
- std::unique_ptr<XlaDevice>* device);
+ const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);
+ // Creates a new XLA Device.
+ // If padded_shape_fn is empty, a default implementation that returns
+ // the logical on-device shape without padding is used.
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
+ const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
+ const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
@@ -110,6 +124,7 @@ class XlaDevice : public LocalDevice {
Tensor* tensor) override;
xla::LocalClient* client() const;
+ const Metadata& metadata() { return xla_metadata_; }
xla::StatusOr<se::Stream*> GetStream();
// If not already set, create and set GpuDeviceInfo.
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 26842fbe5c..c0d86a28c7 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -49,7 +49,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
- /*shape_representation_fn=*/{}, &device);
+ /*shape_representation_fn=*/{},
+ /*padded_shape_fn=*/{}, &device);
if (!status.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << status;
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 4146996f63..661187f4a8 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -48,11 +48,12 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
registration.compile_resource_ops = true;
std::unique_ptr<XlaDevice> device;
- TF_RETURN_IF_ERROR(XlaDevice::Create(
- "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT,
- options, name_prefix, registration,
- /*transfer_as_literal=*/false,
- /*shape_representation_fn=*/{}, &device));
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0,
+ DEVICE_INTERPRETER_XLA_JIT, options,
+ name_prefix, registration,
+ /*transfer_as_literal=*/false,
+ /*shape_representation_fn=*/{},
+ /*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
return Status::OK();
}
diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake
index 310fe58e05..bda5e26f43 100644
--- a/tensorflow/contrib/cmake/tf_c.cmake
+++ b/tensorflow/contrib/cmake/tf_c.cmake
@@ -21,6 +21,7 @@ set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
+ "${tensorflow_source_dir}/tensorflow/c/eager/c_api_debug.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/tape.h"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index a62af4a06c..ea604647fa 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -27,8 +27,15 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
+// forward declare
+struct EagerTensor;
+
namespace {
+// An instance of _EagerTensorProfiler that will receive callbacks about
+// events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
+PyObject* eager_tensor_profiler = nullptr;
+
TFE_Context* GetContext(PyObject* ctx) {
TFE_Context* context =
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
@@ -253,8 +260,45 @@ typedef struct EagerTensor {
// to use a TF_Status object. However note that accesses to `status` are not
// thread-safe.
TF_Status* status;
+
+ PyObject* weakreflist; /* List of weak references */
} EagerTensor;
+namespace {
+
+// Returns true on success - successfully invoked or no profiler registered.
+// Returns false if some error occurred.
+bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
+ if (eager_tensor_profiler != nullptr) {
+#if PY_MAJOR_VERSION < 3
+ PyObject* created_method_name = PyString_InternFromString("created");
+#else
+ PyObject* created_method_name = PyUnicode_InternFromString("created");
+#endif
+ if (created_method_name == nullptr) {
+ return false;
+ }
+ PyObject* result = PyObject_CallMethodObjArgs(
+ eager_tensor_profiler, created_method_name, created_tensor, NULL);
+ if (result == nullptr) {
+ LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
+ // While we can potentially continue because the error is related to
+ // profiling, we choose to return an error because:
+ // - If profiling is used, the user likely wants to stop execution on
+ // profiling errors.
+ // - Error in profiling code might have left some state in an invalid
+ // form that can lead to an error later on. Better to fail fast.
+ Py_DECREF(created_method_name);
+ return false;
+ }
+ Py_DECREF(created_method_name);
+ Py_DECREF(result);
+ }
+ return true;
+}
+
+} // namespace
+
// tp_init for EagerTensor.
int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
self->id = get_uid();
@@ -266,6 +310,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
Py_INCREF(Py_None);
self->tensor_shape = Py_None;
self->status = TF_NewStatus();
+ self->weakreflist = nullptr;
PyObject* value;
PyObject* context = nullptr;
PyObject* device = nullptr;
@@ -344,11 +389,22 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
if (handle == nullptr) return -1;
}
self->handle = handle.release();
+
+ if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
+ return -1;
+ }
+
return 0;
}
// tp_dealloc for EagerTensor.
void EagerTensor_dealloc(EagerTensor* self) {
+ // Clear weak references to self.
+ // Needs to happen before any actual destruction.
+ if (self->weakreflist != nullptr) {
+ PyObject_ClearWeakRefs((PyObject*)self);
+ }
+
TF_DeleteStatus(self->status);
Py_DECREF(self->handle_data);
Py_DECREF(self->keras_mask);
@@ -574,43 +630,43 @@ static PyTypeObject _EagerTensorType = {
// clang-format off
PyVarObject_HEAD_INIT(nullptr, 0)
// clang-format on
- "EagerTensor", /* tp_name */
- sizeof(EagerTensor), /* tp_basicsize */
- 0, /* tp_itemsize */
- (destructor)EagerTensor_dealloc, /* tp_dealloc */
- nullptr, /* tp_print */
- nullptr, /* tp_getattr */
- nullptr, /* tp_setattr */
- nullptr, /* tp_compare */
- nullptr, /* tp_repr */
- nullptr, /* tp_as_number */
- nullptr, /* tp_as_sequence */
- nullptr, /* tp_as_mapping */
- nullptr, /* tp_hash */
- nullptr, /* tp_call */
- nullptr, /* tp_str */
- nullptr, /* tp_getattro */
- nullptr, /* tp_setattro */
- nullptr, /* tp_as_buffer */
- Py_TPFLAGS_DEFAULT, /* tp_flags */
- nullptr, /* tp_doc */
- nullptr, /* tp_traverse */
- nullptr, /* tp_clear */
- nullptr, /* tp_richcompare */
- 0, /* tp_weaklistoffset */
- nullptr, /* tp_iter */
- nullptr, /* tp_iternext */
- EagerTensor_methods, /* tp_methods */
- nullptr, /* tp_members */
- EagerTensor_getseters, /* tp_getset */
- nullptr, /* tp_base */
- nullptr, /* tp_dict */
- nullptr, /* tp_descr_get */
- nullptr, /* tp_descr_set */
- 0, /* tp_dictoffset */
- (initproc)EagerTensor_init, /* tp_init */
- nullptr, /* tp_alloc */
- nullptr, /* tp_new */
+ "EagerTensor", /* tp_name */
+ sizeof(EagerTensor), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ (destructor)EagerTensor_dealloc, /* tp_dealloc */
+ nullptr, /* tp_print */
+ nullptr, /* tp_getattr */
+ nullptr, /* tp_setattr */
+ nullptr, /* tp_compare */
+ nullptr, /* tp_repr */
+ nullptr, /* tp_as_number */
+ nullptr, /* tp_as_sequence */
+ nullptr, /* tp_as_mapping */
+ nullptr, /* tp_hash */
+ nullptr, /* tp_call */
+ nullptr, /* tp_str */
+ nullptr, /* tp_getattro */
+ nullptr, /* tp_setattro */
+ nullptr, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT, /* tp_flags */
+ nullptr, /* tp_doc */
+ nullptr, /* tp_traverse */
+ nullptr, /* tp_clear */
+ nullptr, /* tp_richcompare */
+ offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
+ nullptr, /* tp_iter */
+ nullptr, /* tp_iternext */
+ EagerTensor_methods, /* tp_methods */
+ nullptr, /* tp_members */
+ EagerTensor_getseters, /* tp_getset */
+ nullptr, /* tp_base */
+ nullptr, /* tp_dict */
+ nullptr, /* tp_descr_get */
+ nullptr, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ (initproc)EagerTensor_init, /* tp_init */
+ nullptr, /* tp_alloc */
+ nullptr, /* tp_new */
};
#endif
@@ -641,6 +697,11 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
t->tensor_shape = Py_None;
t->handle = handle;
t->status = TF_NewStatus();
+ t->weakreflist = nullptr;
+
+ if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
+ return nullptr;
+ }
}
return reinterpret_cast<PyObject*>(t);
}
@@ -720,6 +781,18 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
return reinterpret_cast<PyObject*>(EagerTensorType);
}
+PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
+ Py_XDECREF(eager_tensor_profiler);
+
+ if (profiler == Py_None) {
+ eager_tensor_profiler = nullptr;
+ } else {
+ eager_tensor_profiler = profiler;
+ Py_INCREF(eager_tensor_profiler);
+ }
+ Py_RETURN_NONE;
+}
+
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
PyErr_SetString(PyExc_TypeError,
@@ -792,3 +865,37 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
return EagerTensorFromHandle(handle);
}
+
+PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
+ if (!EagerTensor_CheckExact(tensor)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
+ Py_TYPE(tensor)->tp_name, "\"")
+ .c_str());
+ return nullptr;
+ }
+ TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
+
+ auto status = tensorflow::make_safe(TF_NewStatus());
+ TFE_TensorDebugInfo* debug_info =
+ TFE_TensorHandleTensorDebugInfo(handle, status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
+ TF_Message(status.get()))
+ .c_str());
+ return nullptr;
+ }
+
+ int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
+ PyObject* shape = PyTuple_New(rank);
+ for (int i = 0; i < rank; ++i) {
+ tensorflow::int64 dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
+ PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
+ }
+ TFE_DeleteTensorDebugInfo(debug_info);
+
+ return shape;
+}
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 73fe80e8ca..c502fe9e85 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -115,6 +115,15 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
// newly created type, or nullptr on error.
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
+// Sets `profiler` as the current profiler to receive callbacks about events
+// on eager tensors. Currently, the only reported event is creation.
+// `profiler` is expected to have a `created(self, eager_tensor)` method that
+// takes the created tensor as its single argument.
+// Previous profiler, if any, is unset and will not receive any more
+// callbacks.
+// To unset the profiler, pass Py_None as the value of `profiler`.
+PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
+
// Creates a new tape and adds it to the active set. `persistent` must be a
// PyBool_Type, i.e either Py_True or Py_False
PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
@@ -203,4 +212,8 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
// tensors in `tensors`.
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
+// Returns the shape of this tensor's on-device representation.
+// The shape is represented as a Python tuple of integers.
+PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
+
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 42c708b024..500dc30cc3 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -33,6 +33,7 @@ limitations under the License.
%rename("%s") TFE_ContextAsyncClearError;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
+%rename("%s") TFE_Py_SetEagerTensorProfiler;
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_RegisterGradientFunction;
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
@@ -60,6 +61,7 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetAsync;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
+%rename("%s") TFE_Py_TensorShapeOnDevice;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"