aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-08-10 16:58:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-10 17:01:41 -0700
commitbb6f32fa7dad7fae416715752b10834d4a2b271a (patch)
treecc12b375d020b9302207713cc02fbc6d8c2e218f
parent9103096c12faa1fdbdf806c2422c7d84fc2d0642 (diff)
Make HloAliasAnalysis updatable after changes to the HLO graph.
As part of this change make HloAliasAnalysis a thinner layer which basically only holds a map from HloValue to HloBuffer and vice versa. PiperOrigin-RevId: 164923041
-rw-r--r--tensorflow/c/eager/BUILD67
-rw-r--r--tensorflow/c/eager/c_api.cc547
-rw-r--r--tensorflow/c/eager/c_api.h159
-rw-r--r--tensorflow/c/eager/c_api_test.cc471
-rw-r--r--tensorflow/c/eager/runtime.cc289
-rw-r--r--tensorflow/c/eager/runtime.h193
-rw-r--r--tensorflow/c/eager/runtime_test.cc160
-rw-r--r--tensorflow/core/platform/platform.h2
-rw-r--r--tensorflow/python/eager/BUILD254
-rw-r--r--tensorflow/python/eager/context.py333
-rw-r--r--tensorflow/python/eager/core.py75
-rw-r--r--tensorflow/python/eager/core_test.py483
-rw-r--r--tensorflow/python/eager/custom_gradient.py70
-rw-r--r--tensorflow/python/eager/execute.py241
-rw-r--r--tensorflow/python/eager/function.py502
-rw-r--r--tensorflow/python/eager/gen_op.bzl45
-rw-r--r--tensorflow/python/eager/graph_only_ops.py50
-rw-r--r--tensorflow/python/eager/memory_trace.py88
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.cc763
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.h39
-rw-r--r--tensorflow/python/eager/python_eager_op_gen_main.cc46
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h67
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc377
-rw-r--r--tensorflow/python/eager/tape.py240
-rw-r--r--tensorflow/python/eager/tensor.py454
-rw-r--r--tensorflow/python/eager/test.py28
26 files changed, 6042 insertions, 1 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
new file mode 100644
index 0000000000..18e1a68a93
--- /dev/null
+++ b/tensorflow/c/eager/BUILD
@@ -0,0 +1,67 @@
+# Experimental extensions to the C API for eager execution of kernels.
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "c_api",
+ srcs = ["c_api.cc"],
+ hdrs = ["c_api.h"],
+ visibility = [
+ "//tensorflow:internal",
+ "//tensorflow/python/eager:__pkg__",
+ ],
+ deps = [
+ ":runtime",
+ "//tensorflow/c:c_api",
+ "//tensorflow/c:c_api_internal",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "c_api_test",
+ srcs = ["c_api_test.cc"],
+ deps = [
+ ":c_api",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "runtime",
+ srcs = ["runtime.cc"],
+ hdrs = ["runtime.h"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ ],
+)
+
+cc_test(
+ name = "runtime_test",
+ srcs = ["runtime_test.cc"],
+ deps = [
+ ":runtime",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
new file mode 100644
index 0000000000..44e8c9f8e6
--- /dev/null
+++ b/tensorflow/c/eager/c_api.cc
@@ -0,0 +1,547 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/eager/runtime.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/version.h"
+
+using tensorflow::int64;
+using tensorflow::string;
+
+namespace {
+bool IsCPU(tensorflow::Device* d) {
+ return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
+}
+
+string DeviceName(tensorflow::Device* d) {
+ return (d == nullptr) ? "cpu:0" : d->name();
+}
+} // namespace
+
+struct TFE_Context {
+ explicit TFE_Context(TF_Session* s) : session(s) {}
+
+ // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
+ TF_Session* session;
+
+ tensorflow::mutex functions_mu;
+ tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
+ tensorflow::OpRegistry::Global(), {}};
+
+ // One FunctionLibraryRuntime per device.
+ // func_libs[i] is the FunctionLibraryRuntime corresponding to
+ // session->devices[i].
+ std::vector<std::unique_ptr<tensorflow::FunctionLibraryRuntime> > func_libs;
+
+ std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
+ tensorflow::Fprint128Hasher>
+ kernel_cache;
+
+ tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
+ for (int i = 0; i < session->devices.size(); ++i) {
+ if (session->devices[i] == d) {
+ return func_libs[i].get();
+ }
+ }
+ return nullptr;
+ }
+
+ const std::vector<tensorflow::Device*>& devices() { return session->devices; }
+};
+
+struct TFE_TensorHandle {
+ TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d)
+ : t(t), d(d) {}
+
+ tensorflow::Tensor t;
+ // TODO(ashankar): d == nullptr iff local CPU
+ // This was expedient, but perhaps worth revisiting ('d' should always be a
+ // valid pointer?)
+ // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
+ // provided with the appropriate TFE_Context.
+ //
+ // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a
+ // TFE_TensorHandle does not outlive the TFE_Context from which it came?
+ tensorflow::Device* d;
+};
+
+struct TFE_Op {
+ TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
+ : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
+
+ bool const is_function() const { return attr_types == nullptr; }
+
+ TFE_Context* ctx; // Must outlive the TFE_Op.
+ const char* name;
+ tensorflow::AttrBuilder attrs;
+ const tensorflow::AttrTypeMap* attr_types;
+ std::vector<tensorflow::Tensor> inputs;
+ std::vector<tensorflow::Device*> input_devices;
+ tensorflow::Device* device;
+};
+
+extern "C" {
+
+TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
+ TF_Graph* graph = TF_NewGraph();
+ TF_Session* session = TF_NewSession(graph, opts, status);
+ if (status->status.ok()) {
+ if (session->device_mgr == nullptr || session->devices.empty()) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Provided TF_SessionOptions are not compatible with eager execution "
+ "(perhaps the TF_SessionOptions alluded to session execution in a "
+ "remote address space?)");
+ }
+ }
+ if (!status->status.ok()) {
+ TF_DeleteGraph(graph);
+ return nullptr;
+ }
+
+ TFE_Context* ret = new TFE_Context(session);
+ ret->func_libs.resize(ret->devices().size());
+ for (int i = 0; i < ret->devices().size(); ++i) {
+ ret->func_libs[i] = tensorflow::NewFunctionLibraryRuntime(
+ ret->session->device_mgr, opts->options.env, ret->devices()[i],
+ TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {});
+ }
+
+ return ret;
+}
+
+void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
+ status->status = tensorflow::Status::OK();
+ tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+ TF_Graph* graph = ctx->session->graph;
+ TF_DeleteSession(ctx->session, status);
+ TF_DeleteGraph(graph);
+ delete ctx;
+}
+
+TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
+ return TF_SessionListDevices(ctx->session, status);
+}
+
+TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) {
+ return new TFE_TensorHandle(
+ tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer),
+ nullptr);
+}
+
+void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
+
+TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
+ return static_cast<TF_DataType>(h->t.dtype());
+}
+
+int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
+
+int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
+ return h->t.dim_size(dim_index);
+}
+
+const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
+ // This might be a bit confusing as a tensor on CPU can sometimes return
+ // "CPU:0" and sometimes "/job:localhost/replica:0/task:0/cpu:0".
+ // TODO(ashankar): Figure out which one would be nicer.
+ return (h->d == nullptr) ? "CPU:0" : h->d->name().c_str();
+}
+
+TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
+ if (!IsCPU(h->d)) {
+ TF_SetStatus(status, TF_UNIMPLEMENTED,
+ tensorflow::strings::StrCat(
+ "TFE_TensorHandle can be resolved iff it is on CPU (this "
+ "handle is on ",
+ h->d->name(),
+ "). Consider using TFE_TensorHandleCopyToDevice to get a "
+ "copy of the tensor on CPU")
+ .c_str());
+ return nullptr;
+ }
+ return tensorflow::TF_TensorFromTensor(h->t, status);
+}
+
+TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
+ TFE_Context* ctx,
+ const char* device_name,
+ TF_Status* status) {
+ tensorflow::Device* dstd = nullptr;
+ status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd);
+ if (!status->status.ok()) return nullptr;
+
+ tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
+ bool is_same_device =
+ (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
+ const bool dst_cpu = IsCPU(dstd);
+ if (is_same_device) {
+ return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
+ }
+ const bool src_cpu = IsCPU(srcd);
+ if (src_cpu == dst_cpu) {
+ TF_SetStatus(
+ status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat(
+ "TFE_TensorHandleCopyToDevice requires either the source "
+ "TFE_TensorHandle be on or the destination device be on CPU "
+ "or be the same (they are ",
+ DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
+ .c_str());
+ return nullptr;
+ }
+ tensorflow::Tensor* src = &(h->t);
+ if (src_cpu) {
+ tensorflow::Tensor dst(
+ dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
+ src->shape());
+ tensorflow::Notification n;
+ dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice(
+ src, dstd, &dst, [status, &n](const tensorflow::Status& s) {
+ status->status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd)
+ : nullptr;
+ }
+ CHECK(dst_cpu);
+ tensorflow::Tensor dst(src->dtype(), src->shape());
+ tensorflow::Notification n;
+ // TODO(ashankar): The Sync() call below may be more aggressive than
+ // necessary. It is based on knowledge of implementation details - that
+ // GPU devices are implemented using 3 streams - one for host->device copies,
+ // one for device->host copies and one for sending operations to the GPU.
+ // With that setup, Sync()ing across all 3 streams should be sufficient
+ // but more than necessary (since it waits for operations that might have
+ // nothing to do with this tensor to complete).
+ status->status = srcd->Sync();
+ if (!status->status.ok()) return nullptr;
+ srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU(
+ src, "IGNORE_MY_TENSOR_NAME", srcd, &dst,
+ [status, &n](const tensorflow::Status& s) {
+ status->status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr)
+ : nullptr;
+}
+
+TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
+ TF_Status* status) {
+ const char* name = op_or_function_name; // Shorthand
+ const tensorflow::AttrTypeMap* types;
+ status->status = tensorflow::AttrTypeMapForOp(name, &types);
+ if (status->status.ok()) return new TFE_Op(ctx, name, types);
+ if (TF_GetCode(status) == TF_NOT_FOUND) {
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ if (ctx->func_lib_def.Find(name) != nullptr) {
+ status->status = tensorflow::Status::OK();
+ return new TFE_Op(ctx, name, nullptr);
+ }
+ }
+ return nullptr;
+}
+
+void TFE_DeleteOp(TFE_Op* op) { delete op; }
+
+static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device,
+ TF_Status* status) {
+ // Questionable heuristic: Place the op on the same device as the first input
+ // placed outside of host memory?
+ if (IsCPU(op->device) && !IsCPU(device)) {
+ op->device = device;
+ }
+}
+
+void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, const char* device_name,
+ TF_Status* status) {
+ tensorflow::Device* d = nullptr;
+ status->status = ctx->session->device_mgr->LookupDevice(device_name, &d);
+ if (!status->status.ok()) return;
+ TFE_OpSetDeviceHelper(op, d, status);
+}
+
+void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
+ TFE_OpSetDeviceHelper(op, h->d, status);
+ if (!status->status.ok()) return;
+ op->inputs.push_back(h->t);
+ op->input_devices.push_back(h->d);
+ op->attrs.NumInputs(op->inputs.size());
+}
+
+TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
+ unsigned char* is_list, TF_Status* status) {
+ TF_AttrType ret;
+ if (op->is_function()) {
+ status->status = tensorflow::errors::Unimplemented(
+ "TODO(apassos): Support for attributes for TensorFlow functions is not "
+ "ready yet.");
+ return TF_ATTR_INT; // The compiler requires that we return something.
+ }
+ status->status =
+ tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list);
+ return ret;
+}
+
+void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
+ op->attrs.Set(attr_name, value);
+}
+
+void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
+ op->attrs.Set(attr_name, static_cast<int64>(value));
+}
+
+void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
+ op->attrs.Set(attr_name, value);
+}
+
+void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
+ op->attrs.Set(attr_name, (value == 0) ? false : true);
+}
+
+void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
+ op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
+}
+
+void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
+ const int num_dims, TF_Status* out_status) {
+ if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
+ TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat(
+ "Value specified for `", attr_name, "` has ", num_dims,
+ " dimensions which is over the limit of ",
+ tensorflow::TensorShape::MaxDimensions(), ".")
+ .c_str());
+ return;
+ }
+ tensorflow::TensorShapeProto proto;
+ if (num_dims < 0) {
+ proto.set_unknown_rank(true);
+ } else {
+ for (int d = 0; d < num_dims; ++d) {
+ proto.add_dim()->set_size(dims[d]);
+ }
+ }
+ op->attrs.Set(attr_name, proto);
+}
+
+#define TFE_OP_SET_ATTR_LIST(fn, type) \
+ void fn(TFE_Op* op, const char* attr_name, const type* values, \
+ int num_values) { \
+ op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
+ values, num_values)); \
+ }
+TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
+TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
+#undef TFE_OP_SET_ATTR_LIST
+
+void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
+ const int64_t* values, int num_values) {
+ op->attrs.Set(attr_name,
+ tensorflow::gtl::ArraySlice<const int64>(
+ reinterpret_cast<const int64*>(values), num_values));
+}
+
+void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
+ const TF_DataType* values, int num_values) {
+ op->attrs.Set(
+ attr_name,
+ tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
+ reinterpret_cast<const tensorflow::DataType*>(values), num_values));
+}
+
+void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
+ const unsigned char* values, int num_values) {
+ std::unique_ptr<bool[]> b(new bool[num_values]);
+ for (int i = 0; i < num_values; ++i) {
+ b[i] = values[i];
+ }
+ op->attrs.Set(attr_name,
+ tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
+}
+
+void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
+ const int64_t** dims, const int* num_dims,
+ int num_values, TF_Status* out_status) {
+ std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
+ new tensorflow::TensorShapeProto[num_values]);
+ for (int i = 0; i < num_values; ++i) {
+ const auto num_dims_i = num_dims[i];
+
+ if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
+ TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat(
+ "Value specified for `", attr_name, "` has ", num_dims_i,
+ " dimensions which is over the limit of ",
+ tensorflow::TensorShape::MaxDimensions(), ".")
+ .c_str());
+ return;
+ }
+ if (num_dims_i < 0) {
+ proto[i].set_unknown_rank(true);
+ } else {
+ const int64_t* dims_i = dims[i];
+ auto proto_i = &proto[i];
+ for (int d = 0; d < num_dims_i; ++d) {
+ proto_i->add_dim()->set_size(dims_i[d]);
+ }
+ }
+ }
+ op->attrs.Set(attr_name,
+ tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
+ proto.get(), num_values));
+}
+
+namespace {
+
+tensorflow::Status ValidateInputTypeAndPlacement(
+ tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op,
+ const tensorflow::OpKernel* kernel) {
+ const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
+ if (memtypes.size() != op->inputs.size()) {
+ return tensorflow::errors::InvalidArgument(
+ "expected ", memtypes.size(), " inputs, got ", op->inputs.size());
+ }
+ for (int i = 0; i < op->inputs.size(); ++i) {
+ const tensorflow::Device* expected_device =
+ memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
+ const tensorflow::Device* actual_device =
+ op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
+ if (expected_device != actual_device) {
+ return tensorflow::errors::InvalidArgument(
+ "cannot compute ", op->name, " as input #", i,
+ " was expected to be on ", expected_device->name(),
+ " but is actually on ", actual_device->name(),
+ " (operation running on ", op_device->name(), ")");
+ }
+ if (op->inputs[i].dtype() != kernel->input_type(i)) {
+ return tensorflow::errors::InvalidArgument(
+ "cannot compute ", op->name, " as input #", i,
+ " was expected to be a ",
+ tensorflow::DataType_Name(kernel->input_type(i)), " tensor but is a ",
+ tensorflow::DataType_Name(op->inputs[i].dtype()), " tensor");
+ }
+ }
+ return tensorflow::Status::OK();
+}
+} // namespace
+
+void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
+ TF_Status* status) {
+ TFE_Context* ctx = op->ctx;
+ // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
+ tensorflow::Device* device =
+ (op->device == nullptr) ? ctx->devices()[0] : op->device;
+ std::vector<tensorflow::Tensor> outputs(1);
+ const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
+ tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
+ tensorflow::KernelAndDevice* kernel =
+ tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
+ if (kernel == nullptr) {
+ const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
+ kernel = new tensorflow::KernelAndDevice();
+ if (!op->is_function()) {
+ status->status =
+ tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
+ } else {
+ // Knowledge of the implementation of InitFn (and in-turn
+ // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
+ // will be accessed, so grab on to the lock.
+ // See WARNING comment below - would be nice to rework to avoid this
+ // subtlety.
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ status->status = tensorflow::KernelAndDevice::InitFn(
+ ndef, ctx->func_lib(device), kernel);
+ }
+ if (!status->status.ok()) {
+ return;
+ }
+ tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
+ }
+ status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op,
+ kernel->kernel());
+ output_memory_types = &kernel->kernel()->output_memory_types();
+ if (!status->status.ok()) {
+ return;
+ }
+ // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
+ // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
+ // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
+ // of FunctionLibraryRuntime tells use that func_lib_def is not accessed by
+ // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
+ // This is quite subtle. Re-work things to make this better? (Would it make
+ // sense for FunctionLibraryRuntime to ensure thread-safe access to
+ // FunctionLibraryDefinition?).
+ status->status = kernel->Run(&op->inputs, &outputs);
+ if (!status->status.ok()) return;
+ *num_retvals = std::min<int>(*num_retvals, outputs.size());
+ for (int i = 0; i < *num_retvals; ++i) {
+ tensorflow::Device* d = IsCPU(device) ? nullptr : device;
+ if (d != nullptr && output_memory_types != nullptr &&
+ (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
+ d = nullptr;
+ }
+ retvals[i] = new TFE_TensorHandle(outputs[i], d);
+ }
+}
+
+void TFE_ContextAddFunctionDef(TFE_Context* ctx,
+ const char* serialized_function_def, size_t size,
+ TF_Status* status) {
+ tensorflow::FunctionDef function_def;
+ if (!function_def.ParseFromArray(serialized_function_def, size)) {
+ status->status =
+ tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
+ return;
+ }
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ status->status = ctx->func_lib_def.AddFunctionDef(function_def);
+}
+
+} // extern "C"
+
+TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
+ return new TFE_TensorHandle(t, nullptr);
+}
+
+const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
+ TFE_TensorHandle* h, TF_Status* status) {
+ if (h->d != nullptr) {
+ status->status = tensorflow::errors::FailedPrecondition(
+ "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
+ "a tensorflow::Tensor");
+ return nullptr;
+ }
+ return &h->t;
+}
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
new file mode 100644
index 0000000000..476c9288f8
--- /dev/null
+++ b/tensorflow/c/eager/c_api.h
@@ -0,0 +1,159 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EAGER_C_API_H_
+#define TENSORFLOW_C_EAGER_C_API_H_
+
+// C API extensions to experiment with eager execution of kernels.
+
+#include "tensorflow/c/c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// "Context" under which operations/functions are executed. It encapsulates
+// things like the available devices, resource manager etc.
+//
+// TODO(ashankar): Merge with TF_Session?
+typedef struct TFE_Context TFE_Context;
+
+extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
+ TF_Status* status);
+extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
+extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
+ TF_Status* status);
+
+// A handle to a tensor on a device.
+//
+// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
+// type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors
+// placed in memory of different devices or remote address spaces.
+typedef struct TFE_TensorHandle TFE_TensorHandle;
+
+extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
+extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
+extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
+extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
+extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
+extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
+extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
+ TF_Status* status);
+
+// Create a new TFE_TensorHandle with the same contents as 'h' but placed
+// in the memory of the device name 'device_name'.
+// If source and destination are the same device, then this creates a new handle
+// that shares the underlying buffer. Otherwise, it currently requires at least
+// one of the source or destination devices to be CPU (i.e., for the source or
+// destination tensor to be placed in host memory).
+extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
+ TFE_Context* ctx,
+ const char* device_name,
+ TF_Status* status);
+
+// Description of the TensorFlow op to execute.
+//
+// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e.,
+// TFE_DeleteOp() is called before TFE_DeleteContext().
+//
+// Very similar to TF_OperationDescription with some differences:
+// (1) TF_Output or TFE_TensorHandle* as arguments to TF_AddInput,
+// TF_AddInputList
+// (2) TF_ColocateWith, TF_AddControlInput etc. do not make sense.
+// (3) Implementation detail: Avoid use of NodeBuilder/NodeDefBuilder since
+// the additional sanity checks there seem unnecessary;
+typedef struct TFE_Op TFE_Op;
+
+extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
+ TF_Status* status);
+extern void TFE_DeleteOp(TFE_Op* op);
+
+// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context
+// parameter. Instead, the TFE_Context should be captured when creating the
+// TFE_Op.
+extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
+ const char* device_name, TF_Status* status);
+
+extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
+
+extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
+ unsigned char* is_list, TF_Status* status);
+
+extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
+ const char* value);
+extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
+extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
+extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
+ unsigned char value);
+extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
+ TF_DataType value);
+// If the number of dimensions is unknown, `num_dims` must be set to
+// -1 and `dims` can be null. If a dimension is unknown, the
+// corresponding entry in the `dims` array must be -1.
+extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
+ const int64_t* dims, const int num_dims,
+ TF_Status* out_status);
+
+extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
+ const char** value, int num_values);
+extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
+ const int64_t* values, int num_values);
+extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
+ const float* values, int num_values);
+extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
+ const unsigned char* values, int num_values);
+extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
+ const TF_DataType* values, int num_values);
+extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
+ const int64_t** dims, const int* num_dims,
+ int num_values, TF_Status* out_status);
+
+// Execute the operation defined by 'op' and return handles to computed
+// tensors in 'retvals'.
+//
+// 'retvals' must point to a pre-allocated array of TFE_TensorHandle*
+// and '*num_retvals' should be set to the size of this array.
+//
+// On return, 'num_retvals' will be set to the actual number of outputs
+// returned by the operation.
+extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
+ int* num_retvals, TF_Status* status);
+
+// Add a function (serialized FunctionDef protocol buffer) to ctx so
+// that it can be invoked using TFE_Execute.
+extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
+ const char* serialized_function_def,
+ size_t size, TF_Status* status);
+
+#ifdef __cplusplus
+} /* end extern "C" */
+#endif
+
+#ifdef __cplusplus
+// A workaround to ease conversion to and from numpy objects and
+// TFE_TensorHandle's.
+//
+// TODO(ashankar): Figure out an alternative scheme that precludes the need for
+// these API-boundary breaking methods.
+namespace tensorflow {
+class Tensor;
+} // namespace tensorflow
+
+const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
+ TFE_TensorHandle* h, TF_Status* status);
+TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
+#endif
+
+#endif // TENSORFLOW_C_EAGER_C_API_H_
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
new file mode 100644
index 0000000000..6614df78d9
--- /dev/null
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -0,0 +1,471 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include <string.h>
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+using tensorflow::string;
+
+namespace {
+
+TFE_TensorHandle* TestMatrixTensorHandle() {
+ int64_t dims[] = {2, 2};
+ float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t);
+ TF_DeleteTensor(t);
+ return th;
+}
+
+TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, a, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, b, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrBool(op, "transpose_a", 0);
+ TFE_OpSetAttrBool(op, "transpose_b", 0);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
+
+ return op;
+}
+
+// TODO(apassos) uncomment after rewriting to use the right benchmark API
+// void BM_InitOp(benchmark::State& state) {
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// TFE_TensorHandle* m = TestMatrixTensorHandle();
+// for (auto _ : state) {
+// TFE_Op* matmul = MatMulOp(ctx, m, m);
+// TFE_DeleteOp(matmul);
+// }
+// TFE_DeleteTensorHandle(m);
+// TFE_DeleteContext(ctx, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+// BENCHMARK(BM_InitOp);
+
+// void BM_Execute(benchmark::State& state) {
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// TFE_TensorHandle* m = TestMatrixTensorHandle();
+// TFE_Op* matmul = MatMulOp(ctx, m, m);
+// TFE_TensorHandle* retvals[1];
+// int num_retvals = 1;
+// for (auto _ : state) {
+// TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// }
+// TFE_DeleteOp(matmul);
+// TFE_DeleteTensorHandle(m);
+// TFE_DeleteContext(ctx, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+// BENCHMARK(BM_Execute);
+
+TEST(CAPI, Context) {
+ TF_Status* status = TF_NewStatus();
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ TF_DeleteSessionOptions(opts);
+
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_DeleteContext(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ const int num_devices = TF_DeviceListCount(devices);
+ EXPECT_GE(num_devices, 1) << "At least one CPU device should exist";
+ for (int i = 0; i < num_devices; ++i) {
+ EXPECT_NE("", TF_DeviceListName(devices, i, status)) << i;
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ }
+ TF_DeleteDeviceList(devices);
+ TF_DeleteStatus(status);
+}
+
+TEST(CAPI, TensorHandle) {
+ TFE_TensorHandle* h = TestMatrixTensorHandle();
+ EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
+
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
+ ASSERT_EQ(16, TF_TensorByteSize(t));
+ float data[4] = {0};
+ memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
+ EXPECT_EQ(1.0, data[0]);
+ EXPECT_EQ(2.0, data[1]);
+ EXPECT_EQ(3.0, data[2]);
+ EXPECT_EQ(4.0, data[3]);
+ TF_DeleteTensor(t);
+ TFE_DeleteTensorHandle(h);
+}
+
+TEST(CAPI, TensorHandleCopyBetweenDevices) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status.get());
+ TF_DeleteSessionOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+ TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ const int num_devices = TF_DeviceListCount(devices);
+
+ const char* kCPUDevice = "CPU:0";
+ for (int i = 0; i < num_devices; ++i) {
+ const string name(TF_DeviceListName(devices, i, status.get()));
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << i << " -- " << TF_Message(status.get());
+ continue;
+ }
+ auto tag = tensorflow::strings::StrCat("Device #", i, " (", name, ")");
+ // Copy to device
+ TFE_TensorHandle* hdevice =
+ TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
+ continue;
+ }
+ // Copy from device to the same device.
+ TFE_TensorHandle* hdevice2 =
+ TFE_TensorHandleCopyToDevice(hdevice, ctx, name.c_str(), status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
+ continue;
+ }
+ TFE_DeleteTensorHandle(hdevice);
+ // Copy back to CPU
+ TFE_TensorHandle* hcopy =
+ TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
+ continue;
+ }
+ TFE_DeleteTensorHandle(hdevice2);
+
+ // Ensure that the contents are the same!
+ TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
+ TFE_DeleteTensorHandle(hcopy);
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << tag;
+ continue;
+ }
+ EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)) << tag;
+ EXPECT_EQ(
+ 0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)))
+ << tag;
+ TF_DeleteTensor(tcopy);
+ }
+
+ TF_DeleteDeviceList(devices);
+ TF_DeleteTensor(t);
+ TFE_DeleteTensorHandle(hcpu);
+ TFE_DeleteContext(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+}
+
+TEST(CAPI, Execute) {
+ TF_Status* status = TF_NewStatus();
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteSessionOptions(opts);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+ TFE_Op* matmul = MatMulOp(ctx, m, m);
+ TFE_TensorHandle* retvals[2] = {nullptr};
+ int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteOp(matmul);
+ TFE_DeleteTensorHandle(m);
+ TFE_DeleteContext(ctx, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ TFE_DeleteTensorHandle(retvals[0]);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+ TF_DeleteStatus(status);
+}
+
+string MatMulFunction() {
+ tensorflow::FunctionDef def;
+ CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
+ " signature {"
+ " name: 'MatMulFunction'"
+ " input_arg {"
+ " name: 'a'"
+ " type: DT_FLOAT"
+ " }"
+ " output_arg {"
+ " name: 'm'"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " node_def {"
+ " name: 'matmul'"
+ " op: 'MatMul'"
+ " input: 'a'"
+ " input: 'a'"
+ " attr {"
+ " key: 'T'"
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " }"
+ " ret {"
+ " key: 'm'"
+ " value: 'matmul:product'"
+ " }",
+ &def));
+ return def.SerializeAsString();
+}
+
+TEST(CAPI, FunctionDefAndExecute) {
+ TF_Status* status = TF_NewStatus();
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteSessionOptions(opts);
+
+ string function_def = MatMulFunction();
+ TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
+ status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+ TFE_TensorHandle* retval[1] = {nullptr};
+ int num_retvals = 1;
+ TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, m, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_Execute(op, &retval[0], &num_retvals, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+ TFE_DeleteOp(op);
+ TFE_DeleteTensorHandle(m);
+ TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
+ TFE_DeleteTensorHandle(retval[0]);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+ TFE_DeleteContext(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+}
+
+// TODO(apassos) uncomment after rewriting to use the right benchmark API
+// void BM_ExecuteFunction(benchmark::State& state) {
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// string function_def = MatMulFunction();
+// TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
+// status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+// TFE_TensorHandle* m = TestMatrixTensorHandle();
+// TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TFE_OpAddInput(matmul, m, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TFE_TensorHandle* retval[1] = {nullptr};
+// int num_retvals = 1;
+// for (auto _ : state) {
+// TFE_Execute(matmul, &retval[0], &num_retvals, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// }
+// TFE_DeleteTensorHandle(m);
+// TFE_DeleteTensorHandle(retval[0]);
+// TFE_DeleteContext(ctx, status);
+// EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+// BENCHMARK(BM_ExecuteFunction);
+
+// TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
+// TF_Status* status) {
+// // Create the variable handle.
+// TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+// TFE_OpSetAttrShape(op, "shape", {}, 0, status);
+// TFE_OpSetAttrString(op, "container", "");
+// TFE_OpSetAttrString(op, "shared_name", "");
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// TFE_TensorHandle* var_handle = nullptr;
+// int num_retvals = 1;
+// TFE_Execute(op, &var_handle, &num_retvals, status);
+// TFE_DeleteOp(op);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// CHECK_EQ(1, num_retvals);
+
+// // Assign 'value' to it.
+// op = TFE_NewOp(ctx, "AssignVariableOp", status);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+// TFE_OpAddInput(op, var_handle, status);
+
+// // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
+// std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
+// TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)),
+// TF_DeleteTensor);
+// memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
+
+// std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+// value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
+
+// TFE_OpAddInput(op, value_handle.get(), status);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+
+// num_retvals = 0;
+// TFE_Execute(op, nullptr, &num_retvals, status);
+// TFE_DeleteOp(op);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// CHECK_EQ(0, num_retvals);
+
+// return var_handle;
+// }
+
+// TEST(CAPI, Variables) {
+// // Variables use resource handles, so this is really a test for resource
+// // tensor handling.
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+// TFE_OpAddInput(op, var_handle, status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// int num_retvals = 1;
+// TFE_TensorHandle* value_handle = nullptr;
+// TFE_Execute(op, &value_handle, &num_retvals, status);
+// TFE_DeleteOp(op);
+
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// ASSERT_EQ(1, num_retvals);
+// EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
+// EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle));
+// float value = 0.0f;
+// TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
+// memcpy(&value, TF_TensorData(t), sizeof(float));
+// TF_DeleteTensor(t);
+// EXPECT_EQ(12.0, value);
+
+// TFE_DeleteTensorHandle(var_handle);
+// TFE_DeleteTensorHandle(value_handle);
+// TFE_DeleteContext(ctx, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+
+// void BM_ReadVariable(benchmark::State& state) {
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+// TFE_OpAddInput(op, var_handle, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+// int num_retvals = 1;
+// TFE_TensorHandle* h = nullptr;
+// for (auto _ : state) {
+// TFE_Execute(op, &h, &num_retvals, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// CHECK_EQ(1, num_retvals);
+// CHECK(h);
+// CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
+// CHECK_EQ(0, TFE_TensorHandleNumDims(h));
+// h = nullptr;
+// }
+// TFE_DeleteOp(op);
+
+// TFE_DeleteTensorHandle(var_handle);
+// TFE_DeleteContext(ctx, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+// BENCHMARK(BM_ReadVariable);
+
+} // namespace
diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc
new file mode 100644
index 0000000000..87e9f5377f
--- /dev/null
+++ b/tensorflow/c/eager/runtime.cc
@@ -0,0 +1,289 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/runtime.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+namespace tensorflow {
+namespace {
+
+mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
+
+std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
+ static auto* const m = new std::unordered_map<string, const AttrTypeMap*>;
+ return m;
+}
+
+const uint32 kIsList = 1U << 31;
+
+} // namespace
+
+Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
+ mutex_lock l(g_op_name_to_attr_type_map_lock);
+ *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
+ if (*out != nullptr) return Status::OK();
+ const OpRegistrationData* op_reg_data = nullptr;
+ Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
+ if (!s.ok()) return s;
+ std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
+ // TODO(agarwal): Avoid having to create this "registry" at runtime,
+ // perhaps can be done at op registration time?
+ for (const auto& attr : op_reg_data->op_def.attr()) {
+ string type = attr.type();
+ const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
+ if (is_list) {
+ type = type.substr(5, type.length() - 6);
+ }
+ uint32 t = is_list ? kIsList : 0;
+ if (type == "string") {
+ t |= TF_ATTR_STRING;
+ } else if (type == "int") {
+ t |= TF_ATTR_INT;
+ } else if (type == "float") {
+ t |= TF_ATTR_FLOAT;
+ } else if (type == "bool") {
+ t |= TF_ATTR_BOOL;
+ } else if (type == "type") {
+ t |= TF_ATTR_TYPE;
+ } else if (type == "shape") {
+ t |= TF_ATTR_SHAPE;
+ } else if (type == "tensor") {
+ t |= TF_ATTR_TENSOR;
+ } else {
+ return errors::Unimplemented(
+ "TODO(agarwal): Enable support for ops with attributes of type '",
+ type, "'");
+ }
+ gtl::InsertIfNotPresent(m.get(), attr.name(), t);
+ }
+ *out = m.get();
+ (*OpNameToAttrTypeMap())[op_name] = m.release();
+ return Status::OK();
+}
+
+Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
+ TF_AttrType* out, unsigned char* is_list) {
+ CHECK(m);
+ auto* t = gtl::FindOrNull(*m, attr_name);
+ if (t == nullptr) {
+ return errors::InvalidArgument("Attribute '", attr_name,
+ "' does not exist for this operation");
+ }
+ *out = static_cast<TF_AttrType>(*t & ~kIsList);
+ if (*t & kIsList) {
+ *is_list = 1;
+ } else {
+ *is_list = 0;
+ }
+ return Status::OK();
+}
+
+#define DEFINE_SET_ATTR(value_type, value_field) \
+ template <> \
+ AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
+ value_field.push_back(std::make_pair(attr_name, value)); \
+ return *this; \
+ }
+
+DEFINE_SET_ATTR(StringPiece, string_attrs_);
+DEFINE_SET_ATTR(float, float_attrs_);
+DEFINE_SET_ATTR(int, int_attrs_);
+DEFINE_SET_ATTR(bool, bool_attrs_);
+DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
+
+#undef DEFINE_SET_ATTR
+
+AttrBuilder& AttrBuilder::NumInputs(int n) {
+ DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
+ num_inputs_ = n;
+ return *this;
+}
+
+const NodeDef& AttrBuilder::BuildNodeDef() {
+ if (node_def_finalized_) return *node_def_;
+ MayBeInitializeNodeDef();
+ for (int i = 0; i < num_inputs_; ++i) {
+ node_def_->add_input("dummy_input");
+ }
+ for (const auto& p : string_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ for (const auto& p : int_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ for (const auto& p : float_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ for (const auto& p : bool_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ for (const auto& p : type_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ node_def_finalized_ = true;
+ return *node_def_;
+}
+
+namespace {
+inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
+ const tensorflow::Fprint128& b) {
+ return {tensorflow::FingerprintCat64(a.low64, b.low64),
+ tensorflow::FingerprintCat64(a.low64, b.low64)};
+}
+
+void CombineUnordered(const tensorflow::Fprint128& a,
+ tensorflow::Fprint128* b) {
+ b->low64 += a.low64;
+ b->high64 += a.high64;
+}
+
+inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
+ const tensorflow::Fprint128& b) {
+ // TODO(agarwal): avoid ToString().
+ tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
+ return FingerprintCat128(a, b);
+}
+
+inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
+ return CacheKeyHelper(s, {b, b});
+}
+
+} // namespace
+
+tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
+ tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
+ f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
+ if (node_def_ != nullptr) {
+ // Some attributes are directly written to node_def_ instead of being
+ // stored explicitly.
+ string value;
+ for (const auto& attr : node_def_->attr()) {
+ attr.second.SerializeToString(&value);
+ CombineUnordered(
+ CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
+ }
+ // Note that node_def_ may be created but not finalized. This can happen
+ // when the creation was triggered by a call to Set, but BuildNodeDef has
+ // not been called.
+ if (node_def_finalized_) return f;
+ }
+ for (const auto& p : string_attrs_) {
+ // TODO(agarwal): avoid ToString().
+ CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128(
+ p.second.ToString())),
+ &f);
+ }
+ for (const auto& p : int_attrs_) {
+ CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
+ &f);
+ }
+ static std::hash<float> float_hasher;
+ for (const auto& p : float_attrs_) {
+ CombineUnordered(
+ CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
+ &f);
+ }
+ for (const auto& p : bool_attrs_) {
+ CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
+ }
+ for (const auto& p : type_attrs_) {
+ CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
+ &f);
+ }
+ return f;
+}
+
+void AttrBuilder::MayBeInitializeNodeDef() {
+ if (node_def_ == nullptr) {
+ node_def_.reset(new NodeDef());
+ node_def_->set_name(op_name_);
+ node_def_->set_op(op_name_);
+ }
+}
+
+// static
+Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
+ KernelAndDevice* out) {
+ OpKernel* k = nullptr;
+ Status s = CreateOpKernel(device->device_type().c_str(), device,
+ device->GetAllocator(AllocatorAttributes()),
+ nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
+ out->device_ = device;
+ out->kernel_.reset(k);
+ out->flib_ = nullptr;
+ return s;
+}
+
+// static
+Status KernelAndDevice::InitFn(const NodeDef& ndef,
+ FunctionLibraryRuntime* flib,
+ KernelAndDevice* out) {
+ OpKernel* k = nullptr;
+ Status s = flib->CreateKernel(ndef, &k);
+ out->device_ = flib->device();
+ out->kernel_.reset(k);
+ out->flib_ = flib;
+ return s;
+}
+
+Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
+ std::vector<Tensor>* output_tensors) {
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ for (Tensor& t : *input_tensors) {
+ inputs.push_back(TensorValue(&t));
+ }
+
+ std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
+ for (size_t i = 0; i < out_attrs.size(); ++i) {
+ out_attrs[i].set_on_host(kernel_->output_memory_types()[i] ==
+ tensorflow::HOST_MEMORY);
+ }
+
+ OpKernelContext::Params params;
+ params.device = device_;
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = kernel_.get();
+ params.resource_manager = device_->resource_manager();
+ params.output_attr_array = gtl::vector_as_array(&out_attrs);
+ params.function_library = flib_;
+ params.slice_reader_cache = &slice_reader_cache_;
+ // TODO(apassos): use a thread pool.
+ std::function<void(std::function<void()>)> runner =
+ [](std::function<void()> f) { f(); };
+ params.runner = &runner;
+
+ OpKernelContext context(&params);
+ device_->Compute(kernel_.get(), &context);
+ if (!context.status().ok()) return context.status();
+
+ output_tensors->clear();
+ for (int i = 0; i < context.num_outputs(); ++i) {
+ output_tensors->push_back(Tensor(*context.mutable_output(i)));
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h
new file mode 100644
index 0000000000..5916302ff4
--- /dev/null
+++ b/tensorflow/c/eager/runtime.h
@@ -0,0 +1,193 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EAGER_RUNTIME_H_
+#define TENSORFLOW_C_EAGER_RUNTIME_H_
+
+// Support for eager execution of TensorFlow kernels.
+
+#include <memory>
+#include <unordered_map>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+namespace tensorflow {
+
+// Maps attribute name to an encoding of the type of the attribute value.
+// If the type is not a list type, the value is the same as the TF_AttrType type
+// of the value. Else, the highest order bit is on, and the rest of the bits
+// represent the TF_AttrType type of the values in the list.
+typedef std::unordered_map<string, uint32> AttrTypeMap;
+
+// Returns the AttrTypeMap for the TensorFlow operation named op_name.
+Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
+
+// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
+Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
+ TF_AttrType* out, unsigned char* is_list);
+
+// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
+// An AttrBuilder is a convenience class to help with that - providing a smaller
+// interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
+// checks (like number of inputs matching the OpDef - we only care about
+// attributes here).
+//
+// TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
+// ones make sense to replicate.
+
+// This is a helper class for creating a NodeDef. Additionally, this class
+// allows computing a cache key based on fingerprinting the attributes of this
+// NodeDef.
+//
+// Example usage:
+// AttrBuilder a;
+// a.NumInputs(2);
+// a.Set("T", TF_FLOAT);
+// uint64 cache_key = a.CacheKey("cpu:0");
+// const NodeDef& n = a.BuildNodeDef();
+//
+// Note that all calls to Set and NumInputs should happen before calling
+// BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
+// to CacheKey may cause different values to be returned by CacheKey.
+//
+// For performance reasons, the class internally delays the actual construction
+// of the NodeDef till BuildNodeDef is called, or Set is called with certain
+// uncommon types (see template specializations of Set to see which types
+// trigger a NodeDef creation).
+class AttrBuilder {
+ public:
+ explicit AttrBuilder(const char* op)
+ : op_name_(op),
+ num_inputs_(0),
+ node_def_(nullptr),
+ node_def_finalized_(false) {}
+
+ // Needed to work around call to ValidateNodeDef in CreateOpKernel.
+ AttrBuilder& NumInputs(int n);
+
+ template <class T>
+ AttrBuilder& Set(StringPiece attr_name, T&& value) {
+ MayBeInitializeNodeDef();
+ return SetInNodeDef(attr_name, value);
+ }
+
+ tensorflow::Fprint128 CacheKey(const string& device) const;
+
+ const NodeDef& BuildNodeDef();
+
+ private:
+ template <class T>
+ using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
+
+ void MayBeInitializeNodeDef();
+
+ template <class T>
+ AttrBuilder& SetInNodeDef(StringPiece attr_name, T&& value) {
+ DCHECK(!node_def_finalized_) << "Calling SetInNodeDef after BuildNodeDef.";
+ // Copied from NodeDefBuilder::Attr
+ const AttrValue* found = AttrSlice(*node_def_).Find(attr_name);
+ if (found == nullptr) {
+ AddNodeAttr(attr_name, std::forward<T>(value), node_def_.get());
+ } else {
+ AttrValue attr_value;
+ SetAttrValue(std::forward<T>(value), &attr_value);
+ // TODO(ashankar): Do what is done in
+ // NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value);
+ }
+ return *this;
+ }
+
+ AttrVec<StringPiece> string_attrs_;
+ AttrVec<int> int_attrs_;
+ AttrVec<float> float_attrs_;
+ AttrVec<bool> bool_attrs_;
+ AttrVec<tensorflow::DataType> type_attrs_;
+ string op_name_;
+ int num_inputs_;
+ std::unique_ptr<NodeDef> node_def_;
+ bool node_def_finalized_;
+}; // namespace tensorflow
+
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
+ tensorflow::DataType&& value);
+
+// KernelAndDevice encapsulates an instantiated kernel and the device it is on.
+//
+// Also see:
+// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
+// and
+// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
+class KernelAndDevice {
+ public:
+ // Populates 'out' with a kernel appropriate for 'ndef'.
+ //
+ // Assumes that 'ndef' refers to a primitive op (as opposed to a function).
+ static Status InitOp(Device* device, const NodeDef& ndef,
+ KernelAndDevice* out);
+
+ // Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
+ // TensorFlow function in the FunctionLibraryRuntime).
+ //
+ // The provided FunctionLibraryRuntime MUST outlive all calls to
+ // Run() on the returned KernelAndDevice.
+ //
+ // TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
+ // The implementation of InitFn should work for both because
+ // FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
+ // appropriate. However, for now we keep them separate because I haven't
+ // figured out thread-safety concerns around FunctionLibraryRuntime (in
+ // particular, how the underlying FunctionLibraryDefinition might be mutated
+ // by another thread as new functions are registered with it).
+ // Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
+ // on to the caller (see locking in c_api.cc) for now. But I really should
+ // dig into this so that both InitOp and InitFn can be collapsed to
+ // FunctionLibraryRuntime::CreateKernel.
+ static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
+ KernelAndDevice* out);
+
+ KernelAndDevice() : device_(nullptr), flib_(nullptr) {}
+
+ // TODO(ashankar): Handle list-valued inputs.
+ Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs);
+
+ const OpKernel* kernel() const { return kernel_.get(); }
+
+ private:
+ std::unique_ptr<OpKernel> kernel_;
+ tensorflow::Device* device_;
+ tensorflow::FunctionLibraryRuntime* flib_;
+ tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_C_EAGER_RUNTIME_H_
diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc
new file mode 100644
index 0000000000..3b38e24704
--- /dev/null
+++ b/tensorflow/c/eager/runtime_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/runtime.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+Device* CPUDevice() {
+ return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
+}
+
+TEST(AttrTypeMap, Lookup) {
+ const AttrTypeMap* m = nullptr;
+ Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m);
+ EXPECT_FALSE(s.ok());
+ s = AttrTypeMapForOp("MatMul", &m);
+ ASSERT_TRUE(s.ok()) << s;
+
+ TF_AttrType t;
+ unsigned char is_list = 1;
+ s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
+ EXPECT_FALSE(s.ok());
+ EXPECT_NE(is_list, 0);
+ s = AttrTypeByName(m, "transpose_a", &t, &is_list);
+ ASSERT_TRUE(s.ok()) << s;
+ EXPECT_EQ(TF_ATTR_BOOL, t);
+ EXPECT_EQ(is_list, 0);
+
+ s = AttrTypeMapForOp("Squeeze", &m);
+ ASSERT_TRUE(s.ok()) << s;
+ s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
+ ASSERT_TRUE(s.ok()) << s;
+ EXPECT_EQ(TF_ATTR_INT, t);
+ EXPECT_NE(is_list, 0);
+}
+
+TEST(KernelAndDevice, Run) {
+ Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
+ std::vector<Tensor> inputs;
+ inputs.push_back(t);
+ inputs.push_back(t);
+ NodeDef ndef(AttrBuilder("MatMul")
+ .Set("T", DT_FLOAT)
+ .Set("transpose_a", false)
+ .Set("transpose_b", false)
+ .NumInputs(inputs.size())
+ .BuildNodeDef());
+ std::unique_ptr<Device> device(CPUDevice());
+ KernelAndDevice kernel;
+ Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
+ ASSERT_TRUE(s.ok()) << s;
+ std::vector<Tensor> outputs;
+ s = kernel.Run(&inputs, &outputs);
+ ASSERT_TRUE(s.ok()) << s;
+ ASSERT_EQ(1, outputs.size());
+ const Tensor& out = outputs[0];
+ EXPECT_EQ(7, out.matrix<float>()(0, 0));
+ EXPECT_EQ(10, out.matrix<float>()(0, 1));
+ EXPECT_EQ(15, out.matrix<float>()(1, 0));
+ EXPECT_EQ(22, out.matrix<float>()(1, 1));
+}
+
+// TODO(apassos) uncomment after rewriting to use the right benchmark API
+// void BM_CreateGraph(benchmark::State& state) {
+// for (auto _ : state) {
+// Scope root = Scope::NewRootScope();
+// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
+// auto M = ops::MatMul(root, C, C);
+// TF_CHECK_OK(root.status());
+// }
+// }
+// BENCHMARK(BM_CreateGraph);
+
+// void BM_RunGraph(benchmark::State& state) {
+// Scope root = Scope::NewRootScope();
+// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
+// auto M = ops::MatMul(root, C, C);
+// SessionOptions opts;
+// opts.config.set_inter_op_parallelism_threads(1);
+// opts.config.set_intra_op_parallelism_threads(1);
+// ClientSession sess(root, opts);
+// std::vector<Tensor> outputs;
+// for (auto _ : state) {
+// outputs.clear();
+// TF_CHECK_OK(sess.Run({M}, &outputs));
+// }
+// }
+// BENCHMARK(BM_RunGraph);
+
+// void BM_CreateAndDestroySession(benchmark::State& state) {
+// Scope root = Scope::NewRootScope();
+// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
+// auto M = ops::MatMul(root, C, C);
+// for (auto _ : state) {
+// ClientSession sess(root);
+// }
+// }
+// BENCHMARK(BM_CreateAndDestroySession);
+
+// void BM_KernelAndDeviceInit(benchmark::State& state) {
+// NodeDef ndef(AttrBuilder("MatMul")
+// .Set("T", DT_FLOAT)
+// .Set("transpose_a", false)
+// .Set("transpose_b", false)
+// .NumInputs(2)
+// .BuildNodeDef());
+// std::unique_ptr<Device> device(CPUDevice());
+// KernelAndDevice k;
+// for (auto _ : state) {
+// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
+// }
+// }
+// BENCHMARK(BM_KernelAndDeviceInit);
+
+// void BM_KernelAndDeviceRun(benchmark::State& state) {
+// Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
+// std::vector<Tensor> inputs;
+// inputs.push_back(t);
+// inputs.push_back(t);
+// std::vector<Tensor> outputs;
+// NodeDef ndef(AttrBuilder("MatMul")
+// .Set("T", DT_FLOAT)
+// .Set("transpose_a", false)
+// .Set("transpose_b", false)
+// .NumInputs(inputs.size())
+// .BuildNodeDef());
+// std::unique_ptr<Device> device(CPUDevice());
+// KernelAndDevice kernel;
+// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
+// for (auto _ : state) {
+// TF_CHECK_OK(kernel.Run(&inputs, &outputs));
+// }
+// }
+// BENCHMARK(BM_KernelAndDeviceRun);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h
index 0719070b97..12120c4ab9 100644
--- a/tensorflow/core/platform/platform.h
+++ b/tensorflow/core/platform/platform.h
@@ -49,7 +49,7 @@ limitations under the License.
#endif // !defined(RASPBERRY_PI)
#else
-// If no blargle platform specified, use:
+// If no platform specified, use:
#define PLATFORM_POSIX
#endif
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
new file mode 100644
index 0000000000..dd12db9f46
--- /dev/null
+++ b/tensorflow/python/eager/BUILD
@@ -0,0 +1,254 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+cc_library(
+ name = "pywrap_tfe_lib",
+ srcs = ["pywrap_tfe_src.cc"],
+ hdrs = ["pywrap_tfe.h"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/c:c_api",
+ "//tensorflow/c/eager:c_api",
+ "//tensorflow/core:lib",
+ "//tensorflow/python:numpy_lib",
+ "//tensorflow/python:py_func_lib",
+ "//util/python:python_headers",
+ ],
+)
+
+py_library(
+ name = "core",
+ srcs = ["core.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":context",
+ ":memory_trace",
+ ":tape",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:pywrap_tensorflow",
+ ],
+)
+
+py_library(
+ name = "tensor",
+ srcs = ["tensor.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//learning/brain/contrib/eager:__subpackages__"],
+ deps = [
+ ":context",
+ ":core",
+ ":tape",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:tensor_shape",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "context",
+ srcs = ["context.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//learning/brain/contrib/eager:__subpackages__"],
+ deps = [
+ "//tensorflow/python:device",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "tape",
+ srcs = ["tape.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "memory_trace",
+ srcs = ["memory_trace.py"],
+ srcs_version = "PY2AND3",
+)
+
+cuda_py_test(
+ name = "core_test",
+ srcs = ["core_test.py"],
+ additional_deps = [
+ ":context",
+ ":core",
+ ":execute",
+ "//tensorflow/python:pywrap_tensorflow",
+ ":tensor",
+ ":test",
+ "//third_party/py/numpy",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_library(
+ name = "test",
+ srcs = ["test.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":context",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
+ name = "execute",
+ srcs = ["execute.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":context",
+ ":core",
+ ":tape",
+ ":tensor",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "@six_archive//:six",
+ ],
+)
+
+cc_library(
+ name = "python_eager_op_gen",
+ srcs = ["python_eager_op_gen.cc"],
+ hdrs = ["python_eager_op_gen.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:op_gen_lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/python:python_op_gen",
+ ],
+)
+
+cc_library(
+ name = "python_eager_op_gen_main",
+ srcs = [
+ "python_eager_op_gen_main.cc",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":python_eager_op_gen",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_binary(
+ name = "python_eager_op_gen_demo",
+ deps = [
+ ":python_eager_op_gen_main",
+ "//tensorflow/core:ops",
+ ],
+)
+
+py_library(
+ name = "custom_gradient",
+ srcs = ["custom_gradient.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":core",
+ ":tape",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "graph_only_ops",
+ srcs = ["graph_only_ops.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+py_library(
+ name = "framework_for_generated_wrappers",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/eager:execute",
+ ],
+)
+
+py_library(
+ name = "function",
+ srcs = ["function.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":graph_only_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:graph_to_function_def",
+ "//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:util",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:core",
+ "//tensorflow/python/eager:execute",
+ "//tensorflow/python/eager:tape",
+ "//tensorflow/python/eager:tensor",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "pip_dependencies",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":context",
+ ":core",
+ ":execute",
+ ":tensor",
+ ":test",
+ "//tensorflow/python:pywrap_tensorflow",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+# Google-internal targets.
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
new file mode 100644
index 0000000000..f31d008254
--- /dev/null
+++ b/tensorflow/python/eager/context.py
@@ -0,0 +1,333 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for TensorFlow's "Eager" mode of execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import threading
+
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.platform import app
+from tensorflow.python.util import compat
+from tensorflow.python.util import tf_contextlib
+
+GRAPH_MODE = 0
+EAGER_MODE = 1
+
+# Default execution mode.
+_default_mode = GRAPH_MODE
+
+
+# TODO(agarwal): better name ?
+class _EagerContext(threading.local):
+ """Thread local eager context."""
+
+ def __init__(self):
+ super(_EagerContext, self).__init__()
+ self.device_index = -1
+ self.mode = _default_mode
+ self.scope_name = ""
+ self.recording_summaries = False
+
+
+# TODO(agarwal): rename to EagerContext / EagerRuntime ?
+class Context(object):
+ """Environment in which eager operations execute."""
+
+ def __init__(self, graph=None):
+ self._eager_context = _EagerContext()
+ if not self.in_eager_mode():
+ raise ValueError("Trying to create a Context in GRAPH_MODE")
+ # Create a handle
+ opts = pywrap_tensorflow.TF_NewSessionOptions(target=compat.as_bytes(""),
+ config=None)
+ with errors.raise_exception_on_not_ok_status() as status:
+ self._handle = pywrap_tensorflow.TFE_NewContext(opts, status)
+ pywrap_tensorflow.TF_DeleteSessionOptions(opts)
+ # Store list of devices
+ self._devices = []
+ with errors.raise_exception_on_not_ok_status() as status:
+ device_list = pywrap_tensorflow.TFE_ContextListDevices(
+ self._handle, status)
+ try:
+ for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
+ with errors.raise_exception_on_not_ok_status() as status:
+ dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i, status)
+ self._devices.append(pydev.canonical_name(dev_name))
+ finally:
+ pywrap_tensorflow.TF_DeleteDeviceList(device_list)
+
+ self._summary_writer_resource = None
+ self._graph = graph or tf_ops.get_default_graph()
+
+ def __del__(self):
+ if self._handle is not None:
+ with errors.raise_exception_on_not_ok_status() as status:
+ pywrap_tensorflow.TFE_DeleteContext(self._handle, status)
+
+ def __str__(self):
+ lines = [
+ "Eager TensorFlow environment with %d devices" % (len(self._devices))
+ ]
+ for i, d in enumerate(self._devices):
+ lines.append(" Device %d: %s" % (i, d))
+ return "\n".join(lines)
+
+ @tf_contextlib.contextmanager
+ def _mode(self, mode):
+ ctx = self._eager_context
+ old_mode = ctx.mode
+ ctx.mode = mode
+ try:
+ yield
+ finally:
+ ctx.mode = old_mode
+
+ def in_graph_mode(self):
+ """Returns True if current thread is in GRAPH mode."""
+ return self._eager_context.mode == GRAPH_MODE
+
+ def in_eager_mode(self):
+ """Returns True if current thread is in EAGER mode."""
+ return self._eager_context.mode == EAGER_MODE
+
+ @property
+ def scope_name(self):
+ """Returns scope name for the current thread."""
+ return self._eager_context.scope_name
+
+ @scope_name.setter
+ def scope_name(self, s):
+ """Sets scope name for the current thread."""
+ self._eager_context.scope_name = s
+
+ @property
+ def summary_writer_resource(self):
+ """Returns summary writer resource."""
+ return self._summary_writer_resource
+
+ @summary_writer_resource.setter
+ def summary_writer_resource(self, resource):
+ """Sets summary writer resource."""
+ self._summary_writer_resource = resource
+
+ @property
+ def recording_summaries(self):
+ """Returns True if recording summaries is enabled in current thread.."""
+ return self._eager_context.recording_summaries
+
+ @recording_summaries.setter
+ def recording_summaries(self, val):
+ """Enables recording summaries is enabled in current thread.."""
+ self._eager_context.recording_summaries = val
+
+ # TODO(agarwal): remove?
+ @property
+ def _device_index(self):
+ return self._eager_context.device_index
+
+ # TODO(agarwal): remove?
+ @_device_index.setter
+ def _device_index(self, val):
+ self._eager_context.device_index = val
+
+ @property
+ def device_name(self):
+ """Returns the device name for the current thread."""
+ index = self._device_index
+ return None if index < 0 else self._devices[index]
+
+ def devices(self):
+ """List of the names of devices available to execute operations."""
+ return self._devices
+
+ def num_gpus(self):
+ """The number of GPUs available to execute operations."""
+ # TODO(ashankar): Use TF_DeviceListType to count GPU devices.
+ return len(self._devices) - 1
+
+ def as_default(self):
+ """Returns a context manager to make self the default for this thread."""
+ return _default_context_stack.get_controller(self)
+
+
+class _DefaultContextStack(tf_ops._DefaultStack): # pylint: disable=protected-access
+ """A thread-local stack of Context objects."""
+
+ def __init__(self):
+ super(_DefaultContextStack, self).__init__()
+ self._global_default_context = None
+
+ def get_default(self):
+ """Returns a thread local object if present, else a global default."""
+ return (super(_DefaultContextStack, self).get_default() or
+ self.global_default_context)
+
+ @property
+ def global_default_context(self):
+ if self._global_default_context is None:
+ self._global_default_context = Context()
+ return self._global_default_context
+
+ def reset(self):
+ super(_DefaultContextStack, self).reset()
+ self._global_default_context = None
+
+
+_default_context_stack = _DefaultContextStack()
+
+
+def get_default_context():
+ """Returns a default Context object."""
+ return _default_context_stack.get_default()
+
+
+# TODO(agarwal): switch users to get_default_context and get rid of this
+# function.
+def context():
+ return get_default_context()
+
+
+def in_graph_mode():
+ """Returns True if current thread is in GRAPH mode for default context."""
+ return get_default_context().in_graph_mode()
+
+
+def in_eager_mode():
+ """Returns True if current thread is in EAGER mode for default context."""
+ return get_default_context().in_eager_mode()
+
+
+def graph_mode():
+ """Context-manager to enable GRAPH mode for current thread."""
+ return get_default_context()._mode(GRAPH_MODE) # pylint: disable=protected-access
+
+
+def eager_mode():
+ """Context-manager to enable EAGER mode for current thread."""
+ return get_default_context()._mode(EAGER_MODE) # pylint: disable=protected-access
+
+
+@contextlib.contextmanager
+def namescope(name):
+ """ContextManager for creating hierarchical name scopes."""
+ ctx = get_default_context()
+ old_name = ctx.scope_name
+ ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
+ try:
+ yield
+ finally:
+ ctx.scope_name = old_name
+
+
+def scope_name():
+ """Name of the current scope."""
+ return get_default_context().scope_name
+
+
+@tf_contextlib.contextmanager
+def device(name):
+ """Context-manager to force placement of operations and Tensors on a device.
+
+ For example:
+ ```python
+ with tfe.device('gpu:0'):
+ with tfe.device('cpu:0'):
+ shape = tfe.Tensor([], dtype=tf.int32)
+ x = ops.truncated_normal(shape, tf.float32)
+ ```
+ will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
+ operation
+ runs on GPU 0.
+
+ Args:
+ name: Name of the device (see get_default_context().devices()), or None to
+ enable automatic placement.
+
+ Yields:
+ Nothing.
+
+ Raises:
+ ValueError: If name does not correspond to a valid device.
+ """
+ device_index = -1
+ ctx = get_default_context()
+ if name is not None:
+ name = pydev.canonical_name(name)
+ all_devices = ctx.devices()
+ for i, d in enumerate(all_devices):
+ # TODO(ashankar): This will change when we have distributed support.
+ # At that point, should not look for a string suffix but be able to
+ # do a full string comparison.
+ if d.endswith(name):
+ device_index = i
+ break
+ if device_index < 0:
+ raise ValueError("device {} does not match the available devices ({})".
+ format(name, all_devices))
+ old_device_index = ctx._device_index # pylint: disable=protected-access
+ try:
+ ctx._device_index = device_index # pylint: disable=protected-access
+ yield
+ finally:
+ ctx._device_index = old_device_index # pylint: disable=protected-access
+
+
+@contextlib.contextmanager
+def record_summaries():
+ """Context-manager to enable recording of summaries."""
+ ctx = get_default_context()
+ old = ctx.recording_summaries
+ ctx.recording_summaries = True
+ try:
+ yield
+ finally:
+ ctx.recording_summaries = old
+
+
+def should_record_summary():
+ """True if a summary should be recorded now."""
+ c = get_default_context()
+ return c.recording_summaries and c.summary_writer_resource is not None
+
+
+def run(main=None, argv=None):
+ """Runs the program with an optional 'main' function and 'argv' list.
+
+ The program will run with eager execution enabled.
+
+ Args:
+ main: the main function to run
+ argv: the arguments to pass to it
+ """
+ enable_eager_execution()
+ app.run(main, argv)
+
+
+# TODO(apassos): This should not be a part of the public API.
+def enable_eager_execution():
+ """Enables, for the rest of the lifetime of this program, eager execution.
+
+ If not called immediately on startup risks creating breakage and bugs.
+ """
+ global _default_mode
+ assert _default_mode == GRAPH_MODE
+ _default_mode = EAGER_MODE
diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py
new file mode 100644
index 0000000000..64c615fb63
--- /dev/null
+++ b/tensorflow/python/eager/core.py
@@ -0,0 +1,75 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for TensorFlow's "Eager" mode of execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import context
+from tensorflow.python.eager import memory_trace
+from tensorflow.python.framework import errors
+
+# Trace of execution and memory usage.
+_active_trace = None
+
+
+def _status_to_exception(code, message):
+ try:
+ error_class = errors.exception_type_from_error_code(code)
+ return error_class(None, None, message)
+ except KeyError:
+ return errors.UnknownError(None, None, message, code)
+
+
+class _NotOkStatusException(Exception):
+ """Exception class to handle not ok Status."""
+
+ def __init__(self, message, code):
+ super(_NotOkStatusException, self).__init__()
+ self.message = message
+ self.code = code
+
+ def __str__(self):
+ e = _status_to_exception(self.code, self.message)
+ return "%s: %s" % (e.__class__.__name__, e)
+
+
+pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
+
+
+def enable_tracing():
+ """Enables tracing of execution and memory usage.
+
+ WARNING: tracing is not thread-safe.
+ """
+ global _active_trace
+ _active_trace = memory_trace.MemoryTrace(
+ len(context.get_default_context().devices()))
+
+
+def flush_trace():
+ """Flushes the active trace, if it exists.
+
+ WARNING: tracing is not thread-safe.
+ """
+ if _active_trace is not None:
+ _active_trace.flush_trace()
+
+
+def active_trace():
+ """Returns the current global active trace of execution and memory usage."""
+ return _active_trace
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
new file mode 100644
index 0000000000..7e76236ee5
--- /dev/null
+++ b/tensorflow/python/eager/core_test.py
@@ -0,0 +1,483 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for core."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+import numpy as np
+
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import context
+from tensorflow.python.eager import core
+from tensorflow.python.eager import execute
+from tensorflow.python.eager import tensor
+from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+
+
+def truncated_normal(shape):
+ return execute.execute(
+ 'TruncatedNormal',
+ 1,
+ inputs=[shape],
+ attrs=('dtype', dtypes.float32.as_datatype_enum, 'T',
+ shape.dtype.as_datatype_enum, 'seed', 0, 'seed2', 0))[0]
+
+
+class TFETest(test_util.TensorFlowTestCase):
+
+ def testContext(self):
+ ctx = context.Context()
+ self.assertFalse(ctx.in_graph_mode())
+ self.assertTrue(ctx.in_eager_mode())
+ self.assertEqual('', ctx.scope_name)
+ self.assertEqual(-1, ctx._device_index) # pylint: disable=protected-access
+ self.assertFalse(ctx.recording_summaries)
+ self.assertIsNone(ctx.summary_writer_resource)
+ del ctx
+
+ def testDefaultContext(self):
+ orig = context.get_default_context()
+ self.assertIs(context.get_default_context(), orig)
+ c0 = context.Context()
+ self.assertIs(context.get_default_context(), orig)
+ context_manager_0 = c0.as_default()
+ self.assertIs(context.get_default_context(), orig)
+ with context_manager_0 as c0:
+ self.assertIs(context.get_default_context(), c0)
+ with context.Context().as_default() as c1:
+ self.assertIs(context.get_default_context(), c1)
+ self.assertIs(context.get_default_context(), c0)
+ self.assertIs(context.get_default_context(), orig)
+
+ def testContextWithThreads(self):
+
+ def run_fn(ctx1):
+ ctx2 = context.get_default_context()
+ # Default context created in different threads are different.
+ self.assertIsNot(ctx1, ctx2)
+ # Check that default values of the context created in a different thread
+ # are set correctly.
+ self.assertFalse(ctx2.in_graph_mode())
+ self.assertTrue(ctx2.in_eager_mode())
+ self.assertEqual('', ctx2.scope_name)
+ self.assertEqual(-1, ctx2._device_index) # pylint: disable=protected-access
+ self.assertFalse(ctx2.recording_summaries)
+ self.assertIsNone(ctx2.summary_writer_resource)
+
+ ctx1 = context.get_default_context()
+ t = threading.Thread(target=run_fn, args=(ctx1,))
+ t.start()
+ t.join()
+
+ def testScalarTensor(self):
+ t = tensor.Tensor(3)
+ self.assertEqual(t.numpy(), tensor.Tensor(np.array(3)).numpy())
+ self.assertEqual(dtypes.int32, t.dtype)
+ self.assertEqual(0, t.shape.ndims)
+ self.assertAllEqual([], t.shape.as_list())
+
+ def testTensorAndNumpyMatrix(self):
+ expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
+ actual = tensor.Tensor([[1.0, 2.0], [3.0, 4.0]])
+ self.assertAllEqual(expected, actual.numpy())
+ self.assertEqual(np.float32, actual.numpy().dtype)
+ self.assertEqual(dtypes.float32, actual.dtype)
+ self.assertAllEqual([2, 2], actual.shape.as_list())
+
+ def testFloatDowncast(self):
+ # Unless explicitly specified, float64->float32
+ t = tensor.Tensor(3.0)
+ self.assertEqual(dtypes.float32, t.dtype)
+ t = tensor.Tensor(3.0, dtype=dtypes.float64)
+ self.assertEqual(dtypes.float64, t.dtype)
+
+ def testBool(self):
+ t = tensor.Tensor(False)
+ if t:
+ self.assertFalse(True)
+
+ def testIntDowncast(self):
+ t = tensor.Tensor(3)
+ self.assertEqual(dtypes.int32, t.dtype)
+ t = tensor.Tensor(3, dtype=dtypes.int64)
+ self.assertEqual(dtypes.int64, t.dtype)
+ t = tensor.Tensor(2**33)
+ self.assertEqual(dtypes.int64, t.dtype)
+
+ def testTensorCreationFailure(self):
+ with self.assertRaises(Exception):
+ # Should fail because the each row of the Python object has a different
+ # number of columns.
+ self.assertEqual(None, tensor.Tensor([[1], [1, 2]]))
+
+ def testTensorPlacement(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+
+ x = tensor.Tensor(1.).as_gpu_tensor()
+ with context.device('gpu:0'):
+ y = tensor.Tensor(2.)
+ # Add would fail if t2 were not on GPU
+ result = execute.execute(
+ 'Add', 1, inputs=[x, y],
+ attrs=('T', x.dtype.as_datatype_enum))[0].as_cpu_tensor().numpy()
+ self.assertEqual(3, result)
+
+ def testNumpyOrderHandling(self):
+ n = np.array([[1, 2], [3, 4]], order='F')
+ t = tensor.Tensor(n)
+ self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
+
+ def testCopyBetweenDevices(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+
+ x = tensor.Tensor([[1., 2.], [3., 4.]])
+ x = x.as_cpu_tensor()
+ x = x.as_gpu_tensor()
+ x = x.as_gpu_tensor()
+ x = x.as_cpu_tensor()
+
+ # Invalid device
+ with self.assertRaises(errors.InvalidArgumentError):
+ x.as_gpu_tensor(context.context().num_gpus() + 1)
+
+ def testNumpyForceCPU(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+
+ cpu = tensor.Tensor([[1., 2.], [3., 4.]])
+ c2g = cpu.as_gpu_tensor()
+ self.assertAllEqual(c2g.numpy(), cpu.numpy())
+
+ def testCopyFromCPUToCPU(self):
+ ta = tensor.Tensor([[1, 2], [3, 4]])
+ tb = ta.as_cpu_tensor()
+
+ self.assertNotEqual(ta._handle, tb._handle)
+ self.assertAllEqual(ta.numpy(), tb.numpy())
+
+ def testRegisterExceptionClass(self):
+ with self.assertRaises(TypeError):
+ pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
+ pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
+
+ # TODO(agarwal): add tests passing incorrect typed values to attrs.
+ def testExecuteBasic(self):
+ three = tensor.Tensor(3)
+ five = tensor.Tensor(5)
+ product = execute.execute(
+ 'Mul',
+ num_outputs=1,
+ inputs=[three, five],
+ attrs=('T', three.dtype.as_datatype_enum))[0]
+ self.assertEqual(15, product.numpy())
+
+ def testExecuteTooManyNumOutputs(self):
+ # num_outputs provided is 50, but only one output is produced.
+ # That should be okay.
+ product = execute.execute(
+ 'Mul',
+ num_outputs=50,
+ inputs=[tensor.Tensor(3), tensor.Tensor(5)],
+ attrs=('T', dtypes.int32.as_datatype_enum))[0]
+ self.assertEqual(15, product.numpy())
+
+ def testMatMulGPU(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+ three = tensor.Tensor([[3.]]).as_gpu_tensor()
+ five = tensor.Tensor([[5.]]).as_gpu_tensor()
+ product = execute.execute(
+ 'MatMul',
+ num_outputs=1,
+ inputs=[three, five],
+ attrs=('transpose_a', False, 'transpose_b', False, 'T',
+ three.dtype.as_datatype_enum))[0]
+ self.assertEqual([[15.0]], product.numpy())
+
+ def testExecuteStringAttr(self):
+ checked_three = execute.execute(
+ 'CheckNumerics',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3.)],
+ attrs=('message', 'just checking', 'T',
+ dtypes.float32.as_datatype_enum))[0]
+ self.assertEqual([[3]], checked_three.numpy())
+
+ def testExecuteStringAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ _ = execute.execute(
+ 'CheckNumerics',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3.)],
+ attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum))
+
+ def testExecuteFloatAttr(self):
+ almost_equal = execute.execute(
+ 'ApproximateEqual',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
+ attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0]
+ self.assertTrue(almost_equal.numpy())
+
+ def testExecuteFloatAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ _ = execute.execute(
+ 'ApproximateEqual',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
+ attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum))
+
+ def testExecuteIntAttr(self):
+ total = execute.execute(
+ 'AddN',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3), tensor.Tensor(4)],
+ attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0]
+ self.assertEqual(7, total.numpy())
+
+ def testExecuteIntAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ _ = execute.execute(
+ 'AddN',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3), tensor.Tensor(4)],
+ attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2'))
+
+ # Looks like we don't have an existing op with list(bool) attrs.
+ def testExecuteBoolAttr(self):
+ product = execute.execute(
+ 'MatMul',
+ num_outputs=1,
+ inputs=[tensor.Tensor([[3]]),
+ tensor.Tensor([[5]])],
+ attrs=('transpose_a', True, 'transpose_b', False, 'T',
+ dtypes.int32.as_datatype_enum))[0]
+ self.assertEqual([[15]], product.numpy())
+
+ def testExecuteShapeAttr(self):
+ execute.execute(
+ 'VarHandleOp',
+ num_outputs=1,
+ inputs=[],
+ attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum,
+ 'container', '', 'shared_name', ''))
+
+ def testExecuteShapeAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'VarHandleOp',
+ num_outputs=1,
+ inputs=[],
+ attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum,
+ 'container', '', 'shared_name', ''))
+
+ def testExecuteListStringAttr(self):
+ execute.execute(
+ 'TensorSummary',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3.0)],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'description',
+ 'tensor_summary', 'labels', ['3',
+ 'summary'], 'display_name', 'test'))
+
+ def testExecuteListStringAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'TensorSummary',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3.0)],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
+ 'labels', 3, 'display_name', 'test'))
+
+ def testExecuteListStringAttrBadListValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'TensorSummary',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3.0)],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
+ 'labels', [3], 'display_name', 'test'))
+
+ def testExecuteListFloatAttr(self):
+ b = execute.execute(
+ 'Bucketize',
+ num_outputs=1,
+ inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0,
+ 6.0]))[0]
+ self.assertAllEqual([0, 1, 2], b.numpy())
+
+ def testExecuteListFloatAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Bucketize',
+ num_outputs=1,
+ inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0))
+
+ def testExecuteListFloatAttrBadListValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Bucketize',
+ num_outputs=1,
+ inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries',
+ ['4.0', '6.0']))
+
+ def testExecuteListIntAttr(self):
+ b = execute.execute(
+ 'Squeeze',
+ num_outputs=1,
+ inputs=[tensor.Tensor([[[3.0]]])],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0]
+ self.assertAllEqual([3], b.numpy())
+
+ def testExecuteListIntAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Squeeze',
+ num_outputs=1,
+ inputs=[tensor.Tensor([[[3.0]]])],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0))
+
+ def testExecuteListIntAttrBadListValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Squeeze',
+ num_outputs=1,
+ inputs=[tensor.Tensor([[[3.0]]])],
+ attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims',
+ ['0', '2']))
+
+ def testExecuteListTypeListShapeAttr(self):
+ execute.execute(
+ 'Barrier',
+ num_outputs=1,
+ inputs=[],
+ attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
+ [[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
+
+ def testExecuteListTypeAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Barrier',
+ num_outputs=1,
+ inputs=[],
+ attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes',
+ [[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
+
+ def testExecuteListTypeAttrBadListValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Barrier',
+ num_outputs=1,
+ inputs=[],
+ attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1,
+ 'container', '', 'shared_name', ''))
+
+ def testExecuteListShapeAttrBadValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Barrier',
+ num_outputs=1,
+ inputs=[],
+ attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
+ [1, 2], 'capacity', -1, 'container', '', 'shared_name', ''))
+
+ def testExecuteListShapeAttrBadListValue(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Barrier',
+ num_outputs=1,
+ inputs=[],
+ attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
+ [1], 'capacity', -1, 'container', '', 'shared_name', ''))
+
+ def testExecuteMultipleOutputs(self):
+ split_dim = 1
+ value = [[0, 1, 2], [3, 4, 5]]
+ x1, x2, x3 = execute.execute(
+ 'Split',
+ num_outputs=3,
+ inputs=[tensor.Tensor(split_dim),
+ tensor.Tensor(value)],
+ attrs=('num_split', 3, 'T', dtypes.int32.as_datatype_enum))
+ self.assertAllEqual([[0], [3]], x1.numpy())
+ self.assertAllEqual([[1], [4]], x2.numpy())
+ self.assertAllEqual([[2], [5]], x3.numpy())
+
+ def testExecuteBadNumOutputsArgument(self):
+ with self.assertRaises(TypeError):
+ execute.execute(
+ 'Relu', [],
+ inputs=[tensor.Tensor(3.0)],
+ attrs=('T', dtypes.float32.as_datatype_enum))
+
+ def testExecuteUnknownOp(self):
+ with self.assertRaises(errors.NotFoundError):
+ execute.execute('BlahBlahBlah', num_outputs=1, inputs=[], attrs=None)
+
+ def testExecuteUnknownAttr(self):
+ with self.assertRaises(errors.InvalidArgumentError):
+ execute.execute(
+ 'Identity',
+ num_outputs=1,
+ inputs=[tensor.Tensor(3)],
+ attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah'))
+
+ def testComposition(self):
+
+ def add(x, y):
+ return execute.execute(
+ 'Add',
+ num_outputs=1,
+ inputs=[x, y],
+ attrs=('T', dtypes.int32.as_datatype_enum))[0]
+
+ x = tensor.Tensor(1)
+ three_x = add(add(x, x), x)
+ self.assertEquals(dtypes.int32, three_x.dtype)
+ self.assertEquals(3, three_x.numpy())
+
+ def testOperationWithNoInputsRunsOnDevice(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+ shape = tensor.Tensor([], dtype=dtypes.int32)
+
+ # x: Run the "TruncatedNormal" op CPU and copy result to GPU.
+ x = truncated_normal(shape).as_gpu_tensor()
+ # y: Explicitly run the "TruncatedNormal" op on GPU.
+ with context.device('gpu:0'):
+ y = truncated_normal(shape)
+ # Add would fail if x and y were not on the same device.
+ execute.execute(
+ 'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum))
+
+ def testInvalidDevice(self):
+ with self.assertRaises(ValueError):
+ with context.device('pu:0'):
+ _ = tensor.Tensor(1)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py
new file mode 100644
index 0000000000..afa328b9fe
--- /dev/null
+++ b/tensorflow/python/eager/custom_gradient.py
@@ -0,0 +1,70 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Decorator to overrides the gradient for a function."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from autograd import core as ag_core
+
+from tensorflow.python.eager import tape
+from tensorflow.python.eager import tensor as _tensor
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.util import nest
+
+
+def _watch_value_from_tape(tensor):
+ for t in tape._tape_stack.stack: # pylint: disable=protected-access
+ w = t.value.tensors.get(tape.tensor_id(tensor), None)
+ if w is not None:
+ return w
+ return tensor
+
+
+def custom_gradient(f):
+ """Decorator to define a function with a custom gradient.
+
+ The input function is expected to return the tuple
+ (results, gradient_function)
+
+ The output function will return results while possibly recording the
+ gradient_function and inputs in the tape.
+
+ Args:
+ f: function to be decorated.
+
+ Returns:
+ decorated function.
+ """
+
+ def decorated(*args, **kwargs):
+ """Decorated function with custom gradient."""
+ input_tensors = [_watch_value_from_tape(x) for x in args
+ if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
+ or ag_core.isnode(x)]
+ result, grad_fn = f(*args, **kwargs)
+
+ flat_result = nest.flatten(result)
+ flat_result = [ag_core.getval(x) for x in flat_result]
+ flat_result = tape.record_operation(
+ flat_result,
+ input_tensors,
+ [],
+ grad_fn)
+ flat_result = list(flat_result)
+ return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
+
+ return decorated
diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py
new file mode 100644
index 0000000000..b8178e1388
--- /dev/null
+++ b/tensorflow/python/eager/execute.py
@@ -0,0 +1,241 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions called by the generated code to execute an eager-mode op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from autograd import core as ag_core
+import six
+
+from google.protobuf import text_format
+from tensorflow.core.framework import tensor_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import context
+from tensorflow.python.eager import core
+from tensorflow.python.eager import tape
+from tensorflow.python.eager import tensor
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util import compat
+
+
+def execute(op_name, num_outputs, inputs, attrs=None, name=None):
+ """Execute a TensorFlow operation.
+
+ Args:
+ op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
+ execute.
+ num_outputs: The number of outputs of the operation to fetch.
+ (Explicitly provided instead of being inferred for performance
+ reasons).
+ inputs: A list of inputs to the operation. Each entry should be a Tensor, or
+ a value which can be passed to the Tensor constructor to create one.
+ attrs: A tuple with alternating string attr names and attr values for this
+ operation.
+ name: Customized name for the operation.
+
+ Returns:
+ None if there are no outputs, a single Tensor object if there is one output
+ and a list of Tensor objects if there are multiple outputs.
+
+ Raises:
+ An exception on error.
+ """
+ ctx = context.get_default_context()
+ # TODO(apassos) move this to convert_to_tensor
+ inputs = [ag_core.getval(x) for x in inputs]
+ # pylint: disable=protected-access
+ input_handles = [c._handle for c in inputs]
+ device_name = ctx.device_name
+ try:
+ outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
+ op_name, input_handles, attrs,
+ num_outputs)
+ # pylint: enable=protected-access
+ except core._NotOkStatusException as e: # pylint: disable=protected-access
+ raise core._status_to_exception(e.code, e.message) # pylint: disable=protected-access
+ # pylint: enable=protected-access
+
+ tensors = [tensor._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
+ if core.active_trace() is not None:
+ trace_name = name if name else op_name
+ for t in tensors:
+ # pylint: disable=protected-access
+ core.active_trace().record_tensor(trace_name,
+ tape.tensor_id(t),
+ t._device_name(),
+ t.shape.num_elements())
+ # pylint: enable=protected-access
+ return tensors
+
+
+def record_gradient(unused_op_name, unused_inputs, unused_attrs, results,
+ unused_name):
+ """Import backprop if you want gradients recorded."""
+ return results
+
+
+def make_float(v, arg_name):
+ if not isinstance(v, compat.real_types):
+ raise TypeError("Expected float for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return float(v)
+
+
+def make_int(v, arg_name):
+ if isinstance(v, six.string_types):
+ raise TypeError("Expected int for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ try:
+ return int(v)
+ except (ValueError, TypeError):
+ raise TypeError("Expected int for argument '%s' not %s." %
+ (arg_name, repr(v)))
+
+
+def make_str(v, arg_name):
+ if not isinstance(v, compat.bytes_or_text_types):
+ raise TypeError("Expected string for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return compat.as_bytes(v) # Convert unicode strings to bytes.
+
+
+def make_bool(v, arg_name):
+ if not isinstance(v, bool):
+ raise TypeError("Expected bool for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return v
+
+
+def make_type(v, arg_name):
+ try:
+ v = dtypes.as_dtype(v).base_dtype
+ except TypeError:
+ raise TypeError("Expected DataType for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ i = v.as_datatype_enum
+ return i
+
+
+def make_shape(v, arg_name):
+ """Convert v into a list."""
+ # Args:
+ # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
+ # arg_name: String, for error messages.
+
+ # Returns:
+ # None if the rank is unknown, otherwise a list of ints (or Nones in the
+ # position where the dimension is unknown).
+ try:
+ shape = tensor_shape.as_shape(v)
+ except TypeError as e:
+ raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
+ except ValueError as e:
+ raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
+ if shape.ndims is None:
+ return None
+ else:
+ return shape.as_list()
+
+
+def make_tensor(v, arg_name):
+ """Ensure v is a TensorProto."""
+ if isinstance(v, tensor_pb2.TensorProto):
+ return v
+ elif isinstance(v, six.string_types):
+ pb = tensor_pb2.TensorProto()
+ text_format.Merge(v, pb)
+ return pb
+ raise TypeError(
+ "Don't know how to convert %s to a TensorProto for argument '%s'" %
+ (repr(v), arg_name))
+
+
+def args_to_matching_eager(l, default_dtype=None):
+ """Convert sequence `l` to eager same-type Tensors."""
+ # TODO(josh11b): Could we do a better job if we also passed in the
+ # allowed dtypes when that was known?
+
+ # Is some input already a Tensor with a dtype?
+ dtype = None
+ for t in l:
+ if isinstance(ag_core.getval(t), tensor.Tensor):
+ dtype = t.dtype
+ break
+
+ if dtype is None:
+ # TODO(josh11b): At the moment, I don't think this can fail, but at some
+ # point we likely should have some logic to prevent bad conversions.
+ dtype = default_dtype
+
+ if dtype is None:
+ # Infer a dtype based on the first value, and use that dtype for the
+ # remaining values.
+ ret = []
+ for t in l:
+ ret.append(tensor.convert_to_eager_tensor(t, dtype))
+ if dtype is None:
+ dtype = ret[-1].dtype
+ else:
+ ret = [tensor.convert_to_eager_tensor(t, dtype) for t in l]
+
+ return dtype, ret
+
+
+def convert_to_mixed_eager_tensors(values):
+ v = [t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t)
+ for t in values]
+ types = [t.dtype for t in v]
+ return types, v
+
+
+def args_to_mixed_eager_tensors(lists):
+ """Converts a list of same-length lists of values to eager tensors."""
+ assert len(lists) > 1
+
+ # Generate an error if len(lists[i]) is not the same for all i.
+ lists_ret = []
+ for l in lists[1:]:
+ if len(l) != len(lists[0]):
+ raise ValueError(
+ "Expected list arguments to be the same length: %d != %d (%r vs. %r)"
+ % (len(lists[0]), len(l), lists[0], l))
+ lists_ret.append([])
+
+ # Convert the first element of each list first, then the second element, etc.
+ types = []
+ for i in range(len(lists[0])):
+ dtype = None
+ # If any list has a Tensor, use that dtype
+ for l in lists:
+ if isinstance(ag_core.getval(l[i]), tensor.Tensor):
+ dtype = l[i].dtype
+ break
+ if dtype is None:
+ # Convert the first one and use its dtype.
+ lists_ret[0].append(tensor.convert_to_eager_tensor(lists[0][i]))
+ dtype = lists_ret[0][i].dtype
+ for j in range(1, len(lists)):
+ lists_ret[j].append(
+ tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
+ else:
+ # Convert everything to the found dtype.
+ for j in range(len(lists)):
+ lists_ret[j].append(
+ tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
+ types.append(dtype)
+ return types, lists_ret
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
new file mode 100644
index 0000000000..e4866b6105
--- /dev/null
+++ b/tensorflow/python/eager/function.py
@@ -0,0 +1,502 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=unidiomatic-typecheck
+"""Defun decorator for defining graph-mode functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import contextlib
+import threading
+
+from autograd import core as ag_core
+import numpy as np
+
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import context
+from tensorflow.python.eager import execute
+from tensorflow.python.eager import tape
+from tensorflow.python.eager import tensor
+from tensorflow.python.eager.graph_only_ops import graph_placeholder
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import graph_to_function_def
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.util import nest
+
+# Thread-local storage for tfe Tensors which are referenced while evaluating a
+# graph-mode function.
+_scoped_captures = threading.local()
+# _scoped_captures.tensors is either None or a map from tfe.Tensor id to a pair
+# of a tfe tensor and its corresponding placeholder to pass as a function
+# argument. The value should be None unless we're in function definition
+# context.
+_scoped_captures.tensors = None
+
+
+@contextlib.contextmanager
+def capture_tensors(captures):
+ old = _scoped_captures.__dict__.get("tensors", None)
+ try:
+ _scoped_captures.tensors = captures
+ yield
+ finally:
+ _scoped_captures.tensors = old
+
+
+def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
+ """Captures a tfe Tensor while building a graph mode function.
+
+ Creates a placeholder to pass the tensor as an argument.
+
+ Arguments:
+ value: A tfe.Tensor object
+ dtype: The datatype of the value produced by the node in the graph.
+ name: Name of the node in the graph.
+ as_ref: Ignored (required by register_tensor_conversion_function).
+
+ Returns:
+ A placeholder which will, at runtime, have the value of this tensor.
+
+ Raises:
+ ValueError: if called outside a defun context.
+ """
+ _ = as_ref
+ tensor_map = _scoped_captures.tensors
+ if tensor_map is None:
+ raise ValueError(
+ "Trying to use tfe.Tensor objects in a graph outside graph mode. "
+ "To build a graph use tfe.defun or tfe.func_to_object.")
+ captured_value = tensor_map.get(tape.tensor_id(value), None)
+ if captured_value is None:
+ captured_value = graph_placeholder(
+ dtype=dtype or value.dtype, shape=value.shape, name=name)
+ if captured_value.dtype == dtypes.resource:
+ captured_value._handle_data = value._handle_data # pylint: disable=protected-access
+ tensor_map[tape.tensor_id(value)] = (value, captured_value)
+ else:
+ captured_value = captured_value[1]
+ return captured_value
+
+
+# TODO(apassos): it'd be really nice if we could scope this registration.
+# Note that we register this at a higher priority than ops.Tensor since we want
+# to handle subclass specific conversion before a superclass conversion.
+ops.register_tensor_conversion_function(
+ tensor.Tensor, _convert_to_graph_constant, priority=-1)
+
+
+class _CapturingContext(object):
+ """Tracks references to Tensors outside this context while it is active."""
+
+ def __init__(self):
+ # known_ops are ops which are created while this context is active
+ self.known_ops = set()
+
+ # captured_tensors are all tensors referenced to by ops in this context but
+ # not produced in it
+ self.captured_tensors = set()
+
+ def AddOp(self, op): # pylint: disable=invalid-name
+ if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
+ raise ValueError("tfe.defun cannot capture variables created without "
+ "using tf.get_variable. Op: %s" % op)
+ self.known_ops.add(op)
+ for i in op.inputs:
+ if i.op not in self.known_ops:
+ self.captured_tensors.add(i)
+
+ def __enter__(self):
+ self._g = ops.get_default_graph()
+ self._old = self._g._get_control_flow_context() # pylint: disable=protected-access
+ self._g._set_control_flow_context(self) # pylint: disable=protected-access
+
+ def __exit__(self, _, __, ___): # pylint: disable=invalid-name
+ self._g._set_control_flow_context(self._old) # pylint: disable=protected-access
+
+
+def _forward_name(n):
+ """The name of a generated forward defun named n."""
+ return "__forward_%s_%s" % (n, ops.uid())
+
+
+def _backward_name(n):
+ """The name of a generated backward defun named n."""
+ return "__backward_%s_%s" % (n, ops.uid())
+
+
+def _inference_name(n):
+ """The name of a forward-but-no-gradient defun named n."""
+ return "__inference_%s_%s" % (n, ops.uid())
+
+
+class _DefinedFunction(object):
+ """Mocks the interface of tf _DefinedFunction."""
+
+ def __init__(self, fdef):
+ self.definition = fdef
+ self.name = fdef.signature.name
+ self.grad_func_name = None
+ self.python_grad_func = None
+
+
+def _map_sequence_obj_to_idx(sequence):
+ """Maps objs in the sequence from id(obj) to sequence index."""
+ return {id(x): i for i, x in enumerate(sequence)}
+
+
+class _GraphModeFunction(object):
+ """Callable object representing a graph-mode function.
+
+ Args:
+ input_placeholders: list of placeholder values to feed when calling
+ the wrapped function.
+ extra_inputs: Tensor inputs this function definition closed over which
+ are passed as arguments. Need to track so gradients are supported
+ correctly.
+ fdef: the function definition we want to call.
+ graph: the graph from which the fdef operations were pulled. Used as
+ a context when computing gradients.
+ operations: the subset of operations in the graph used in the function
+ definition.
+ func_outputs: the python outputs of the graph-mode function, with
+ tensorflow.Tensor objects to be replaced by tfe values when called.
+ func_outputs_to_fdef_outputs: Maps id(obj) in func_outputs to index of
+ fdef's outputs. It allows mapping fdef output tensors to nested
+ func_outputs structure.
+ output_shapes: List of shapes of all tensors which are output by the
+ internal function.
+ """
+
+ def __init__(self, input_placeholders, extra_inputs, fdef, graph, operations,
+ func_outputs, func_outputs_to_fdef_outputs, output_shapes):
+ assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
+ len(input_placeholders), len(fdef.signature.input_arg))
+ self._input_placeholders = input_placeholders
+ self._extra_inputs = list(extra_inputs)
+ self._graph = graph
+ self._has_backprop = False
+ self._func_name = fdef.signature.name
+ self._fdef = _DefinedFunction(fdef)
+ self._num_outputs = len(fdef.signature.output_arg)
+ self._ops = operations
+ self._func_outputs = func_outputs
+ if (isinstance(func_outputs, (ops.Tensor, type(None))) or
+ ag_core.isnode(func_outputs)):
+ self._returns = [func_outputs]
+ else:
+ self._returns = list(func_outputs)
+ self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
+ self._output_shapes = output_shapes
+
+ def _compute_backprop(self):
+ """Computes the backprop function object for this function."""
+ self._has_backprop = True
+ with self._graph.as_default(), context.graph_mode():
+ c = _CapturingContext()
+ with c:
+ filtered_outputs = [
+ ag_core.getval(x) for x in self._returns if x is not None
+ ]
+ self._out_grad_placeholders = [
+ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
+ ]
+ in_gradients = gradients_impl.gradients(
+ filtered_outputs,
+ self._input_placeholders,
+ grad_ys=self._out_grad_placeholders)
+ shapes = [x.shape for x in in_gradients if x is not None]
+ captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
+ forward_function_def = graph_to_function_def.graph_to_function_def(
+ self._graph, self._ops, self._input_placeholders,
+ filtered_outputs + captures)
+ self._forward_fdef = _DefinedFunction(forward_function_def)
+ _register_with_name(_forward_name(self._func_name), forward_function_def)
+ backward_outputs = [x for x in in_gradients if x is not None]
+ all_inputs = self._out_grad_placeholders + captures
+ backward_function_def = graph_to_function_def.graph_to_function_def(
+ self._graph, [x.op for x in self._out_grad_placeholders
+ ] + list(sorted(c.known_ops, key=lambda x: x.name)),
+ all_inputs, backward_outputs)
+ _register_with_name(_backward_name(self._func_name), backward_function_def)
+ self._backward_function = _GraphModeFunction(
+ all_inputs, [], backward_function_def, self._graph, c.known_ops,
+ in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
+
+ def _backprop_call(self, args):
+ """Calls the wrapped function and records the result on a tape."""
+ all_args = args + self._extra_inputs
+ signature = self._forward_fdef.definition.signature
+ if context.in_graph_mode():
+ g = ops.get_default_graph()
+ g._add_function(self._forward_fdef) # pylint: disable=protected-access
+ unwrapped_args = [ag_core.getval(x) for x in all_args]
+ op = g.create_op(
+ signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
+ [dtypes.DType(x.type) for x in signature.output_arg],
+ op_def=signature,
+ name="FunctionCall",
+ compute_shapes=False)
+ outputs = op.outputs
+ outputs = [outputs] if isinstance(
+ outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
+ for i, s in enumerate(self._output_shapes):
+ outputs[i].set_shape(s)
+ else:
+ outputs = execute.execute(
+ signature.name,
+ num_outputs=len(signature.output_arg),
+ inputs=all_args)
+ real_outputs = outputs[:len(self._returns)]
+ side_outputs = outputs[len(self._returns):]
+ watched_extra_inputs = []
+ for t in self._extra_inputs:
+ tid = tape.tensor_id(t)
+ for t in tape._tape_stack.stack: # pylint: disable=protected-access
+ w = t.value.tensors.get(tid, None)
+ if w is not None:
+ watched_extra_inputs.append(w)
+ break
+ else: # Note: for-else here done on purpose
+ watched_extra_inputs.append(t)
+ real_outputs = tape.record_operation(real_outputs,
+ (args + watched_extra_inputs),
+ side_outputs, self._backward_function)
+
+ return self._build_call_outputs(self._returns, real_outputs)
+
+ def __call__(self, *args):
+ """Executes the passed function in eager mode."""
+ tensor_inputs = [
+ x for x in nest.flatten(args)
+ if isinstance(x, (tensor.Tensor, ops.Tensor,
+ tensor.LazyZero)) or ag_core.isnode(x)
+ ]
+ if tape.should_record(tensor_inputs) or any(
+ tape.any_tape_has(t) for t in self._extra_inputs):
+ if not self._has_backprop:
+ self._compute_backprop()
+ return self._backprop_call(tensor_inputs)
+
+ if context.in_graph_mode():
+ g = ops.get_default_graph()
+ g._add_function(self._fdef) # pylint: disable=protected-access
+ signature = self._fdef.definition.signature
+ args = list(tensor_inputs) + self._extra_inputs
+ op = g.create_op(
+ signature.name, [ops.convert_to_tensor(x) for x in args],
+ [dtypes.DType(x.type) for x in signature.output_arg],
+ op_def=signature,
+ name="FunctionCall",
+ compute_shapes=False)
+ result = op.outputs
+ for i, s in enumerate(self._output_shapes):
+ result[i].set_shape(s)
+ else:
+ tensor_inputs = [
+ x.tensor() if isinstance(x, tensor.LazyZero) else x
+ for x in tensor_inputs
+ ]
+ result = execute.execute(
+ self._func_name,
+ num_outputs=self._num_outputs,
+ inputs=tensor_inputs + self._extra_inputs)
+
+ return self._build_call_outputs(self._returns, result)
+
+ def _build_call_outputs(self, func_outputs, result):
+ """Maps the fdef output list to actual output structure.
+
+ Args:
+ func_outputs: The outputs originally defined by the graph function. It
+ could potentially be a nested structure.
+ result: Output lists defined by FunctionDef.
+ Returns:
+ The actual call output.
+ """
+ if self._func_outputs is None:
+ return None
+ if isinstance(ag_core.getval(self._func_outputs), ops.Tensor):
+ return result[0]
+
+ outputs = []
+ for o in func_outputs:
+ vo = ag_core.getval(o)
+ if isinstance(vo, ops.Tensor):
+ outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
+ elif type(vo) in (tuple, list):
+ outputs.append(self._build_call_outputs(o, result))
+ else:
+ outputs.append(o)
+
+ return tuple(outputs) if type(func_outputs) is tuple else outputs
+
+
+def _get_defun_inputs(args):
+ """Maps the inputs args to graph inputs."""
+ ret = []
+ for a in args:
+ a = ag_core.getval(a)
+ if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
+ ret.append(graph_placeholder(a.dtype, a.shape))
+ elif type(a) in (tuple, list):
+ ret.append(_get_defun_inputs(a))
+ else:
+ ret.append(a)
+ return tuple(ret) if type(args) is tuple else ret
+
+
+def _defun_internal(name, func, args, kwds):
+ """Defines and returns graph-mode version of func."""
+ with context.graph_mode():
+ tmp_graph = ops.Graph()
+ with tmp_graph.as_default():
+ func_inputs = _get_defun_inputs(args)
+
+ captures = {}
+ with capture_tensors(captures):
+ func_outputs = func(*func_inputs, **kwds)
+ ids = list(sorted(captures.keys()))
+ if ids:
+ extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
+ else:
+ extra_inputs = []
+ extra_placeholders = []
+ outputs_list = nest.flatten(func_outputs)
+ output_shapes = [x.shape for x in outputs_list if x is not None]
+
+ flat_inputs = [
+ x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
+ ]
+ all_inputs = flat_inputs + list(extra_placeholders)
+
+ func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
+ inference_function_def = graph_to_function_def.graph_to_function_def(
+ tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
+ # Register any other functions defined in the graph
+ # TODO(ashankar): Oh lord, forgive me for this lint travesty.
+ for f in tmp_graph._functions.values(): # pylint: disable=protected-access
+ # TODO(ashankar): What about the gradient registry?
+ _register_with_name(f.name, f.definition)
+ _register_with_name(_inference_name(name), inference_function_def)
+
+ return _GraphModeFunction(
+ all_inputs, extra_inputs, inference_function_def, tmp_graph,
+ tmp_graph.get_operations(), func_outputs,
+ _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
+
+
+# Defun uses this instead of Tensor as a cache key. Using dtype because
+# TensorFlow graphs are not parametric wrt dtypes, and using shapes for
+# performance reasons, as much TensorFlow code specializes on known shapes to
+# produce slimmer graphs.
+_TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"])
+_ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
+
+
+def _cache_key(x):
+ """Cache key for tfe functions."""
+ x = ag_core.getval(x)
+ if isinstance(x, tensor.Tensor):
+ return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
+ if isinstance(x, tensor.LazyZero):
+ return _TensorDtype(x.dtype, tuple(x.shape.as_list())) # pylint: disable=protected-access
+ if isinstance(x, np.ndarray):
+ return ("array", x.shape, tuple(x.reshape(-1)))
+ if type(x) in (list, tuple):
+ return tuple([_cache_key(a) for a in x])
+ return x
+
+
+def register_function_def(fdef):
+ fdef_string = fdef.SerializeToString()
+ with errors.raise_exception_on_not_ok_status() as status:
+ pywrap_tensorflow.TFE_ContextAddFunctionDef(
+ context.get_default_context()._handle, # pylint: disable=protected-access
+ fdef_string,
+ len(fdef_string),
+ status)
+
+
+def _register_with_name(name, fdef):
+ """Registers the function `fdef` with the name `name`."""
+ fdef.signature.name = name
+ register_function_def(fdef)
+
+
+# TODO(apassos): better error messages for non-hashable arguments.
+def named_defun(func, name):
+ """Defines a function with a given name.
+
+ See the documentation for `defun` for more information on the semantics of the
+ function.
+
+ Args:
+ func: the function to be wrapped.
+ name: the name given to it.
+
+ Returns:
+ the wrapped function.
+ """
+ arguments_to_functions = {}
+
+ def decorated(*args, **kwds):
+ """Decorated version of func."""
+ # Macroexpand on non-Tensor arguments
+ cache_key = tuple(_cache_key(x) for x in args)
+ assert all(not isinstance(x, tensor.Tensor) for x in kwds.values())
+ cache_key = (cache_key, tuple(kwds.items()))
+
+ if cache_key not in arguments_to_functions:
+ arguments_to_functions[cache_key] = _defun_internal(
+ name, func, args, kwds)
+ return arguments_to_functions[cache_key](*args)
+
+ return decorated
+
+
+def defun(func):
+ """Decorator to compile func into graph_mode.
+
+ defun converts a function that constructs a TensorFlow graph into a function
+ that executes the graph. TensorFlow graphs typically execute faster and with a
+ lower memory-footprint than executing each of the operations that make up the
+ function individually as the TensorFlow runtime can optimize the graph and
+ execute sub-operations in parallel.
+
+ func must be a Python function that constructs a TensorFlow graph,
+ typically using functions in the tensorflow module.
+
+ Arguments to func can be either tfe.Tensor objects or Python
+ objects. Non-Tensor python objects are treated as constants, and new function
+ definitions are created internally based on their values.
+
+ func must return a tf.Tensor (NOT a tfe.Tensor) or a list of tf.Tensor (NOT a
+ tfe.Tensor). TODO(apassos) make the wrapped tfe ops return tf.Tensors when in
+ graph mode.
+
+ TODO(apassos): deal with captured global state. Deal with control flow.
+
+ Args:
+ func: function to be compiled.
+
+ Returns:
+ A callable that will execute the compiled function (and return zero
+ or more tfe.Tensor objects)
+ """
+ return named_defun(func, func.__name__)
diff --git a/tensorflow/python/eager/gen_op.bzl b/tensorflow/python/eager/gen_op.bzl
new file mode 100644
index 0000000000..f9f6456bba
--- /dev/null
+++ b/tensorflow/python/eager/gen_op.bzl
@@ -0,0 +1,45 @@
+"""For eager-mode Python."""
+
+load("//tensorflow:tensorflow.bzl", "clean_dep", "tf_copts")
+
+def tfe_gen_op_wrapper_py(name,
+ out=None,
+ visibility=None,
+ deps=[],
+ generated_target_name=None):
+ """Generate an eager-mode Python op wrapper for an op library."""
+ # Construct a cc_binary containing the specified ops.
+ tool_name = "gen_" + name + "_py_wrappers_cc"
+ if not deps:
+ deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
+ native.cc_binary(
+ name=tool_name,
+ linkopts=["-lm"],
+ copts=tf_copts(),
+ linkstatic=1,
+ deps=([
+ clean_dep("//tensorflow/python/eager:python_eager_op_gen_main")
+ ] + deps),
+ visibility=[clean_dep("//visibility:public")],)
+
+ # Invoke the previous cc_binary to generate a python file.
+ if not out:
+ out = "gen_" + name + ".py"
+
+ native.genrule(
+ name=name + "_pygenrule",
+ outs=[out],
+ tools=[tool_name],
+ cmd=("$(location " + tool_name + ") > $@"))
+
+ # Make a py_library out of the generated python file.
+ if not generated_target_name:
+ generated_target_name = name
+ native.py_library(
+ name=generated_target_name,
+ srcs=[out],
+ srcs_version="PY2AND3",
+ visibility=visibility,
+ deps=[
+ clean_dep("//tensorflow/python/eager:framework_for_generated_wrappers"),
+ ],)
diff --git a/tensorflow/python/eager/graph_only_ops.py b/tensorflow/python/eager/graph_only_ops.py
new file mode 100644
index 0000000000..bd7d08faed
--- /dev/null
+++ b/tensorflow/python/eager/graph_only_ops.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Graph-only versions of a few op functions, for internal use only."""
+
+# Must be separate from array_ops to avoid a cyclic dependency.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import ops
+
+
+def graph_zeros_like(tensor):
+ """Graph-only version of tf.zeros_like(), for internal use only."""
+ g = ops._get_graph_from_inputs([tensor]) # pylint: disable=protected-access
+ with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name:
+ tensor = ops.convert_to_tensor(tensor, name="tensor")
+ dtype = tensor.dtype.base_dtype.as_datatype_enum
+ dtype_value = attr_value_pb2.AttrValue(type=dtype)
+ op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype],
+ attrs={"T": dtype_value}, name=name)
+ result, = op.outputs
+ return result
+
+
+def graph_placeholder(dtype, shape, name=None):
+ """Graph-only version of tf.placeholder(), for internal use only."""
+ dtype = dtype.base_dtype.as_datatype_enum
+ dtype_value = attr_value_pb2.AttrValue(type=dtype)
+ shape = attr_value_pb2.AttrValue(shape=shape.as_proto())
+ g = ops.get_default_graph()
+ with ops.name_scope(name, "placeholder", []) as name:
+ op = g.create_op("Placeholder", [], [dtype], input_types=[],
+ attrs={"dtype": dtype_value, "shape": shape}, name=name)
+ result, = op.outputs
+ return result
diff --git a/tensorflow/python/eager/memory_trace.py b/tensorflow/python/eager/memory_trace.py
new file mode 100644
index 0000000000..0baf922408
--- /dev/null
+++ b/tensorflow/python/eager/memory_trace.py
@@ -0,0 +1,88 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to trace per-device memory consumption across time over execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+TraceEntry = collections.namedtuple(
+ "TraceEntry", ["op_name", "tensor_id", "mem_usage", "device", "size"])
+TensorData = collections.namedtuple(
+ "TensorData", ["op_name", "tensor_size", "device"])
+
+
+class MemoryTrace(object):
+ """Records a trace of memory usage over operation execution."""
+
+ def __init__(self, n_devices):
+
+ self.trace = []
+ self.tensor_to_data = {}
+ self.current_device_mem_usage = [0] * n_devices
+
+ def record_tensor(self, op_name, tensor_id, device, size):
+ self.current_device_mem_usage[device] += size
+ self.tensor_to_data[tensor_id] = TensorData(op_name, size, device)
+ self.trace.append(TraceEntry(op_name,
+ tensor_id,
+ self.current_device_mem_usage[:],
+ device,
+ size))
+
+ def delete_tensor(self, tensor_id):
+ if tensor_id not in self.tensor_to_data:
+ return
+ data = self.tensor_to_data.pop(tensor_id)
+ self.current_device_mem_usage[data.device] -= data.tensor_size
+ self.trace.append(TraceEntry(data.op_name,
+ tensor_id,
+ self.current_device_mem_usage[:],
+ data.device,
+ -data.tensor_size))
+
+ def flush_trace(self):
+ """Prints the formatted trace recorded so far."""
+ longest_op_name = max(len(t.op_name) for t in self.trace)
+ longest_op_name = max(longest_op_name, len("op_name"))
+ longest_heap_size = max(max(len(str(d)) for d in t.mem_usage)
+ for t in self.trace)
+ longest_heap_size = max(longest_heap_size, len("d0"))
+ longest_id_len = max(len(str(t.tensor_id)) for t in self.trace)
+ longest_id_len = max(longest_id_len, 2)
+ first_line = []
+ first_line.append("+/-")
+ first_line.append("op_name".ljust(longest_op_name))
+ first_line.append("id".ljust(longest_id_len))
+ for i in range(len(self.current_device_mem_usage)):
+ first_line.append(("d"+str(i)).ljust(longest_heap_size))
+ first_line.append("size")
+ print(" | ".join(first_line))
+ for t in self.trace:
+ line = []
+ if t.size > 0:
+ line.append("+ ")
+ else:
+ line.append("- ")
+ line.append(t.op_name.ljust(longest_op_name))
+ line.append(str(t.tensor_id).ljust(longest_id_len))
+ for d in t.mem_usage:
+ line.append(str(d).ljust(longest_heap_size))
+ line.append(str(t.size))
+ print(" | ".join(line))
+ self.trace = []
+ print()
diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc
new file mode 100644
index 0000000000..493f549b35
--- /dev/null
+++ b/tensorflow/python/eager/python_eager_op_gen.cc
@@ -0,0 +1,763 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/eager/python_eager_op_gen.h"
+
+#include <stdio.h>
+#include <sstream>
+#include <unordered_map>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb_text.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/framework/tensor.pb_text.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/framework/python_op_gen_internal.h"
+
+namespace tensorflow {
+namespace {
+
+const int kRightMargin = 78;
+
+string AttrVarName(const string& attr_name,
+ std::unordered_map<string, string>* attr_expressions) {
+ const string var = strings::StrCat("_attr_", attr_name);
+ if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
+ return var;
+}
+
+void AddInferredAttr(const string& attr_name, const string& value_expression,
+ string* result,
+ std::unordered_map<string, string>* attr_expressions) {
+ strings::StrAppend(result, " ", AttrVarName(attr_name, attr_expressions),
+ " = ", value_expression, "\n");
+}
+
+string VectorToTuple(const std::vector<string>& l) {
+ if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
+ string ret = "(";
+ for (int i = 0; i < l.size(); ++i) {
+ if (i > 0) {
+ strings::StrAppend(&ret, ", ");
+ }
+ strings::StrAppend(&ret, l[i]);
+ }
+ strings::StrAppend(&ret, ")");
+ return ret;
+}
+
+void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
+ const string& var, string* result) {
+ for (int i = 0; i < output_sizes.size(); ++i) {
+ if (!output_sizes[i].empty()) {
+ strings::StrAppend(result, prefix, var, " = ");
+ if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
+ if (i + 1 < output_sizes.size()) {
+ // Special case i == 0 to avoid "0 +" in the generated code.
+ if (i == 0) {
+ strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
+ var, "[", output_sizes[i], ":]");
+ } else {
+ strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
+ output_sizes[i], "]] + ", var, "[", i, " + ",
+ output_sizes[i], ":]");
+ }
+ } else {
+ strings::StrAppend(result, "[", var, "[", i, ":]]");
+ }
+ strings::StrAppend(result, "\n");
+ }
+ }
+}
+
+string TensorPBString(const TensorProto& pb) {
+ // Note: This gets used in the argument list, and so must survive naive
+ // word wrapping.
+ return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
+}
+
+class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
+ public:
+ GenEagerPythonOp(const OpDef& op_def, const string& function_name)
+ : python_op_gen_internal::GenPythonOp(op_def, function_name) {
+ op_name_ = function_name_;
+ op_name_.Consume("_");
+ }
+ ~GenEagerPythonOp() override {}
+
+ string Code() override;
+
+ protected:
+ void ExpectListArg(const string& arg_name);
+ void AddEagerInferredAttrs();
+ void AddEagerInputCasts();
+ void AddEagerAttrs();
+ void AddEagerExecute(const string& num_outputs_expr);
+
+ void AddAttrForArg(const string& attr, int arg_index) {
+ gtl::InsertIfNotPresent(&inferred_attrs_, attr,
+ op_def_.input_arg(arg_index).name());
+ auto iter = attr_to_args_.find(attr);
+ if (iter == attr_to_args_.end()) {
+ attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
+ } else {
+ iter->second.push_back(arg_index);
+ }
+ }
+
+ // Returns a string expression representing a flattened list of all
+ // the inputs given by `*input_indices` (or all inputs if
+ // `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
+ string FlattenInputs(const std::vector<int>* input_indices,
+ std::vector<string>* output_sizes) const;
+
+ StringPiece op_name_;
+ typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
+ AttrToArgMap attr_to_args_;
+ std::unordered_map<string, string> attr_expressions_;
+};
+
+string GetEagerPythonOp(const OpDef& op_def, const string& function_name) {
+ return GenEagerPythonOp(op_def, function_name).Code();
+}
+
+string GenEagerPythonOp::FlattenInputs(
+ const std::vector<int>* input_indices,
+ std::vector<string>* output_sizes) const {
+ string inputs;
+ enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
+ const int n = input_indices != nullptr ? input_indices->size()
+ : op_def_.input_arg_size();
+ for (int j = 0; j < n; ++j) {
+ const int i = input_indices ? (*input_indices)[j] : j;
+ const auto& arg(op_def_.input_arg(i));
+ const bool is_list =
+ !arg.type_list_attr().empty() || !arg.number_attr().empty();
+ if (is_list) {
+ if (inputs_state == WAS_SOLO_INPUT) {
+ strings::StrAppend(&inputs, "] + ");
+ } else if (inputs_state == WAS_LIST_INPUT) {
+ strings::StrAppend(&inputs, " + ");
+ }
+ strings::StrAppend(&inputs, "list(", param_names_[i], ")");
+ inputs_state = WAS_LIST_INPUT;
+ if (output_sizes != nullptr) {
+ if (!arg.number_attr().empty()) {
+ output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
+ } else {
+ output_sizes->emplace_back(
+ strings::StrCat("len(", param_names_[i], ")"));
+ }
+ }
+ } else {
+ if (inputs_state == WAS_SOLO_INPUT) {
+ strings::StrAppend(&inputs, ", ");
+ } else if (inputs_state == WAS_LIST_INPUT) {
+ strings::StrAppend(&inputs, " + [");
+ } else {
+ strings::StrAppend(&inputs, "[");
+ }
+ strings::StrAppend(&inputs, param_names_[i]);
+ inputs_state = WAS_SOLO_INPUT;
+ if (output_sizes != nullptr) output_sizes->emplace_back();
+ }
+ }
+ if (inputs_state == STARTING) return "[]";
+ if (inputs_state == WAS_SOLO_INPUT) {
+ strings::StrAppend(&inputs, "]");
+ }
+ return inputs;
+}
+
+string GenEagerPythonOp::Code() {
+ // This has all the input args followed by those attrs that don't have
+ // defaults.
+ std::vector<string> args_no_default;
+ // The parameters with defaults (these have to be listed after those without).
+ // No input args are included, just attrs.
+ std::vector<std::pair<string, string>> args_with_defaults;
+ for (int i = 0; i < op_def_.input_arg_size(); ++i) {
+ const auto& arg(op_def_.input_arg(i));
+ args_no_default.push_back(arg.name());
+ if (!arg.type_attr().empty()) {
+ AddAttrForArg(arg.type_attr(), i);
+ } else if (!arg.type_list_attr().empty()) {
+ AddAttrForArg(arg.type_list_attr(), i);
+ }
+ if (!arg.number_attr().empty()) {
+ AddAttrForArg(arg.number_attr(), i);
+ }
+ }
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ const auto& attr(op_def_.attr(i));
+ // Do not add inferred attrs to the Python function signature.
+ if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
+ if (attr.has_default_value()) {
+ if (attr.type() == "tensor") {
+ args_with_defaults.emplace_back(
+ attr.name(),
+ strings::StrCat("_execute.make_tensor(",
+ TensorPBString(attr.default_value().tensor()),
+ ", \"", attr.name(), "\")"));
+ } else if (attr.type() == "list(tensor)") {
+ std::vector<string> pbtxt;
+ for (const auto& pb : attr.default_value().list().tensor()) {
+ pbtxt.emplace_back(TensorPBString(pb));
+ }
+ args_with_defaults.emplace_back(
+ attr.name(),
+ strings::StrCat("[_execute.make_tensor(_pb, \"", attr.name(),
+ "\") for _pb in ", VectorToTuple(pbtxt), "]"));
+ } else {
+ args_with_defaults.emplace_back(
+ attr.name(), python_op_gen_internal::AttrValueToPython(
+ attr.type(), attr.default_value(), "_dtypes."));
+ }
+ } else {
+ args_no_default.push_back(attr.name());
+ }
+ }
+ }
+
+ // Save the list of attr parameters (attrs that won't be inferred),
+ // those with defaults go at the end.
+ // Get the attrs in the order we want by taking the attrs without defaults
+ // from the end of args_no_default, and adding args_no_default.
+ attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
+ args_with_defaults.size());
+ attrs_.insert(attrs_.end(),
+ args_no_default.begin() + op_def_.input_arg_size(),
+ args_no_default.end());
+ for (const auto& a : args_with_defaults) {
+ attrs_.push_back(a.first);
+ }
+
+ param_names_.reserve(args_no_default.size() + args_with_defaults.size());
+ string parameters;
+ for (const string& name : args_no_default) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ const string param = python_op_gen_internal::AvoidPythonReserved(name);
+ strings::StrAppend(&parameters, param);
+ param_names_.push_back(param);
+ }
+ for (const auto& name_default : args_with_defaults) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ const string param =
+ python_op_gen_internal::AvoidPythonReserved(name_default.first);
+ strings::StrAppend(&parameters, param, "=", name_default.second);
+ param_names_.push_back(param);
+ }
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ strings::StrAppend(&parameters, "name=None");
+
+ AddDefLine(parameters);
+ AddDocStringDescription();
+ AddDocStringArgs();
+ AddDocStringInputs();
+ AddDocStringAttrs();
+ strings::StrAppend(
+ &result_,
+ " name: A name for the operation (optional, only for graph mode).\n");
+ AddOutputGlobals();
+ AddDocStringOutputs();
+ strings::StrAppend(&result_, " \"\"\"\n");
+
+ // Function body.
+
+ // Validate list inputs, infer length attrs.
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ const auto& attr(op_def_.attr(i));
+ if (attr.type() == "int") {
+ auto arg_list = attr_to_args_.find(attr.name());
+ if (arg_list != attr_to_args_.end()) {
+ // Inferred int attrs are the lengths of inputs. Validate those
+ // inputs are lists and have the same length.
+ for (auto iter = arg_list->second.begin();
+ iter != arg_list->second.end(); ++iter) {
+ const string& arg_name = param_names_[*iter];
+ ExpectListArg(arg_name);
+ if (iter == arg_list->second.begin()) {
+ AddInferredAttr(attr.name(), strings::StrCat("len(", arg_name, ")"),
+ &result_, &attr_expressions_);
+ } else {
+ const auto& attr_var = attr_expressions_[attr.name()];
+ strings::StrAppend(&result_, " if len(", arg_name,
+ ") != ", attr_var,
+ ":\n"
+ " raise ValueError(\n"
+ " \"List argument '",
+ arg_name, "' to '", op_name_,
+ "' Op with length %d \"\n"
+ " \"must match length %d of argument '",
+ inferred_attrs_[attr.name()],
+ "'.\" %\n"
+ " (len(",
+ arg_name, "), ", attr_var, "))\n");
+ }
+ }
+ }
+ }
+ }
+
+ // Values for non-inferred attrs.
+ for (int i = 0; i < attrs_.size(); ++i) {
+ const string& attr_name = attrs_[i];
+ const string& param = param_names_[i + op_def_.input_arg_size()];
+ const auto& attr = *FindAttr(attr_name, op_def_);
+ StringPiece attr_type = attr.type();
+ attr_expressions_[attr_name] = param;
+ const int default_index = i - (attrs_.size() - args_with_defaults.size());
+ if (default_index >= 0) {
+ const string& default_value = args_with_defaults[default_index].second;
+ strings::StrAppend(&result_, " if ", param, " is None:\n");
+ strings::StrAppend(&result_, " ", param, " = ", default_value, "\n");
+ }
+ if (attr_type.starts_with("list(")) {
+ ExpectListArg(param);
+ }
+
+ if (attr_type == "string") {
+ strings::StrAppend(&result_, " ", param, " = _execute.make_str(", param,
+ ", \"", param, "\")\n");
+ } else if (attr_type == "list(string)") {
+ strings::StrAppend(&result_, " ", param, " = [_execute.make_str(_s, \"",
+ param, "\") for _s in ", param, "]\n");
+ } else if (attr_type == "int") {
+ strings::StrAppend(&result_, " ", param, " = _execute.make_int(", param,
+ ", \"", param, "\")\n");
+ } else if (attr_type == "list(int)") {
+ strings::StrAppend(&result_, " ", param, " = [_execute.make_int(_i, \"",
+ param, "\") for _i in ", param, "]\n");
+ } else if (attr_type == "float") {
+ strings::StrAppend(&result_, " ", param, " = _execute.make_float(",
+ param, ", \"", param, "\")\n");
+ } else if (attr_type == "list(float)") {
+ strings::StrAppend(&result_, " ", param,
+ " = [_execute.make_float(_f, \"", param,
+ "\") for _f in ", param, "]\n");
+ } else if (attr_type == "bool") {
+ strings::StrAppend(&result_, " ", param, " = _execute.make_bool(", param,
+ ", \"", param, "\")\n");
+ } else if (attr_type == "list(bool)") {
+ strings::StrAppend(&result_, " ", param, " = [_execute.make_bool(_b, \"",
+ param, "\") for _b in ", param, "]\n");
+ } else if (attr_type == "type") {
+ strings::StrAppend(&result_, " ", param, " = _execute.make_type(", param,
+ ", \"", param, "\")\n");
+ } else if (attr_type == "list(type)") {
+ strings::StrAppend(&result_, " ", param, " = [_execute.make_type(_t, \"",
+ param, "\") for _t in ", param, "]\n");
+ } else if (attr_type == "shape") {
+ strings::StrAppend(&result_, " ", param, " = _execute.make_shape(",
+ param, ", \"", param, "\")\n");
+ } else if (attr_type == "list(shape)") {
+ strings::StrAppend(&result_, " ", param,
+ " = [_execute.make_shape(_s, \"", param,
+ "\") for _s in ", param, "]\n");
+ } else if (attr_type == "tensor") {
+ strings::StrAppend(&result_, " ", param, " = _execute.make_tensor(",
+ param, ", \"", param, "\")\n");
+ } else if (attr_type == "list(tensor)") {
+ strings::StrAppend(&result_, " ", param,
+ " = [_execute.make_tensor(_t, \"", param,
+ "\") for _t in ", param, "]\n");
+ } else if (attr_type != "func") {
+ return strings::StrCat("# No definition for ", function_name_,
+ " since we don't support attrs with type\n"
+ "# '",
+ attr_type, "' right now.\n\n");
+ }
+ }
+
+ // Figure out the list of inputs.
+ const string inputs = FlattenInputs(nullptr, nullptr);
+
+ // Handle graph-mode case
+ strings::StrAppend(&result_,
+ " if _context.in_graph_mode():\n"
+ " _, _, _op = _op_def_lib._apply_op_helper(\n");
+ AddBodyNoReturn(" ");
+ if (num_outs_ > 0) {
+ strings::StrAppend(&result_, " _result = _op.outputs[:]\n");
+ // Special case handling for stateful op with single list output
+ // that might be empty.
+ if (num_outs_ == 1 && op_def_.is_stateful() &&
+ (!op_def_.output_arg(0).number_attr().empty() ||
+ !op_def_.output_arg(0).type_list_attr().empty())) {
+ // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
+ // a constraint indicating that this can never be empty.
+ strings::StrAppend(&result_,
+ " if not _result:\n"
+ " return _op\n");
+ }
+ strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n");
+
+ // Compute graph-mode attrs.
+ if (op_def_.attr_size() > 0) {
+ string attr_values;
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ if (i > 0) strings::StrAppend(&attr_values, ", ");
+ const auto& attr_name(op_def_.attr(i).name());
+ strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
+ attr_name, "\")");
+ }
+ strings::StrAppend(&attr_values, ")");
+ strings::StrAppend(&result_,
+ WordWrap(" _attrs = (", attr_values, kRightMargin),
+ "\n");
+ } else {
+ strings::StrAppend(&result_, " _attrs = None\n");
+ }
+ } else {
+ strings::StrAppend(&result_, " return _op\n");
+ }
+
+ // Handle eager-mode case
+ strings::StrAppend(&result_, " else:\n");
+
+ // Expression representing the number of outputs.
+ int num_fixed_outputs = 0;
+ string num_outputs_expr;
+ // If output i is list output, output_sizes[i] will be set to a
+ // string with the python expression that will evaluate to its
+ // length. output_sizes[i] is empty for non-list outputs.
+ std::vector<string> output_sizes(num_outs_);
+ for (int i = 0; i < num_outs_; ++i) {
+ const auto& arg(op_def_.output_arg(i));
+ if (!arg.number_attr().empty()) {
+ if (!num_outputs_expr.empty()) {
+ strings::StrAppend(&num_outputs_expr, " + ");
+ }
+ output_sizes[i] = attr_expressions_[arg.number_attr()];
+ strings::StrAppend(&num_outputs_expr, output_sizes[i]);
+ } else if (!arg.type_list_attr().empty()) {
+ if (!num_outputs_expr.empty()) {
+ strings::StrAppend(&num_outputs_expr, " + ");
+ }
+ // Have to be careful to use an expression that works in both
+ // graph and eager paths here.
+ const auto iter = inferred_attrs_.find(arg.type_list_attr());
+ if (iter == inferred_attrs_.end()) {
+ output_sizes[i] = strings::StrCat(
+ "len(", attr_expressions_[arg.type_list_attr()], ")");
+ } else {
+ output_sizes[i] = strings::StrCat("len(", iter->second, ")");
+ }
+ strings::StrAppend(&num_outputs_expr, output_sizes[i]);
+ } else {
+ ++num_fixed_outputs;
+ }
+ }
+ if (num_fixed_outputs > 0) {
+ if (!num_outputs_expr.empty()) {
+ strings::StrAppend(&num_outputs_expr, " + ");
+ }
+ strings::StrAppend(&num_outputs_expr, num_fixed_outputs);
+ } else if (num_outputs_expr.empty()) {
+ num_outputs_expr = "0";
+ }
+
+ bool eager_allowed = true;
+ for (const auto& arg : op_def_.input_arg()) {
+ if (arg.is_ref()) eager_allowed = false;
+ }
+ for (const auto& arg : op_def_.output_arg()) {
+ if (arg.is_ref()) eager_allowed = false;
+ }
+
+ if (eager_allowed) {
+ AddEagerInferredAttrs();
+ AddEagerInputCasts();
+ strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n");
+ AddEagerAttrs();
+ AddEagerExecute(num_outputs_expr);
+ } else {
+ strings::StrAppend(&result_,
+ " raise RuntimeError(\n"
+ " \"",
+ op_name_, " op does not support eager execution.\")\n");
+ }
+
+ if (num_outs_ > 0) {
+ strings::StrAppend(&result_, " _result = _execute.record_gradient(\n",
+ " \"", op_def_.name(),
+ "\", _inputs_flat, _attrs, _result, name)\n");
+ if (num_outs_ == 1 && !output_sizes[0].empty()) {
+ // Single list result.
+ } else if (num_outs_ == 1) {
+ // Execute returns a single-element list which we need to destructure.
+ strings::StrAppend(&result_, " _result, = _result\n");
+ } else {
+ // Have multiple outputs, so we will need to reformat the return
+ // value of execute() to be a list with one entry per op output
+ // (that entry will be a list of tensors if that output is of list
+ // type).
+ // For list outputs, convert the right subrange of _result into a list.
+ Unflatten(" ", output_sizes, "_result", &result_);
+ // Convert to a named tuple.
+ strings::StrAppend(&result_, " _result = _", op_def_.name(),
+ "Output._make(_result)\n");
+ }
+ }
+ strings::StrAppend(&result_, " return _result\n\n");
+ return prelude_ + result_;
+}
+
+void GenEagerPythonOp::ExpectListArg(const string& arg_name) {
+ strings::StrAppend(&result_, " if not isinstance(", arg_name,
+ ", (list, tuple)):\n"
+ " raise TypeError(\n"
+ " \"Expected list for '",
+ arg_name,
+ "' argument to \"\n"
+ " \"'",
+ op_name_, "' Op, not %r.\" % ", arg_name, ")\n");
+}
+
+void GenEagerPythonOp::AddEagerInferredAttrs() {
+ // Figure out values for inferred attrs, and cast to eager tensors.
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ const auto& attr(op_def_.attr(i));
+ auto arg_list = attr_to_args_.find(attr.name());
+ if (arg_list != attr_to_args_.end()) {
+ if (attr.type() == "type") {
+ std::vector<string> output_sizes;
+ const string flattened =
+ FlattenInputs(&arg_list->second, &output_sizes);
+ string conversion =
+ strings::StrCat("_execute.args_to_matching_eager(", flattened);
+ if (attr.has_default_value()) {
+ strings::StrAppend(
+ &conversion, ", ",
+ python_op_gen_internal::AttrValueToPython(
+ attr.type(), attr.default_value(), "_dtypes."));
+ }
+ strings::StrAppend(&conversion, ")");
+ const string var_name = AttrVarName(attr.name(), &attr_expressions_);
+ if (output_sizes.size() == 1) {
+ // Avoid creating a temporary variable in the case where
+ // we can easily assign to the right value directly.
+ const string inputs_var = param_names_[arg_list->second.front()];
+ if (output_sizes.front().empty()) {
+ strings::StrAppend(&result_, " ", var_name, ", (", inputs_var,
+ ",) = ", conversion, "\n");
+ } else {
+ strings::StrAppend(&result_, " ", var_name, ", ", inputs_var,
+ " = ", conversion, "\n");
+ }
+ } else {
+ const string inputs_var = strings::StrCat("_inputs_", attr.name());
+ strings::StrAppend(&result_, " ", var_name, ", ", inputs_var,
+ " = ", conversion, "\n");
+ // Convert from a flat list of eager tensors back to the
+ // parameter variables.
+ Unflatten(" ", output_sizes, inputs_var, &result_);
+ std::vector<string> p;
+ for (int j : arg_list->second) {
+ p.emplace_back(param_names_[j]);
+ }
+ strings::StrAppend(&result_, " ", VectorToTuple(p), " = ",
+ inputs_var, "\n");
+ }
+ strings::StrAppend(&result_, " ", var_name, " = ", var_name,
+ ".as_datatype_enum\n");
+ } else if (attr.type() == "list(type)") {
+ // NOTE: We ignore default values for these attrs, since it is
+ // unclear how you would use it, and the one use case is
+ // parse_single_sequence_example which only needs it for
+ // backwards compatibility.
+ const string var_name = AttrVarName(attr.name(), &attr_expressions_);
+ string inputs_var;
+ string conversion;
+ if (arg_list->second.size() > 1) {
+ // If you have more than one list(tensor) argument, their types
+ // have to match.
+ std::vector<string> lists;
+ for (auto iter = arg_list->second.begin();
+ iter != arg_list->second.end(); ++iter) {
+ lists.push_back(param_names_[*iter]);
+ }
+ inputs_var = VectorToTuple(lists);
+ conversion = "_execute.args_to_mixed_eager_tensors";
+ } else {
+ // For one list(tensor) argument, we just convert every
+ // element of the list to an eager tensor.
+ inputs_var = param_names_[arg_list->second.front()];
+ conversion = "_execute.convert_to_mixed_eager_tensors";
+ }
+ strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ",
+ conversion, "(", inputs_var, ")\n");
+ strings::StrAppend(&result_, " ", var_name,
+ " = [_t.as_datatype_enum for _t in ", var_name,
+ "]\n");
+ }
+ }
+ }
+}
+
+void GenEagerPythonOp::AddEagerInputCasts() {
+ // Cast remaining args to eager tensors
+ for (int i = 0; i < op_def_.input_arg_size(); ++i) {
+ const auto& arg(op_def_.input_arg(i));
+ if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
+ const string& param = param_names_[i];
+ const string fn = arg.number_attr().empty() ? "" : "n_";
+ const string dtype =
+ python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
+ strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn,
+ "to_eager_tensor(", param, ", ", dtype, ")\n");
+ }
+}
+
+void GenEagerPythonOp::AddEagerAttrs() {
+ // Compute eager attrs
+ if (op_def_.attr_size() > 0) {
+ string attr_values;
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ if (i > 0) strings::StrAppend(&attr_values, ", ");
+ const auto& attr_name(op_def_.attr(i).name());
+ strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
+ attr_expressions_[attr_name]);
+ }
+ strings::StrAppend(&attr_values, ")");
+ strings::StrAppend(
+ &result_, WordWrap(" _attrs = (", attr_values, kRightMargin), "\n");
+ } else {
+ strings::StrAppend(&result_, " _attrs = None\n");
+ }
+}
+
+void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
+ const string return_prefix = " _result = _execute.execute(";
+ const string return_args =
+ strings::StrCat("\"", op_def_.name(), "\", ", num_outputs_expr,
+ ", inputs=_inputs_flat, attrs=_attrs, name=name)");
+ strings::StrAppend(&result_,
+ // Wrap the arguments, and indent to the (.
+ WordWrap(return_prefix, return_args, kRightMargin), "\n");
+}
+
+string GetEagerPythonOps(const OpList& ops,
+ const std::vector<string>& hidden_ops,
+ bool require_shapes) {
+ string result;
+ // Header
+ // TODO(josh11b): Mention the library for which wrappers are being generated.
+ strings::StrAppend(&result, R"("""Python wrappers for TensorFlow ops.
+
+This file is MACHINE GENERATED! Do not edit.
+"""
+
+import collections as _collections
+
+from tensorflow.python.eager import execute as _execute
+from tensorflow.python.eager import context as _context
+from tensorflow.python.eager import core as _core
+from tensorflow.python.eager import tensor as _tensor
+from tensorflow.python.framework import dtypes as _dtypes
+from tensorflow.python.framework import tensor_shape as _tensor_shape
+
+from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
+# Needed to trigger the call to _set_call_cpp_shape_fn.
+from tensorflow.python.framework import common_shapes as _common_shapes
+from tensorflow.python.framework import op_def_registry as _op_def_registry
+from tensorflow.python.framework import ops as _ops
+from tensorflow.python.framework import op_def_library as _op_def_library
+
+)");
+
+ // We'll make a copy of ops that filters out descriptions.
+ OpList cleaned_ops;
+ auto out = cleaned_ops.mutable_op();
+ out->Reserve(ops.op_size());
+ for (const auto& op_def : ops.op()) {
+ bool is_hidden = false;
+ for (const string& hidden : hidden_ops) {
+ if (op_def.name() == hidden) {
+ is_hidden = true;
+ break;
+ }
+ }
+
+ string function_name;
+ python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
+ &function_name);
+ if (is_hidden) function_name = strings::StrCat("_", function_name);
+
+ // When users create custom python wrappers, they may link in the
+ // default op registry by accident, and because they can't
+ // enumerate all 'hidden' symbols, this guard is to prevent
+ // instantiating a python reserved word in their wrapper.
+ if (python_op_gen_internal::IsPythonReserved(function_name)) {
+ continue;
+ }
+
+ strings::StrAppend(&result, GetEagerPythonOp(op_def, function_name));
+
+ if (!require_shapes) {
+ strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
+ "\")(None)\n\n");
+ }
+
+ auto added = out->Add();
+ *added = op_def;
+ RemoveNonDeprecationDescriptionsFromOpDef(added);
+ }
+
+ result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
+ op_list = _op_def_pb2.OpList()
+ op_list.ParseFromString(op_list_proto_bytes)
+ _op_def_registry.register_op_list(op_list)
+ op_def_lib = _op_def_library.OpDefLibrary()
+ op_def_lib.add_op_list(op_list)
+ return op_def_lib
+)");
+
+ result.append("# ");
+ auto ops_text = ProtoDebugString(cleaned_ops);
+ str_util::StripTrailingWhitespace(&ops_text);
+ result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
+ result.append("\n");
+ strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
+ str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
+ return result;
+}
+
+} // namespace
+
+void PrintEagerPythonOps(const OpList& ops,
+ const std::vector<string>& hidden_ops,
+ bool require_shapes) {
+ printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes).c_str());
+}
+
+string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
+ string op_list_str(op_list_buf, op_list_len);
+ OpList ops;
+ ops.ParseFromString(op_list_str);
+ return GetEagerPythonOps(ops, {}, false);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h
new file mode 100644
index 0000000000..9a7ed28cf9
--- /dev/null
+++ b/tensorflow/python/eager/python_eager_op_gen.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
+#define THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// hidden_ops should be a list of Op names that should get a leading _
+// in the output. Prints the output to stdout.
+void PrintEagerPythonOps(const OpList& ops,
+ const std::vector<string>& hidden_ops,
+ bool require_shapes);
+
+// Get the python wrappers for a list of ops in a OpList.
+// `op_list_buf` should be a pointer to a buffer containing
+// the binary encoded OpList proto, and `op_list_len` should be the
+// length of that buffer.
+string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
diff --git a/tensorflow/python/eager/python_eager_op_gen_main.cc b/tensorflow/python/eager/python_eager_op_gen_main.cc
new file mode 100644
index 0000000000..9e4aa97ccc
--- /dev/null
+++ b/tensorflow/python/eager/python_eager_op_gen_main.cc
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/eager/python_eager_op_gen.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/platform/init_main.h"
+
+namespace tensorflow {
+namespace {
+
+void PrintAllPythonOps(const std::vector<string>& hidden_ops) {
+ OpList ops;
+ OpRegistry::Global()->Export(false, &ops);
+ PrintEagerPythonOps(ops, hidden_ops, true /* require_shapes */);
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ if (argc == 1) {
+ tensorflow::PrintAllPythonOps({});
+ } else {
+ return -1;
+ }
+ return 0;
+}
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
new file mode 100644
index 0000000000..bd7f445055
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -0,0 +1,67 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
+#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include <Python.h>
+
+typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
+ TFE_InputTensorHandles;
+typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2>
+ TFE_OutputTensorHandles;
+
+// Execute a TensorFlow operation.
+//
+// 'device_name': Name of the device on which to execute the operation, or NULL
+// for automatic selection.
+// 'op_name': Name of the TensorFlow op to execute.
+// 'inputs': An array of TFE_TensorHandle*'s of size 'num_inputs'. These tensors
+// will be provided as input to the operation.
+// 'attrs': A Python tuple alternating names and attr values.
+// 'outputs': A pointer to a TFE_OutputTensorHandles in which outputs will
+// placed. On success, its elements will be filled in and the
+// caller takes ownership of each returned TFE_TensorHandle.
+// 'outputs' MUST be sized to be at least as large as the number
+// of tensors produced by the operation and will be resized to
+// the actual number of tensors produced.
+void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
+ const char* op_name, TFE_InputTensorHandles* inputs,
+ PyObject* attrs, TFE_OutputTensorHandles* outputs,
+ TF_Status* out_status);
+
+// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
+//
+// The two may share underlying storage so changes to one may reflect in the
+// other.
+PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status);
+
+// Convert a Python numpy.ndarray object to a TFE_TensorHandle.
+//
+// The two may share underlying storage so changes to one may reflect in the
+// other.
+TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj);
+
+// Registers e as the Exception class for handling not ok Status. Returns
+// Py_None if registration succeeds, else throws a TypeError and returns NULL.
+PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
+
+// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using the
+// class registered via TFE_Py_RegisterExceptionClass) and returns -1.
+int TFE_Py_MayBeRaiseException(TF_Status* status);
+
+#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
new file mode 100644
index 0000000000..507fdfb35d
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -0,0 +1,377 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Must be included first.
+#include "tensorflow/python/lib/core/numpy.h"
+
+#include "tensorflow/python/eager/pywrap_tfe.h"
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/python/lib/core/py_func.h"
+
+using tensorflow::string;
+
+namespace {
+
+#define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
+ bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
+ type* value) { \
+ if (check_fn(py_value)) { \
+ *value = static_cast<type>(parse_fn(py_value)); \
+ return true; \
+ } else { \
+ TF_SetStatus(status, TF_INVALID_ARGUMENT, \
+ tensorflow::strings::StrCat( \
+ "Expecting " #type " value for attr ", key, ", got ", \
+ py_value->ob_type->tp_name) \
+ .c_str()); \
+ return false; \
+ } \
+ }
+
+#if PY_MAJOR_VERSION >= 3
+PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
+PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
+PARSE_VALUE(ParseStringValue, const char*, PyUnicode_Check, PyUnicode_AsUTF8)
+#else
+PARSE_VALUE(ParseStringValue, const char*, PyString_Check, PyString_AsString)
+PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
+PARSE_VALUE(ParseInt64Value, int64_t, PyInt_Check, PyInt_AsLong)
+#endif
+PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
+
+#undef PARSE_VALUE
+
+bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
+ unsigned char* value) {
+ *value = PyObject_IsTrue(py_value);
+ return true;
+}
+
+const char* ParseProtoValue(const string& key, const char* proto_name,
+ PyObject* py_value, size_t* size,
+ TF_Status* status) {
+ char* output = nullptr;
+ Py_ssize_t py_size;
+#if PY_MAJOR_VERSION >= 3
+ if (!PyUnicode_Check(py_value) ||
+ (output = PyUnicode_AsUTF8AndSize(py_value, &py_size)) == nullptr) {
+#else
+ if (!PyString_Check(py_value) ||
+ (PyString_AsStringAndSize(py_value, &output, &py_size) < 0)) {
+#endif
+ TF_SetStatus(
+ status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat("Expecting a string (serialized ",
+ proto_name, ") value for attr ", key)
+ .c_str());
+ return nullptr;
+ }
+ *size = static_cast<size_t>(py_size);
+ return output;
+}
+
+bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list,
+ TF_AttrType type, TF_Status* status) {
+ if (!PySequence_Check(py_list)) {
+ TF_SetStatus(
+ status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
+ ", got ", py_list->ob_type->tp_name)
+ .c_str());
+ return false;
+ }
+ const int num_values = PySequence_Size(py_list);
+
+#define PARSE_LIST(c_type, parse_fn) \
+ std::unique_ptr<c_type[]> values(new c_type[num_values]); \
+ for (int i = 0; i < num_values; ++i) { \
+ auto py_value = PySequence_ITEM(py_list, i); \
+ if (!parse_fn(key, py_value, status, &values[i])) return false; \
+ }
+
+ if (type == TF_ATTR_STRING) {
+ PARSE_LIST(const char*, ParseStringValue);
+ TFE_OpSetAttrStringList(op, key, values.get(), num_values);
+ } else if (type == TF_ATTR_INT) {
+ PARSE_LIST(int64_t, ParseInt64Value);
+ TFE_OpSetAttrIntList(op, key, values.get(), num_values);
+ } else if (type == TF_ATTR_FLOAT) {
+ PARSE_LIST(float, ParseFloatValue);
+ TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
+ } else if (type == TF_ATTR_BOOL) {
+ PARSE_LIST(unsigned char, ParseBoolValue);
+ TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
+ } else if (type == TF_ATTR_TYPE) {
+ PARSE_LIST(int, ParseIntValue);
+ TFE_OpSetAttrTypeList(op, key,
+ reinterpret_cast<const TF_DataType*>(values.get()),
+ num_values);
+ } else if (type == TF_ATTR_SHAPE) {
+ // Make one pass through the input counting the total number of
+ // dims across all the input lists.
+ int total_dims = 0;
+ for (int i = 0; i < num_values; ++i) {
+ auto py_value = PySequence_ITEM(py_list, i);
+ if (py_value != Py_None) {
+ if (!PySequence_Check(py_value)) {
+ TF_SetStatus(
+ status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat(
+ "Expecting None or sequence value for element", i,
+ " of attr ", key, ", got ", py_value->ob_type->tp_name)
+ .c_str());
+ return false;
+ }
+ const auto size = PySequence_Size(py_value);
+ total_dims += size;
+ }
+ }
+ // Allocate a buffer that can fit all of the dims together.
+ std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
+ // Copy the input dims into the buffer and set dims to point to
+ // the start of each list's dims.
+ std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]);
+ std::unique_ptr<int[]> num_dims(new int[num_values]);
+ int64_t* offset = buffer.get();
+ for (int i = 0; i < num_values; ++i) {
+ auto py_value = PySequence_ITEM(py_list, i);
+ if (py_value == Py_None) {
+ dims[i] = nullptr;
+ num_dims[i] = -1;
+ } else {
+ const auto size = PySequence_Size(py_value);
+ dims[i] = offset;
+ num_dims[i] = size;
+ for (int j = 0; j < size; ++j) {
+ auto inner_py_value = PySequence_ITEM(py_value, j);
+ if (inner_py_value == Py_None) {
+ *offset = -1;
+ } else if (!ParseInt64Value(key, inner_py_value, status, offset)) {
+ return false;
+ }
+ ++offset;
+ }
+ }
+ }
+ TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
+ status);
+ if (TF_GetCode(status) != TF_OK) return false;
+ } else {
+ TF_SetStatus(status, TF_UNIMPLEMENTED,
+ tensorflow::strings::StrCat("Attr ", key,
+ " has unhandled list type ", type)
+ .c_str());
+ return false;
+ }
+#undef PARSE_LIST
+ return true;
+}
+
+bool SetOpAttrScalar(TFE_Op* op, const char* key, PyObject* py_value,
+ TF_AttrType type, TF_Status* status) {
+ if (type == TF_ATTR_STRING) {
+ const char* value;
+ if (!ParseStringValue(key, py_value, status, &value)) return false;
+ TFE_OpSetAttrString(op, key, value);
+ } else if (type == TF_ATTR_INT) {
+ int64_t value;
+ if (!ParseInt64Value(key, py_value, status, &value)) return false;
+ TFE_OpSetAttrInt(op, key, value);
+ } else if (type == TF_ATTR_FLOAT) {
+ float value;
+ if (!ParseFloatValue(key, py_value, status, &value)) return false;
+ TFE_OpSetAttrFloat(op, key, value);
+ } else if (type == TF_ATTR_BOOL) {
+ unsigned char value;
+ if (!ParseBoolValue(key, py_value, status, &value)) return false;
+ TFE_OpSetAttrBool(op, key, value);
+ } else if (type == TF_ATTR_TYPE) {
+ int value;
+ if (!ParseIntValue(key, py_value, status, &value)) return false;
+ TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
+ } else if (type == TF_ATTR_SHAPE) {
+ if (py_value == Py_None) {
+ TFE_OpSetAttrShape(op, key, nullptr, -1, status);
+ } else {
+ if (!PySequence_Check(py_value)) {
+ TF_SetStatus(status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat(
+ "Expecting None or sequence value for attr", key,
+ ", got ", py_value->ob_type->tp_name)
+ .c_str());
+ return false;
+ }
+ const auto num_dims = PySequence_Size(py_value);
+ std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
+ for (int i = 0; i < num_dims; ++i) {
+ auto inner_py_value = PySequence_ITEM(py_value, i);
+ if (inner_py_value == Py_None) {
+ dims[i] = -1;
+ } else if (!ParseInt64Value(key, inner_py_value, status, &dims[i])) {
+ return false;
+ }
+ }
+ TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
+ }
+ if (TF_GetCode(status) != TF_OK) return false;
+ } else {
+ TF_SetStatus(
+ status, TF_UNIMPLEMENTED,
+ tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
+ .c_str());
+ return false;
+ }
+ return true;
+}
+
+void SetOpAttrs(TFE_Op* op, PyObject* attrs, TF_Status* out_status) {
+ if (attrs == Py_None) return;
+ if (!PyTuple_Check(attrs)) {
+ TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Expecting an attrs tuple.");
+ return;
+ }
+ Py_ssize_t len = PyTuple_GET_SIZE(attrs);
+ if ((len & 1) != 0) {
+ TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
+ "Expecting attrs tuple to have even length.");
+ return;
+ }
+ // Parse attrs
+ for (Py_ssize_t i = 0; i < len; i += 2) {
+ PyObject* py_key = PyTuple_GET_ITEM(attrs, i);
+ PyObject* py_value = PyTuple_GET_ITEM(attrs, i + 1);
+#if PY_MAJOR_VERSION >= 3
+ const char* key = PyUnicode_AsUTF8(py_key);
+#else
+ const char* key = PyString_AsString(py_key);
+#endif
+ unsigned char is_list = 0;
+ const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
+ if (TF_GetCode(out_status) != TF_OK) return;
+ if (is_list != 0) {
+ if (!SetOpAttrList(op, key, py_value, type, out_status)) return;
+ } else {
+ if (!SetOpAttrScalar(op, key, py_value, type, out_status)) return;
+ }
+ }
+}
+} // namespace
+
+void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
+ const char* op_name, TFE_InputTensorHandles* inputs,
+ PyObject* attrs, TFE_OutputTensorHandles* outputs,
+ TF_Status* out_status) {
+ TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
+ if (TF_GetCode(out_status) != TF_OK) return;
+ if (device_name != nullptr) {
+ TFE_OpSetDevice(op, ctx, device_name, out_status);
+ }
+ if (TF_GetCode(out_status) == TF_OK) {
+ for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK;
+ ++i) {
+ TFE_OpAddInput(op, inputs->at(i), out_status);
+ }
+ }
+ if (TF_GetCode(out_status) == TF_OK) {
+ SetOpAttrs(op, attrs, out_status);
+ }
+ if (TF_GetCode(out_status) == TF_OK) {
+ int num_outputs = outputs->size();
+ TFE_Execute(op, outputs->data(), &num_outputs, out_status);
+ outputs->resize(num_outputs);
+ }
+ if (TF_GetCode(out_status) != TF_OK) {
+ TF_SetStatus(out_status, TF_GetCode(out_status),
+ tensorflow::strings::StrCat(TF_Message(out_status),
+ " [Op:", op_name, "]")
+ .c_str());
+ }
+ TFE_DeleteOp(op);
+}
+
+PyObject* TFE_Py_TensorHandleToNumpy(TFE_TensorHandle* h, TF_Status* status) {
+ const tensorflow::Tensor* t =
+ TFE_TensorHandleUnderlyingTensorInHostMemory(h, status);
+ if (TF_GetCode(status) != TF_OK) {
+ Py_RETURN_NONE;
+ }
+ PyObject* ret = nullptr;
+ auto cppstatus = tensorflow::ConvertTensorToNdarray(*t, &ret);
+ if (!cppstatus.ok()) {
+ TF_SetStatus(status, TF_Code(cppstatus.code()),
+ cppstatus.error_message().c_str());
+ }
+ if (ret != nullptr) return ret;
+ Py_RETURN_NONE;
+}
+
+namespace {
+// Python subclass of Exception that is created on not ok Status.
+tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
+PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
+} // namespace
+
+TFE_TensorHandle* TFE_Py_NumpyToTensorHandle(PyObject* obj) {
+ tensorflow::Tensor t;
+ auto cppstatus = tensorflow::ConvertNdarrayToTensor(obj, &t);
+ if (cppstatus.ok()) {
+ return TFE_NewTensorHandle(t);
+ } else {
+ tensorflow::mutex_lock l(exception_class_mutex);
+ auto msg = tensorflow::strings::StrCat(
+ "failed to convert numpy ndarray to a Tensor (",
+ cppstatus.error_message(), ")");
+ if (exception_class != nullptr) {
+ PyErr_SetObject(exception_class,
+ Py_BuildValue("si", msg.c_str(), TF_INVALID_ARGUMENT));
+ } else {
+ PyErr_SetString(PyExc_RuntimeError, msg.c_str());
+ }
+ }
+ return nullptr;
+}
+
+PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
+ tensorflow::mutex_lock l(exception_class_mutex);
+ if (exception_class != nullptr) {
+ Py_DECREF(exception_class);
+ }
+ if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
+ exception_class = nullptr;
+ PyErr_SetString(PyExc_TypeError,
+ "TFE_Py_RegisterExceptionClass: "
+ "Registered class should be subclass of Exception.");
+ return nullptr;
+ } else {
+ Py_INCREF(e);
+ exception_class = e;
+ Py_RETURN_NONE;
+ }
+}
+
+int TFE_Py_MayBeRaiseException(TF_Status* status) {
+ if (TF_GetCode(status) == TF_OK) return 0;
+ tensorflow::mutex_lock l(exception_class_mutex);
+ if (exception_class != nullptr) {
+ PyErr_SetObject(exception_class, Py_BuildValue("si", TF_Message(status),
+ TF_GetCode(status)));
+ } else {
+ PyErr_SetString(PyExc_RuntimeError, TF_Message(status));
+ }
+ return -1;
+}
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
new file mode 100644
index 0000000000..1cab4346b0
--- /dev/null
+++ b/tensorflow/python/eager/tape.py
@@ -0,0 +1,240 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradient tape utilites."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from autograd import container_types
+from autograd import core as ag_core
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.util import nest
+from tensorflow.python.util import tf_contextlib
+
+
+def tensor_id(t):
+ """Returns a unique identifier for this Tensor."""
+ t = ag_core.getval(t)
+ return t._id # pylint: disable=protected-access
+
+
+class ImplicitTape(object):
+ """Global object which can watch tensors and wrap them with autograd."""
+
+ def __init__(self):
+ self.tensors = {}
+ self.gradients = []
+
+ def __eq__(self, other):
+ return self is other
+
+ def __hash__(self):
+ return id(self)
+
+
+@ag_core.primitive
+def _watch_with_tape_internal(_, tensor):
+ """Primitive to wrap a tensor around an ImplicitTape progenitor."""
+ return tensor
+
+
+def _watch_with_tape(tape, tensor):
+ """Wraps a watched Tensor and keeps track of it in the implicit tape."""
+ w = _watch_with_tape_internal(tape, tensor)
+ if ag_core.isnode(tape):
+ tape.value.tensors[tensor_id(tensor)] = w
+ return w
+
+
+def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
+ """Gradient for _watch_with_tape_internal."""
+ del ans, gvs, tape
+
+ def mut_add(implicit_tape):
+ t = ag_core.getval(tensor)
+ implicit_tape.gradients.append((t, g))
+ return implicit_tape
+
+ return ag_core.SparseObject(vs, mut_add)
+
+_watch_with_tape_internal.defvjp(_watch_with_tape_vjp, argnum=0)
+_watch_with_tape_internal.defvjp(
+ lambda g, ans, vs, gvs, tape, tensor: g,
+ argnum=1)
+
+
+class ImplicitTapeVSpace(ag_core.VSpace):
+ """VSpace needed to have ImplicitTape be a valid progenitor."""
+
+ def zeros(self):
+ return ImplicitTape()
+
+
+class ImplicitTapeNode(ag_core.Node):
+ """Node to wrap ImplicitTape in."""
+
+ def __eq__(self, other):
+ return self is other
+
+ def __hash__(self):
+ return id(self)
+
+ag_core.register_node(ImplicitTapeNode, ImplicitTape)
+ag_core.register_vspace(ImplicitTapeVSpace, ImplicitTape)
+
+
+# TODO(apassos) try to not do this.
+class NoneVSpace(ag_core.VSpace):
+ """VSpace for python None."""
+
+ def __init__(self, _):
+ self.size = 0
+
+
+ag_core.register_vspace(NoneVSpace, type(None))
+
+
+class _TapeStack(threading.local):
+
+ def __init__(self):
+ super(_TapeStack, self).__init__()
+ self._stack = []
+
+ @property
+ def stack(self):
+ return self._stack
+
+ @tf_contextlib.contextmanager
+ def replace_stack(self, new_stack):
+ old = self._stack
+ self._stack = new_stack
+ yield
+ self._stack = old
+
+
+# The global tape stack.
+_tape_stack = _TapeStack()
+
+
+def push_new_tape():
+ """Pushes a new tape onto the tape stack."""
+ progenitor = ag_core.new_progenitor(ImplicitTape())
+ _tape_stack.stack.append(progenitor)
+ ag_core.active_progenitors.add(progenitor)
+
+
+def watch(tensor):
+ """Marks this tensor to be watched by all tapes in the stack.
+
+ Args:
+ tensor: tensor to be watched.
+
+ Returns:
+ The tensor, potentially wrapped by all tapes in the stack.
+ """
+ for t in _tape_stack.stack:
+ tensor = _watch_with_tape(t, tensor)
+ return tensor
+
+
+def pop_tape():
+ """Pops the top tape in the stack, if any."""
+ if _tape_stack.stack:
+ return _tape_stack.stack.pop()
+ return None
+
+
+def any_tape_has(tensor):
+ for t in _tape_stack.stack:
+ if tensor_id(tensor) in t.value.tensors:
+ return True
+ return False
+
+
+def should_record(tensors):
+ """Returns true if any tape in the stach watches any of these tensors."""
+ return any(ag_core.isnode(x) for x in tensors)
+
+
+class _EagerSequenceNode(container_types.SequenceNode):
+ """Eager version of SequenceNode, to live in EagerSequenceVSpace."""
+ pass
+
+
+class _EagerSequenceVSpace(container_types.SequenceVSpace):
+ """Changes equality on SequenceVSpace to conform to tfe requirements."""
+
+ def __init__(self, value):
+ self.shape = [ag_core.vspace(x) for x in value]
+ self.size = sum(s.size for s in self.shape)
+ self.sequence_type = type(value)
+
+ def __eq__(self, other):
+ if type(self) != type(other): # pylint: disable=unidiomatic-typecheck
+ return False
+ if len(self.shape) != len(other.shape):
+ # TODO(apassos) function gradients sometimes return gradients for side
+ # inputs which breaks this assertion. Understand how to fix it.
+ return True
+ for ss, os in zip(self.shape, other.shape):
+ if ss != os:
+ if isinstance(ss, NoneVSpace) or isinstance(os, NoneVSpace):
+ continue
+ if ss.dtype == dtypes.resource or os.dtype == dtypes.resource:
+ continue
+ return False
+ return True
+
+
+class _EagerList(list):
+ """Type used to bypass SequenceVSpace."""
+
+ def __init__(self, value):
+ super(_EagerList, self).__init__(value)
+ for v in value:
+ assert not ag_core.isnode(v)
+
+ag_core.register_vspace(_EagerSequenceVSpace, _EagerList)
+ag_core.register_node(_EagerSequenceNode, _EagerList)
+
+
+@ag_core.primitive
+def _record_operation(output_tensors, input_tensors, side_outputs,
+ backward_function):
+ del input_tensors, side_outputs, backward_function
+ return _EagerList(output_tensors)
+
+
+def record_operation(o, i, s, b):
+ """Primitive to trigger autograd tracing on outputs from inputs."""
+ inputs = container_types.make_sequence(_EagerList, *i)
+ return _record_operation(o, inputs, s, b)
+
+
+def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
+ side_outputs, backward_function):
+ """Gradient for _record_operation."""
+ del ans, vs, gvs, output_tensors, input_tensors
+ backward_args = tuple(g) + tuple(side_outputs)
+ if ag_core.isnode(backward_args):
+ backward_args = list(backward_args)
+ tensors = nest.flatten(backward_function(*backward_args))
+ return _EagerList([ag_core.getval(t) for t in tensors])
+
+_record_operation.defvjp(_record_operation_vjp, argnum=1)
diff --git a/tensorflow/python/eager/tensor.py b/tensorflow/python/eager/tensor.py
new file mode 100644
index 0000000000..86ac243ae3
--- /dev/null
+++ b/tensorflow/python/eager/tensor.py
@@ -0,0 +1,454 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for TensorFlow's "Eager" mode of execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from autograd import core as ag_core
+import numpy as np
+
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import context
+from tensorflow.python.eager import core
+from tensorflow.python.eager import tape
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.framework import tensor_shape
+
+
+# TODO(agarwal): rename to TensorHandle.
+class Tensor(tf_ops.Tensor):
+ """A TensorFlow Eager Tensor."""
+
+ def __init__(self, value, dtype=None):
+ """Creates a Tensor object from a Python object or numpy array.
+
+ May share storage with the numpy array, in which case changes to the numpy
+ object will reflect
+ in the Tensor.
+
+ Arguments:
+ value: A numpy.array or a Python object to create a Tensor for.
+ dtype: TensorFlow dtype for the returned Tensor. If None, one will be
+ automatically selected.
+ """
+ # TODO(ashankar): Evaluate if we can and perhaps share code with
+ # tf.constant defined in
+ # https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
+ self._id = tf_ops.uid()
+ if not isinstance(value, np.ndarray):
+ npt = None if dtype is None else dtype.as_numpy_dtype
+ value = np.array(value, dtype=npt)
+ if dtype is None:
+ value = _maybe_modify_numpy_dtype_determination(value)
+ elif dtype is not None:
+ npt = dtype.as_numpy_dtype
+ if npt != value.dtype:
+ value = value.astype(npt)
+ try:
+ value = np.asarray(value, order="C")
+ self._handle = pywrap_tensorflow.TFE_Py_NumpyToTensorHandle(value)
+ except core._NotOkStatusException as e: # pylint: disable=protected-access
+ raise core._status_to_exception(e.code, e.message) # pylint: disable=protected-access
+
+ # Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
+ # memory. This change approximates the same behavior for eager execution -
+ # keeping int32 tensors in host memory.
+ #
+ # We do so to preclude the need for callers into such kernels from having to
+ # explicitly place the int32 tensors in host memory. For example, prior to
+ # this change one needed:
+ #
+ # with tfe.device('/gpu:0'):
+ # ... # code here
+ # with tfe.device('/cpu:0'):
+ # shape = tfe.Tensor(...)
+ # y = tfe.ops.random_uniform(.., shape)
+ #
+ # Without the CPU device block tfe.ops.random_uniform would fail since the
+ # kernel expects the shape in host memory.
+ #
+ # After this change, we simplify the code:
+ #
+ # with tfe.device('/gpu:0'):
+ # y = tfe.ops.random_uniform(, tfe.Tensor(...))
+ #
+ # The approximation is not exact since if there are GPU kernels which do not
+ # require host memory for int32 tensors, there will be a discrepancy between
+ # eager execution and TensorFlow graphs. However, as of July 2017, there
+ # were no known GPU kernels that kept int32 tensors in device memory.
+ if _in_gpu_device() and value.dtype != np.int32:
+ ctx = context.get_default_context()
+ # pylint: disable=protected-access
+ device_name = ctx.device_name
+ with errors.raise_exception_on_not_ok_status() as status:
+ self._handle = pywrap_tensorflow.TFE_TensorHandleCopyToDevice(
+ self._handle, ctx._handle, device_name, status)
+ # pylint: enable=protected-access
+
+ self._dtype = dtypes.as_dtype(
+ pywrap_tensorflow.TFE_TensorHandleDataType(self._handle))
+
+ # This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
+ # be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
+ # tensors, this will contain a serialized HandleData proto with shape
+ # inference metadata about shapes and dtypes of resources accessible from
+ # this handle.
+ self._handle_data = None
+ if core.active_trace() is not None:
+ core.active_trace().record_tensor("MANUAL",
+ tape.tensor_id(self),
+ self.device,
+ self.shape.num_elements())
+
+ def __del__(self):
+ if (pywrap_tensorflow is not None
+ and pywrap_tensorflow.TFE_DeleteTensorHandle is not None):
+ pywrap_tensorflow.TFE_DeleteTensorHandle(self._handle)
+ if core.active_trace() is not None:
+ core.active_trace().delete_tensor(tape.tensor_id(self))
+
+ def __str__(self):
+ if self.dtype.is_numpy_compatible and self.shape.num_elements() > 0:
+ n = self.numpy().reshape(-1)
+ if self.shape.num_elements() > 5:
+ return "tfe.Tensor(%s..., shape=%s, dtype=%s)" % (n[:5], self.shape,
+ self.dtype.name)
+ else:
+ return "tfe.Tensor(%s, dtype=%s)" % (
+ np.array_str(self.numpy()).replace("\n", ""), self.dtype.name)
+ return "tfe.Tensor(<unprintable>, shape=%s dtype=%s)" % (self.shape,
+ self.dtype.name)
+
+ def __repr__(self):
+ if self.dtype.is_numpy_compatible and self.shape.num_elements() > 0:
+ n = self.numpy()
+ # TODO(apassos): understand why self.numpy() sometimes returns not
+ # an array.
+ if isinstance(n, np.ndarray):
+ n = n.reshape(-1)
+ if self.shape.num_elements() > 5:
+ return "<tfe.Tensor at %s shape=%s dtype=%s>(%s..., min=%s, max=%s)" % (
+ self._id, self.shape, self.dtype.name, n[:5], np.min(n), np.max(n))
+ else:
+ return "<tfe.Tensor at %s shape=%s dtype=%s>(%s)" % (self._id,
+ self.shape,
+ self.dtype.name, n)
+ return "<tfe.Tensor at %s shape=%s dtype=%s>" % (self._id, self.shape,
+ self.dtype.name)
+
+ @staticmethod
+ def _override_operator(name, func):
+ setattr(Tensor, name, func)
+
+ def numpy(self):
+ """Returns a numpy array with the same contents as the Tensor.
+
+ The contents of the Tensor must be backed by host memory. The
+ as_cpu_tensor() method can be used ensure that this is true.
+
+ TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
+ buffer but instead always explicitly copy? Note that currently it may or may
+ not copy based on whether the numpy data is properly aligned or not.
+
+ Returns:
+ A numpy array that may share memory with the Tensor object. Any changes
+ to one may be reflected in the other.
+ """
+ # TODO(ashankar): This with status business seems expensive. Profile/avoid?
+ cpu = self.as_cpu_tensor()
+ with errors.raise_exception_on_not_ok_status() as status:
+ return pywrap_tensorflow.TFE_Py_TensorHandleToNumpy(cpu._handle, status) # pylint: disable=protected-access
+
+ def _copy(self, ctx, device_name):
+ """Copies tensor to dest device."""
+ # pylint: disable=protected-access
+ # Creates a new tensor on the dest device.
+ with errors.raise_exception_on_not_ok_status() as status:
+ h = pywrap_tensorflow.TFE_TensorHandleCopyToDevice(
+ self._handle, ctx._handle, device_name, status)
+ new_tensor = _tensor_from_handle(h)
+ if core.active_trace() is not None:
+ core.active_trace().record_tensor("COPY",
+ tape.tensor_id(new_tensor),
+ new_tensor.device,
+ new_tensor.shape.num_elements())
+ return new_tensor
+ # pylint: enable=protected-access
+
+ @property
+ def device(self):
+ return pywrap_tensorflow.TFE_TensorHandleDeviceName(self._handle)
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @property
+ def shape(self):
+ """The shape of this Tensor as a TensorShape object."""
+ n = pywrap_tensorflow.TFE_TensorHandleNumDims(self._handle)
+ # As of May 2017, TFE_TensorHandle objects were always backed by concrete
+ # tensors (which have a valid, known shape). There were vague plans to
+ # change this so that the Tensor class can also represent Tensors that have
+ # not yet been computed.
+ # If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
+ # and also handle -1s returned by TFE_TensorHandleDim.
+ assert n >= 0, "See comment in source code"
+ return tensor_shape.TensorShape(
+ [pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
+ for x in range(n)])
+
+ def get_shape(self):
+ """Alias of Tensor.shape."""
+ return self.shape
+
+ def _shape_tuple(self):
+ """The shape of this Tensor, as a tuple.
+
+ This is more performant than tuple(shape().as_list()) as it avoids
+ two list and one object creation. Marked private for now as from an API
+ perspective, it would be better to have a single performant way of
+ getting a shape rather than exposing shape() and shape_tuple()
+ (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
+ but ideally one would work things out and remove the need for this method.
+ """
+ n = pywrap_tensorflow.TFE_TensorHandleNumDims(self._handle)
+ # As of May 2017, TFE_TensorHandle objects were always backed by concrete
+ # tensors (which have a valid, known shape). There were vague plans to
+ # change this so that the Tensor class can also represent Tensors that have
+ # not yet been computed.
+ # If that happens, handle that (e.g., if n < 0: return tensor_shape(None))
+ # and also handle -1s returned by TFE_TensorHandleDim.
+ assert n >= 0, "See comment in source code"
+ return tuple(
+ pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
+ for x in range(n))
+
+ def _shape_as_list(self):
+ """The shape of the tensor as a list."""
+ return list(self._shape_tuple())
+
+ def as_cpu_tensor(self):
+ """A copy of this Tensor with contents backed by host memory."""
+ return self._copy(context.get_default_context(), "CPU:0")
+
+ def as_gpu_tensor(self, gpu_index=0):
+ """A copy of this Tensor with contents backed by memory on the GPU.
+
+ Arguments:
+ gpu_index: Identifies which GPU to place the contents on the returned
+ Tensor in.
+
+ Returns:
+ A GPU-memory backed Tensor object initialized with the same contents
+ as this Tensor.
+ """
+ return self._copy(context.get_default_context(), "GPU:" + str(gpu_index))
+
+ def __bool__(self):
+ if self._shape_tuple() != (): # pylint: disable=g-explicit-bool-comparison
+ raise ValueError(
+ "Non-scalar tensor %s cannot be converted to boolean." % repr(self))
+ if self.dtype != dtypes.bool:
+ raise ValueError(
+ "Non-boolean tensor %s cannot be converted to boolean." % repr(self))
+ return bool(self.as_cpu_tensor().numpy())
+
+ def __nonzero__(self):
+ return self.__bool__()
+
+ # Methods not supported / implemented for Eager Tensors.
+ @property
+ def op(self):
+ raise NotImplementedError("op not supported for Eager Tensors.")
+
+ @property
+ def graph(self):
+ raise NotImplementedError("graph not supported for Eager Tensors.")
+
+ @property
+ def name(self):
+ raise NotImplementedError("name not supported for Eager Tensors.")
+
+ def set_shape(self, shape):
+ raise NotImplementedError("set_shape not supported for Eager Tensors.")
+
+ @property
+ def value_index(self):
+ raise NotImplementedError("value_index not supported for Eager Tensors.")
+
+ def consumers(self):
+ raise NotImplementedError("consumers not supported for Eager Tensors.")
+
+ def _add_consumer(self, consumer):
+ raise NotImplementedError("_add_consumer not supported for Eager Tensors.")
+
+ def _as_node_def_input(self):
+ raise NotImplementedError(
+ "_as_node_def_input not supported for Eager Tensors.")
+
+ def _as_tf_output(self):
+ raise NotImplementedError("_as_tf_output not supported for Eager Tensors.")
+
+ def eval(self, feed_dict=None, session=None):
+ raise NotImplementedError("eval not supported for Eager Tensors.")
+
+
+class IndexedSlices(object):
+ """A sparse representation of a set of tensor slices at given indices.
+
+ This class is a simple wrapper for a pair of `Tensor` objects:
+
+ * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
+ * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
+
+ An `IndexedSlices` is typically used to represent a subset of a larger
+ tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
+ The values in `indices` are the indices in the first dimension of
+ the slices that have been extracted from the larger tensor.
+
+ The dense tensor `dense` represented by an `IndexedSlices` `slices` has
+
+ ```python
+ dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
+ ```
+
+ The `IndexedSlices` class is used principally in the definition of
+ gradients for operations that have sparse gradients
+ (e.g. @{tf.gather}).
+ """
+
+ def __init__(self, values, indices, dense_shape):
+ """Creates an `IndexedSlices`."""
+ self._values = values
+ self._indices = indices
+ assert indices.shape[0] == values.shape[0]
+ self._dense_shape = dense_shape
+
+ @property
+ def values(self):
+ """A `Tensor` containing the values of the slices."""
+ return self._values
+
+ @property
+ def indices(self):
+ """A 1-D `Tensor` containing the indices of the slices."""
+ return self._indices
+
+ @property
+ def dense_shape(self):
+ """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
+ return self._dense_shape
+
+
+class _Op(object):
+ """Fake op for _LazyZero to make its python API tf.Tensor-like."""
+
+ def __init__(self):
+ self.type = "Zeros"
+
+
+class LazyZero(object):
+ """Lazily-instantiated zero-valued Tensor used as autograd accumulator."""
+
+ def __init__(self, shape, dtype):
+ self.shape = shape
+ self.dtype = dtype
+ self.op = _Op()
+
+ def __add__(self, other):
+ return other
+
+ def __radd__(self, other):
+ return other
+
+ def numpy(self):
+ return np.zeros(self.shape, self.dtype)
+
+
+def convert_to_eager_tensor(t, dtype=None):
+ if isinstance(ag_core.getval(t), Tensor):
+ if dtype is not None and t.dtype != dtype:
+ raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
+ return t
+ return Tensor(t, dtype=dtype)
+
+
+def convert_n_to_eager_tensor(values, dtype):
+ return [convert_to_eager_tensor(t, dtype) for t in values]
+
+
+def _tensor_from_handle(handle):
+ """'Private' constructor for the Tensor object.
+
+ The existence of a 'handle' is an implementation detail that should be hidden
+ from users of this module. Functions within this module do need to create a
+ Tensor object from a handle though.
+
+ One option would be to have an __init__(self, handle) method on the
+ Tensor class, but that would make the existence and use of a handle
+ 'public'.
+
+ Instead, this function avoids exposing a Tensor.__init__ that understands
+ handles and yet allows functions within this module to create Tensor
+ objects from a handle.
+
+ Arguments:
+ handle: A valid TFE_TensorHandle object.
+
+ Returns:
+ A Tensor object.
+ """
+ # pylint: disable=protected-access
+ t = Tensor.__new__(Tensor)
+ t._id = tf_ops.uid()
+ t._handle = handle
+ t._dtype = dtypes.as_dtype(pywrap_tensorflow.TFE_TensorHandleDataType(handle))
+ t._handle_data = None
+ return t
+ # pylint: enable=protected-access
+
+
+# TODO(ashankar): use actual device type.
+def _in_gpu_device():
+ return context.get_default_context()._device_index > 0 # pylint: disable=protected-access
+
+
+def _maybe_modify_numpy_dtype_determination(np_array):
+ """Tweak numpy dtype determination.
+
+ numpy prefers int64 and float64, we prefer int32 and float32.
+ (int32 is often used as the "shape" input to various operations,
+ many of which only support int32 shapes).
+ This preference is copied from tensor_util.make_tensor_proto
+ (https://goto.google.com/numpy_prefs_156503903)
+
+ Args:
+ np_array: A numpy ndarray
+ Returns:
+ A numpy ndarray whose dtype may have been modified.
+ """
+ if np_array.dtype == np.float64:
+ return np_array.astype(np.float32)
+ if np_array.dtype == np.int64:
+ # Downcast iff there is no precision loss.
+ downcasted = np_array.astype(np.int32)
+ if np.array_equal(downcasted, np_array):
+ return downcasted
+ return np_array
diff --git a/tensorflow/python/eager/test.py b/tensorflow/python/eager/test.py
new file mode 100644
index 0000000000..3d8af7e056
--- /dev/null
+++ b/tensorflow/python/eager/test.py
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for testing tfe code."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context as _context
+from tensorflow.python.platform import test as _test
+from tensorflow.python.platform.test import * # pylint: disable=wildcard-import
+
+
+def main(argv=None):
+ _context.enable_eager_execution()
+ _test.main(argv)