aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-08-10 14:19:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-10 14:22:58 -0700
commit13eb3b90e9ed8778ffd2b1bf6401677938b1ec39 (patch)
tree40a2e7e926f3ed9fa0b99f88056bacc471547be7 /tensorflow/c
parent7dfabcc01c9c752747c473346bb3f8c1cd290ad1 (diff)
Experimental C and Python APIs to invoke TensorFlow kernels on concrete values.
PiperOrigin-RevId: 164902588
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/BUILD67
-rw-r--r--tensorflow/c/eager/c_api.cc561
-rw-r--r--tensorflow/c/eager/c_api.h159
-rw-r--r--tensorflow/c/eager/c_api_test.cc463
-rw-r--r--tensorflow/c/eager/runtime.cc289
-rw-r--r--tensorflow/c/eager/runtime.h193
-rw-r--r--tensorflow/c/eager/runtime_test.cc160
7 files changed, 1892 insertions, 0 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
new file mode 100644
index 0000000000..18e1a68a93
--- /dev/null
+++ b/tensorflow/c/eager/BUILD
@@ -0,0 +1,67 @@
+# Experimental extensions to the C API for eager execution of kernels.
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "c_api",
+ srcs = ["c_api.cc"],
+ hdrs = ["c_api.h"],
+ visibility = [
+ "//tensorflow:internal",
+ "//tensorflow/python/eager:__pkg__",
+ ],
+ deps = [
+ ":runtime",
+ "//tensorflow/c:c_api",
+ "//tensorflow/c:c_api_internal",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "c_api_test",
+ srcs = ["c_api_test.cc"],
+ deps = [
+ ":c_api",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "runtime",
+ srcs = ["runtime.cc"],
+ hdrs = ["runtime.h"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ ],
+)
+
+cc_test(
+ name = "runtime_test",
+ srcs = ["runtime_test.cc"],
+ deps = [
+ ":runtime",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
new file mode 100644
index 0000000000..05e09ea120
--- /dev/null
+++ b/tensorflow/c/eager/c_api.cc
@@ -0,0 +1,561 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/eager/runtime.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/version.h"
+
+using tensorflow::int64;
+using tensorflow::string;
+
+namespace {
+bool IsCPU(tensorflow::Device* d) {
+ return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
+}
+
+string DeviceName(tensorflow::Device* d) {
+ return (d == nullptr) ? "cpu:0" : d->name();
+}
+} // namespace
+
+struct TFE_Context {
+ explicit TFE_Context(TF_Session* s) : session(s) {}
+
+ // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
+ TF_Session* session;
+
+ tensorflow::mutex functions_mu;
+ tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
+ tensorflow::OpRegistry::Global(), {}};
+
+ // One FunctionLibraryRuntime per device.
+ // func_libs[i] is the FunctionLibraryRuntime corresponding to
+ // session->devices[i].
+ std::vector<std::unique_ptr<tensorflow::FunctionLibraryRuntime> > func_libs;
+
+ std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
+ tensorflow::Fprint128Hasher>
+ kernel_cache;
+
+ tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
+ for (int i = 0; i < session->devices.size(); ++i) {
+ if (session->devices[i] == d) {
+ return func_libs[i].get();
+ }
+ }
+ return nullptr;
+ }
+
+ const std::vector<tensorflow::Device*>& devices() { return session->devices; }
+};
+
+struct TFE_TensorHandle {
+ TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d)
+ : t(t), d(d) {}
+
+ tensorflow::Tensor t;
+ // TODO(ashankar): d == nullptr iff local CPU
+ // This was expedient, but perhaps worth revisiting ('d' should always be a
+ // valid pointer?)
+ // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
+ // provided with the appropriate TFE_Context.
+ //
+ // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a
+ // TFE_TensorHandle does not outlive the TFE_Context from which it came?
+ tensorflow::Device* d;
+};
+
+struct TFE_Op {
+ TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
+ : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
+
+ bool const is_function() const { return attr_types == nullptr; }
+
+ TFE_Context* ctx; // Must outlive the TFE_Op.
+ const char* name;
+ tensorflow::AttrBuilder attrs;
+ const tensorflow::AttrTypeMap* attr_types;
+ std::vector<tensorflow::Tensor> inputs;
+ std::vector<tensorflow::Device*> input_devices;
+ tensorflow::Device* device;
+};
+
+extern "C" {
+
+TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
+ TF_Graph* graph = TF_NewGraph();
+ TF_Session* session = TF_NewSession(graph, opts, status);
+ if (status->status.ok()) {
+ if (session->device_mgr == nullptr || session->devices.empty()) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Provided TF_SessionOptions are not compatible with eager execution "
+ "(perhaps the TF_SessionOptions alluded to session execution in a "
+ "remote address space?)");
+ }
+ }
+ if (!status->status.ok()) {
+ TF_DeleteGraph(graph);
+ return nullptr;
+ }
+
+ TFE_Context* ret = new TFE_Context(session);
+ ret->func_libs.resize(ret->devices().size());
+ for (int i = 0; i < ret->devices().size(); ++i) {
+ ret->func_libs[i] = tensorflow::NewFunctionLibraryRuntime(
+ ret->session->device_mgr, opts->options.env, ret->devices()[i],
+ TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {});
+ }
+
+ return ret;
+}
+
+void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
+ status->status = tensorflow::Status::OK();
+ tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+ TF_Graph* graph = ctx->session->graph;
+ TF_DeleteSession(ctx->session, status);
+ TF_DeleteGraph(graph);
+ delete ctx;
+}
+
+TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
+ return TF_SessionListDevices(ctx->session, status);
+}
+
+TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) {
+ return new TFE_TensorHandle(
+ tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer),
+ nullptr);
+}
+
+void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
+
+TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
+ return static_cast<TF_DataType>(h->t.dtype());
+}
+
+int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
+
+int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
+ return h->t.dim_size(dim_index);
+}
+
+const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
+ // This might be a bit confusing as a tensor on CPU can sometimes return
+ // "CPU:0" and sometimes "/job:localhost/replica:0/task:0/cpu:0".
+ // TODO(ashankar): Figure out which one would be nicer.
+ return (h->d == nullptr) ? "CPU:0" : h->d->name().c_str();
+}
+
+TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
+ if (!IsCPU(h->d)) {
+ TF_SetStatus(status, TF_UNIMPLEMENTED,
+ tensorflow::strings::StrCat(
+ "TFE_TensorHandle can be resolved iff it is on CPU (this "
+ "handle is on ",
+ h->d->name(),
+ "). Consider using TFE_TensorHandleCopyToDevice to get a "
+ "copy of the tensor on CPU")
+ .c_str());
+ return nullptr;
+ }
+ return tensorflow::TF_TensorFromTensor(h->t, status);
+}
+
+TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
+ TFE_Context* ctx,
+ const char* device_name,
+ TF_Status* status) {
+ tensorflow::Device* dstd = nullptr;
+ status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd);
+ if (!status->status.ok()) return nullptr;
+
+ tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
+ const bool src_cpu = IsCPU(srcd);
+ const bool dst_cpu = IsCPU(dstd);
+ if (!src_cpu && !dst_cpu) {
+ TF_SetStatus(
+ status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat(
+ "TFE_TensorHandleCopyToDevice requires either the source "
+ "TFE_TensorHandle be on or the destination device be CPU (they "
+ "are ",
+ DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
+ .c_str());
+ return nullptr;
+ }
+ tensorflow::Tensor* src = &(h->t);
+ if (src_cpu && dst_cpu) {
+ // There must be a better way, but for now redirect through proto to ensure
+ // that the underlying buffers are not shared.
+ tensorflow::TensorProto proto;
+ src->AsProtoTensorContent(&proto);
+ tensorflow::Tensor dst(src->dtype(), src->shape());
+ if (!dst.FromProto(proto)) {
+ TF_SetStatus(
+ status, TF_INTERNAL,
+ tensorflow::strings::StrCat(
+ "error copying between TFE_TensorHandles on CPU. Consider filing "
+ "a bug report at https://github.com/tensorflow/tensorflow/issues "
+ "mentioning version: ",
+ TF_Version(), " and ", __FILE__, ":", __LINE__)
+ .c_str());
+ return nullptr;
+ }
+ return new TFE_TensorHandle(dst, nullptr);
+ }
+ if (src_cpu) {
+ tensorflow::Tensor dst(
+ dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
+ src->shape());
+ tensorflow::Notification n;
+ dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice(
+ src, dstd, &dst, [status, &n](const tensorflow::Status& s) {
+ status->status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd)
+ : nullptr;
+ }
+ CHECK(dst_cpu);
+ tensorflow::Tensor dst(src->dtype(), src->shape());
+ tensorflow::Notification n;
+ // TODO(ashankar): The Sync() call below may be more aggressive than
+ // necessary. It is based on knowledge of implementation details - that
+ // GPU devices are implemented using 3 streams - one for host->device copies,
+ // one for device->host copies and one for sending operations to the GPU.
+ // With that setup, Sync()ing across all 3 streams should be sufficient
+ // but more than necessary (since it waits for operations that might have
+ // nothing to do with this tensor to complete).
+ status->status = srcd->Sync();
+ if (!status->status.ok()) return nullptr;
+ srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU(
+ src, "IGNORE_MY_TENSOR_NAME", srcd, &dst,
+ [status, &n](const tensorflow::Status& s) {
+ status->status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr)
+ : nullptr;
+}
+
+TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
+ TF_Status* status) {
+ const char* name = op_or_function_name; // Shorthand
+ const tensorflow::AttrTypeMap* types;
+ status->status = tensorflow::AttrTypeMapForOp(name, &types);
+ if (status->status.ok()) return new TFE_Op(ctx, name, types);
+ if (TF_GetCode(status) == TF_NOT_FOUND) {
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ if (ctx->func_lib_def.Find(name) != nullptr) {
+ status->status = tensorflow::Status::OK();
+ return new TFE_Op(ctx, name, nullptr);
+ }
+ }
+ return nullptr;
+}
+
+void TFE_DeleteOp(TFE_Op* op) { delete op; }
+
+static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device,
+ TF_Status* status) {
+ // Questionable heuristic: Place the op on the same device as the first input
+ // placed outside of host memory?
+ if (IsCPU(op->device) && !IsCPU(device)) {
+ op->device = device;
+ }
+}
+
+void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, const char* device_name,
+ TF_Status* status) {
+ tensorflow::Device* d = nullptr;
+ status->status = ctx->session->device_mgr->LookupDevice(device_name, &d);
+ if (!status->status.ok()) return;
+ TFE_OpSetDeviceHelper(op, d, status);
+}
+
+void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
+ TFE_OpSetDeviceHelper(op, h->d, status);
+ if (!status->status.ok()) return;
+ op->inputs.push_back(h->t);
+ op->input_devices.push_back(h->d);
+ op->attrs.NumInputs(op->inputs.size());
+}
+
+TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
+ unsigned char* is_list, TF_Status* status) {
+ TF_AttrType ret;
+ if (op->is_function()) {
+ status->status = tensorflow::errors::Unimplemented(
+ "TODO(apassos): Support for attributes for TensorFlow functions is not "
+ "ready yet.");
+ return TF_ATTR_INT; // The compiler requires that we return something.
+ }
+ status->status =
+ tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list);
+ return ret;
+}
+
+void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
+ op->attrs.Set(attr_name, value);
+}
+
+void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
+ op->attrs.Set(attr_name, static_cast<int64>(value));
+}
+
+void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
+ op->attrs.Set(attr_name, value);
+}
+
+void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
+ op->attrs.Set(attr_name, (value == 0) ? false : true);
+}
+
+void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
+ op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
+}
+
+void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
+ const int num_dims, TF_Status* out_status) {
+ if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
+ TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat(
+ "Value specified for `", attr_name, "` has ", num_dims,
+ " dimensions which is over the limit of ",
+ tensorflow::TensorShape::MaxDimensions(), ".")
+ .c_str());
+ return;
+ }
+ tensorflow::TensorShapeProto proto;
+ if (num_dims < 0) {
+ proto.set_unknown_rank(true);
+ } else {
+ for (int d = 0; d < num_dims; ++d) {
+ proto.add_dim()->set_size(dims[d]);
+ }
+ }
+ op->attrs.Set(attr_name, proto);
+}
+
+#define TFE_OP_SET_ATTR_LIST(fn, type) \
+ void fn(TFE_Op* op, const char* attr_name, const type* values, \
+ int num_values) { \
+ op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
+ values, num_values)); \
+ }
+TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
+TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
+#undef TFE_OP_SET_ATTR_LIST
+
+void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
+ const int64_t* values, int num_values) {
+ op->attrs.Set(attr_name,
+ tensorflow::gtl::ArraySlice<const int64>(
+ reinterpret_cast<const int64*>(values), num_values));
+}
+
+void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
+ const TF_DataType* values, int num_values) {
+ op->attrs.Set(
+ attr_name,
+ tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
+ reinterpret_cast<const tensorflow::DataType*>(values), num_values));
+}
+
+void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
+ const unsigned char* values, int num_values) {
+ std::unique_ptr<bool[]> b(new bool[num_values]);
+ for (int i = 0; i < num_values; ++i) {
+ b[i] = values[i];
+ }
+ op->attrs.Set(attr_name,
+ tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
+}
+
+void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
+ const int64_t** dims, const int* num_dims,
+ int num_values, TF_Status* out_status) {
+ std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
+ new tensorflow::TensorShapeProto[num_values]);
+ for (int i = 0; i < num_values; ++i) {
+ const auto num_dims_i = num_dims[i];
+
+ if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
+ TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat(
+ "Value specified for `", attr_name, "` has ", num_dims_i,
+ " dimensions which is over the limit of ",
+ tensorflow::TensorShape::MaxDimensions(), ".")
+ .c_str());
+ return;
+ }
+ if (num_dims_i < 0) {
+ proto[i].set_unknown_rank(true);
+ } else {
+ const int64_t* dims_i = dims[i];
+ auto proto_i = &proto[i];
+ for (int d = 0; d < num_dims_i; ++d) {
+ proto_i->add_dim()->set_size(dims_i[d]);
+ }
+ }
+ }
+ op->attrs.Set(attr_name,
+ tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
+ proto.get(), num_values));
+}
+
+namespace {
+
+tensorflow::Status ValidateInputTypeAndPlacement(
+ tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op,
+ const tensorflow::OpKernel* kernel) {
+ const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
+ if (memtypes.size() != op->inputs.size()) {
+ return tensorflow::errors::InvalidArgument(
+ "expected ", memtypes.size(), " inputs, got ", op->inputs.size());
+ }
+ for (int i = 0; i < op->inputs.size(); ++i) {
+ const tensorflow::Device* expected_device =
+ memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
+ const tensorflow::Device* actual_device =
+ op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
+ if (expected_device != actual_device) {
+ return tensorflow::errors::InvalidArgument(
+ "cannot compute ", op->name, " as input #", i,
+ " was expected to be on ", expected_device->name(),
+ " but is actually on ", actual_device->name(),
+ " (operation running on ", op_device->name(), ")");
+ }
+ if (op->inputs[i].dtype() != kernel->input_type(i)) {
+ return tensorflow::errors::InvalidArgument(
+ "cannot compute ", op->name, " as input #", i,
+ " was expected to be a ",
+ tensorflow::DataType_Name(kernel->input_type(i)), " tensor but is a ",
+ tensorflow::DataType_Name(op->inputs[i].dtype()), " tensor");
+ }
+ }
+ return tensorflow::Status::OK();
+}
+} // namespace
+
+void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
+ TF_Status* status) {
+ TFE_Context* ctx = op->ctx;
+ // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
+ tensorflow::Device* device =
+ (op->device == nullptr) ? ctx->devices()[0] : op->device;
+ std::vector<tensorflow::Tensor> outputs(1);
+ const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
+ tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
+ tensorflow::KernelAndDevice* kernel =
+ tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
+ if (kernel == nullptr) {
+ const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
+ kernel = new tensorflow::KernelAndDevice();
+ if (!op->is_function()) {
+ status->status =
+ tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
+ } else {
+ // Knowledge of the implementation of InitFn (and in-turn
+ // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
+ // will be accessed, so grab on to the lock.
+ // See WARNING comment below - would be nice to rework to avoid this
+ // subtlety.
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ status->status = tensorflow::KernelAndDevice::InitFn(
+ ndef, ctx->func_lib(device), kernel);
+ }
+ if (!status->status.ok()) {
+ return;
+ }
+ tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
+ }
+ status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op,
+ kernel->kernel());
+ output_memory_types = &kernel->kernel()->output_memory_types();
+ if (!status->status.ok()) {
+ return;
+ }
+ // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
+ // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
+ // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
+ // of FunctionLibraryRuntime tells use that func_lib_def is not accessed by
+ // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
+ // This is quite subtle. Re-work things to make this better? (Would it make
+ // sense for FunctionLibraryRuntime to ensure thread-safe access to
+ // FunctionLibraryDefinition?).
+ status->status = kernel->Run(&op->inputs, &outputs);
+ if (!status->status.ok()) return;
+ *num_retvals = std::min<int>(*num_retvals, outputs.size());
+ for (int i = 0; i < *num_retvals; ++i) {
+ tensorflow::Device* d = IsCPU(device) ? nullptr : device;
+ if (d != nullptr && output_memory_types != nullptr &&
+ (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
+ d = nullptr;
+ }
+ retvals[i] = new TFE_TensorHandle(outputs[i], d);
+ }
+}
+
+void TFE_ContextAddFunctionDef(TFE_Context* ctx,
+ const char* serialized_function_def, size_t size,
+ TF_Status* status) {
+ tensorflow::FunctionDef function_def;
+ if (!function_def.ParseFromArray(serialized_function_def, size)) {
+ status->status =
+ tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
+ return;
+ }
+ tensorflow::mutex_lock l(ctx->functions_mu);
+ status->status = ctx->func_lib_def.AddFunctionDef(function_def);
+}
+
+} // extern "C"
+
+TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
+ return new TFE_TensorHandle(t, nullptr);
+}
+
+const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
+ TFE_TensorHandle* h, TF_Status* status) {
+ if (h->d != nullptr) {
+ status->status = tensorflow::errors::FailedPrecondition(
+ "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
+ "a tensorflow::Tensor");
+ return nullptr;
+ }
+ return &h->t;
+}
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
new file mode 100644
index 0000000000..66a5e43bfc
--- /dev/null
+++ b/tensorflow/c/eager/c_api.h
@@ -0,0 +1,159 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EAGER_C_API_H_
+#define TENSORFLOW_C_EAGER_C_API_H_
+
+// C API extensions to experiment with eager execution of kernels.
+
+#include "tensorflow/c/c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// "Context" under which operations/functions are executed. It encapsulates
+// things like the available devices, resource manager etc.
+//
+// TODO(ashankar): Merge with TF_Session?
+typedef struct TFE_Context TFE_Context;
+
+extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
+ TF_Status* status);
+extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
+extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
+ TF_Status* status);
+
+// A handle to a tensor on a device.
+//
+// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
+// type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors
+// placed in memory of different devices or remote address spaces.
+typedef struct TFE_TensorHandle TFE_TensorHandle;
+
+extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
+extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
+extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
+extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
+extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
+extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
+extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
+ TF_Status* status);
+
+// Create a new TFE_TensorHandle with the same contents as 'h' but placed
+// in the memory of the device name 'device_name'.
+//
+// Currently requires at least one of the source or destination devices to
+// be CPU (i.e., for the source or destination tensor to be placed in
+// host memory).
+extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
+ TFE_Context* ctx,
+ const char* device_name,
+ TF_Status* status);
+
+// Description of the TensorFlow op to execute.
+//
+// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e.,
+// TFE_DeleteOp() is called before TFE_DeleteContext().
+//
+// Very similar to TF_OperationDescription with some differences:
+// (1) TF_Output or TFE_TensorHandle* as arguments to TF_AddInput,
+// TF_AddInputList
+// (2) TF_ColocateWith, TF_AddControlInput etc. do not make sense.
+// (3) Implementation detail: Avoid use of NodeBuilder/NodeDefBuilder since
+// the additional sanity checks there seem unnecessary;
+typedef struct TFE_Op TFE_Op;
+
+extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
+ TF_Status* status);
+extern void TFE_DeleteOp(TFE_Op* op);
+
+// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context
+// parameter. Instead, the TFE_Context should be captured when creating the
+// TFE_Op.
+extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
+ const char* device_name, TF_Status* status);
+
+extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
+
+extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
+ unsigned char* is_list, TF_Status* status);
+
+extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
+ const char* value);
+extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
+extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
+extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
+ unsigned char value);
+extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
+ TF_DataType value);
+// If the number of dimensions is unknown, `num_dims` must be set to
+// -1 and `dims` can be null. If a dimension is unknown, the
+// corresponding entry in the `dims` array must be -1.
+extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
+ const int64_t* dims, const int num_dims,
+ TF_Status* out_status);
+
+extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
+ const char** value, int num_values);
+extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
+ const int64_t* values, int num_values);
+extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
+ const float* values, int num_values);
+extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
+ const unsigned char* values, int num_values);
+extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
+ const TF_DataType* values, int num_values);
+extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
+ const int64_t** dims, const int* num_dims,
+ int num_values, TF_Status* out_status);
+
+// Execute the operation defined by 'op' and return handles to computed
+// tensors in 'retvals'.
+//
+// 'retvals' must point to a pre-allocated array of TFE_TensorHandle*
+// and '*num_retvals' should be set to the size of this array.
+//
+// On return, 'num_retvals' will be set to the actual number of outputs
+// returned by the operation.
+extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
+ int* num_retvals, TF_Status* status);
+
+// Add a function (serialized FunctionDef protocol buffer) to ctx so
+// that it can be invoked using TFE_Execute.
+extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
+ const char* serialized_function_def,
+ size_t size, TF_Status* status);
+
+#ifdef __cplusplus
+} /* end extern "C" */
+#endif
+
+#ifdef __cplusplus
+// A workaround to ease conversion to and from numpy objects and
+// TFE_TensorHandle's.
+//
+// TODO(ashankar): Figure out an alternative scheme that precludes the need for
+// these API-boundary breaking methods.
+namespace tensorflow {
+class Tensor;
+} // namespace tensorflow
+
+const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
+ TFE_TensorHandle* h, TF_Status* status);
+TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
+#endif
+
+#endif // TENSORFLOW_C_EAGER_C_API_H_
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
new file mode 100644
index 0000000000..797506422b
--- /dev/null
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -0,0 +1,463 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include <string.h>
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+using tensorflow::string;
+
+namespace {
+
+TFE_TensorHandle* TestMatrixTensorHandle() {
+ int64_t dims[] = {2, 2};
+ float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t);
+ TF_DeleteTensor(t);
+ return th;
+}
+
+TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, a, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, b, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrBool(op, "transpose_a", 0);
+ TFE_OpSetAttrBool(op, "transpose_b", 0);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
+
+ return op;
+}
+
+// TODO(apassos) uncomment after rewriting to use the right benchmark API
+// void BM_InitOp(benchmark::State& state) {
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// TFE_TensorHandle* m = TestMatrixTensorHandle();
+// for (auto _ : state) {
+// TFE_Op* matmul = MatMulOp(ctx, m, m);
+// TFE_DeleteOp(matmul);
+// }
+// TFE_DeleteTensorHandle(m);
+// TFE_DeleteContext(ctx, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+// BENCHMARK(BM_InitOp);
+
+// void BM_Execute(benchmark::State& state) {
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// TFE_TensorHandle* m = TestMatrixTensorHandle();
+// TFE_Op* matmul = MatMulOp(ctx, m, m);
+// TFE_TensorHandle* retvals[1];
+// int num_retvals = 1;
+// for (auto _ : state) {
+// TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// }
+// TFE_DeleteOp(matmul);
+// TFE_DeleteTensorHandle(m);
+// TFE_DeleteContext(ctx, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+// BENCHMARK(BM_Execute);
+
+TEST(CAPI, Context) {
+ TF_Status* status = TF_NewStatus();
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ TF_DeleteSessionOptions(opts);
+
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_DeleteContext(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ const int num_devices = TF_DeviceListCount(devices);
+ EXPECT_GE(num_devices, 1) << "At least one CPU device should exist";
+ for (int i = 0; i < num_devices; ++i) {
+ EXPECT_NE("", TF_DeviceListName(devices, i, status)) << i;
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ }
+ TF_DeleteDeviceList(devices);
+ TF_DeleteStatus(status);
+}
+
+TEST(CAPI, TensorHandle) {
+ TFE_TensorHandle* h = TestMatrixTensorHandle();
+ EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
+
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
+ ASSERT_EQ(16, TF_TensorByteSize(t));
+ float data[4] = {0};
+ memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
+ EXPECT_EQ(1.0, data[0]);
+ EXPECT_EQ(2.0, data[1]);
+ EXPECT_EQ(3.0, data[2]);
+ EXPECT_EQ(4.0, data[3]);
+ TF_DeleteTensor(t);
+ TFE_DeleteTensorHandle(h);
+}
+
+TEST(CAPI, TensorHandleCopyBetweenDevices) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status.get());
+ TF_DeleteSessionOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+ TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ const int num_devices = TF_DeviceListCount(devices);
+
+ const char* kCPUDevice = "CPU:0";
+ for (int i = 0; i < num_devices; ++i) {
+ const string name(TF_DeviceListName(devices, i, status.get()));
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << i << " -- " << TF_Message(status.get());
+ continue;
+ }
+ auto tag = tensorflow::strings::StrCat("Device #", i, " (", name, ")");
+ // Copy to device
+ TFE_TensorHandle* hdevice =
+ TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
+ continue;
+ }
+ // Copy back to CPU
+ TFE_TensorHandle* hcopy =
+ TFE_TensorHandleCopyToDevice(hdevice, ctx, kCPUDevice, status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
+ continue;
+ }
+ TFE_DeleteTensorHandle(hdevice);
+
+ // Ensure that the contents are the same!
+ TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
+ TFE_DeleteTensorHandle(hcopy);
+ if (TF_GetCode(status.get()) != TF_OK) {
+ ADD_FAILURE() << tag;
+ continue;
+ }
+ EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)) << tag;
+ EXPECT_EQ(
+ 0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)))
+ << tag;
+ TF_DeleteTensor(tcopy);
+ }
+
+ TF_DeleteDeviceList(devices);
+ TF_DeleteTensor(t);
+ TFE_DeleteTensorHandle(hcpu);
+ TFE_DeleteContext(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+}
+
+TEST(CAPI, Execute) {
+ TF_Status* status = TF_NewStatus();
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteSessionOptions(opts);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+ TFE_Op* matmul = MatMulOp(ctx, m, m);
+ TFE_TensorHandle* retvals[2] = {nullptr};
+ int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteOp(matmul);
+ TFE_DeleteTensorHandle(m);
+ TFE_DeleteContext(ctx, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ TFE_DeleteTensorHandle(retvals[0]);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+ TF_DeleteStatus(status);
+}
+
+string MatMulFunction() {
+ tensorflow::FunctionDef def;
+ CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
+ " signature {"
+ " name: 'MatMulFunction'"
+ " input_arg {"
+ " name: 'a'"
+ " type: DT_FLOAT"
+ " }"
+ " output_arg {"
+ " name: 'm'"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " node_def {"
+ " name: 'matmul'"
+ " op: 'MatMul'"
+ " input: 'a'"
+ " input: 'a'"
+ " attr {"
+ " key: 'T'"
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " }"
+ " ret {"
+ " key: 'm'"
+ " value: 'matmul:product'"
+ " }",
+ &def));
+ return def.SerializeAsString();
+}
+
+TEST(CAPI, FunctionDefAndExecute) {
+ TF_Status* status = TF_NewStatus();
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteSessionOptions(opts);
+
+ string function_def = MatMulFunction();
+ TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
+ status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+ TFE_TensorHandle* retval[1] = {nullptr};
+ int num_retvals = 1;
+ TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, m, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_Execute(op, &retval[0], &num_retvals, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+ TFE_DeleteOp(op);
+ TFE_DeleteTensorHandle(m);
+ TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
+ TFE_DeleteTensorHandle(retval[0]);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+ TFE_DeleteContext(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+}
+
+// TODO(apassos) uncomment after rewriting to use the right benchmark API
+// void BM_ExecuteFunction(benchmark::State& state) {
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// string function_def = MatMulFunction();
+// TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
+// status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+// TFE_TensorHandle* m = TestMatrixTensorHandle();
+// TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TFE_OpAddInput(matmul, m, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TFE_TensorHandle* retval[1] = {nullptr};
+// int num_retvals = 1;
+// for (auto _ : state) {
+// TFE_Execute(matmul, &retval[0], &num_retvals, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// }
+// TFE_DeleteTensorHandle(m);
+// TFE_DeleteTensorHandle(retval[0]);
+// TFE_DeleteContext(ctx, status);
+// EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+// BENCHMARK(BM_ExecuteFunction);
+
+// TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
+// TF_Status* status) {
+// // Create the variable handle.
+// TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+// TFE_OpSetAttrShape(op, "shape", {}, 0, status);
+// TFE_OpSetAttrString(op, "container", "");
+// TFE_OpSetAttrString(op, "shared_name", "");
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// TFE_TensorHandle* var_handle = nullptr;
+// int num_retvals = 1;
+// TFE_Execute(op, &var_handle, &num_retvals, status);
+// TFE_DeleteOp(op);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// CHECK_EQ(1, num_retvals);
+
+// // Assign 'value' to it.
+// op = TFE_NewOp(ctx, "AssignVariableOp", status);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+// TFE_OpAddInput(op, var_handle, status);
+
+// // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
+// std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
+// TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)),
+// TF_DeleteTensor);
+// memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
+
+// std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+// value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
+
+// TFE_OpAddInput(op, value_handle.get(), status);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+
+// num_retvals = 0;
+// TFE_Execute(op, nullptr, &num_retvals, status);
+// TFE_DeleteOp(op);
+// if (TF_GetCode(status) != TF_OK) return nullptr;
+// CHECK_EQ(0, num_retvals);
+
+// return var_handle;
+// }
+
+// TEST(CAPI, Variables) {
+// // Variables use resource handles, so this is really a test for resource
+// // tensor handling.
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+// TFE_OpAddInput(op, var_handle, status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// int num_retvals = 1;
+// TFE_TensorHandle* value_handle = nullptr;
+// TFE_Execute(op, &value_handle, &num_retvals, status);
+// TFE_DeleteOp(op);
+
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// ASSERT_EQ(1, num_retvals);
+// EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
+// EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle));
+// float value = 0.0f;
+// TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
+// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
+// memcpy(&value, TF_TensorData(t), sizeof(float));
+// TF_DeleteTensor(t);
+// EXPECT_EQ(12.0, value);
+
+// TFE_DeleteTensorHandle(var_handle);
+// TFE_DeleteTensorHandle(value_handle);
+// TFE_DeleteContext(ctx, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+
+// void BM_ReadVariable(benchmark::State& state) {
+// TF_Status* status = TF_NewStatus();
+// TF_SessionOptions* opts = TF_NewSessionOptions();
+// TFE_Context* ctx = TFE_NewContext(opts, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteSessionOptions(opts);
+
+// TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+// TFE_OpAddInput(op, var_handle, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+// int num_retvals = 1;
+// TFE_TensorHandle* h = nullptr;
+// for (auto _ : state) {
+// TFE_Execute(op, &h, &num_retvals, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// CHECK_EQ(1, num_retvals);
+// CHECK(h);
+// CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
+// CHECK_EQ(0, TFE_TensorHandleNumDims(h));
+// h = nullptr;
+// }
+// TFE_DeleteOp(op);
+
+// TFE_DeleteTensorHandle(var_handle);
+// TFE_DeleteContext(ctx, status);
+// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+// TF_DeleteStatus(status);
+// }
+// BENCHMARK(BM_ReadVariable);
+
+} // namespace
diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc
new file mode 100644
index 0000000000..87e9f5377f
--- /dev/null
+++ b/tensorflow/c/eager/runtime.cc
@@ -0,0 +1,289 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/runtime.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+namespace tensorflow {
+namespace {
+
+mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
+
+std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
+ static auto* const m = new std::unordered_map<string, const AttrTypeMap*>;
+ return m;
+}
+
+const uint32 kIsList = 1U << 31;
+
+} // namespace
+
+Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
+ mutex_lock l(g_op_name_to_attr_type_map_lock);
+ *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
+ if (*out != nullptr) return Status::OK();
+ const OpRegistrationData* op_reg_data = nullptr;
+ Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
+ if (!s.ok()) return s;
+ std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
+ // TODO(agarwal): Avoid having to create this "registry" at runtime,
+ // perhaps can be done at op registration time?
+ for (const auto& attr : op_reg_data->op_def.attr()) {
+ string type = attr.type();
+ const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
+ if (is_list) {
+ type = type.substr(5, type.length() - 6);
+ }
+ uint32 t = is_list ? kIsList : 0;
+ if (type == "string") {
+ t |= TF_ATTR_STRING;
+ } else if (type == "int") {
+ t |= TF_ATTR_INT;
+ } else if (type == "float") {
+ t |= TF_ATTR_FLOAT;
+ } else if (type == "bool") {
+ t |= TF_ATTR_BOOL;
+ } else if (type == "type") {
+ t |= TF_ATTR_TYPE;
+ } else if (type == "shape") {
+ t |= TF_ATTR_SHAPE;
+ } else if (type == "tensor") {
+ t |= TF_ATTR_TENSOR;
+ } else {
+ return errors::Unimplemented(
+ "TODO(agarwal): Enable support for ops with attributes of type '",
+ type, "'");
+ }
+ gtl::InsertIfNotPresent(m.get(), attr.name(), t);
+ }
+ *out = m.get();
+ (*OpNameToAttrTypeMap())[op_name] = m.release();
+ return Status::OK();
+}
+
+Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
+ TF_AttrType* out, unsigned char* is_list) {
+ CHECK(m);
+ auto* t = gtl::FindOrNull(*m, attr_name);
+ if (t == nullptr) {
+ return errors::InvalidArgument("Attribute '", attr_name,
+ "' does not exist for this operation");
+ }
+ *out = static_cast<TF_AttrType>(*t & ~kIsList);
+ if (*t & kIsList) {
+ *is_list = 1;
+ } else {
+ *is_list = 0;
+ }
+ return Status::OK();
+}
+
+#define DEFINE_SET_ATTR(value_type, value_field) \
+ template <> \
+ AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
+ value_field.push_back(std::make_pair(attr_name, value)); \
+ return *this; \
+ }
+
+DEFINE_SET_ATTR(StringPiece, string_attrs_);
+DEFINE_SET_ATTR(float, float_attrs_);
+DEFINE_SET_ATTR(int, int_attrs_);
+DEFINE_SET_ATTR(bool, bool_attrs_);
+DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
+
+#undef DEFINE_SET_ATTR
+
+AttrBuilder& AttrBuilder::NumInputs(int n) {
+ DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
+ num_inputs_ = n;
+ return *this;
+}
+
+const NodeDef& AttrBuilder::BuildNodeDef() {
+ if (node_def_finalized_) return *node_def_;
+ MayBeInitializeNodeDef();
+ for (int i = 0; i < num_inputs_; ++i) {
+ node_def_->add_input("dummy_input");
+ }
+ for (const auto& p : string_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ for (const auto& p : int_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ for (const auto& p : float_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ for (const auto& p : bool_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ for (const auto& p : type_attrs_) {
+ SetInNodeDef(p.first, p.second);
+ }
+ node_def_finalized_ = true;
+ return *node_def_;
+}
+
+namespace {
+inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
+ const tensorflow::Fprint128& b) {
+ return {tensorflow::FingerprintCat64(a.low64, b.low64),
+ tensorflow::FingerprintCat64(a.low64, b.low64)};
+}
+
+void CombineUnordered(const tensorflow::Fprint128& a,
+ tensorflow::Fprint128* b) {
+ b->low64 += a.low64;
+ b->high64 += a.high64;
+}
+
+inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
+ const tensorflow::Fprint128& b) {
+ // TODO(agarwal): avoid ToString().
+ tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
+ return FingerprintCat128(a, b);
+}
+
+inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
+ return CacheKeyHelper(s, {b, b});
+}
+
+} // namespace
+
+tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
+ tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
+ f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
+ if (node_def_ != nullptr) {
+ // Some attributes are directly written to node_def_ instead of being
+ // stored explicitly.
+ string value;
+ for (const auto& attr : node_def_->attr()) {
+ attr.second.SerializeToString(&value);
+ CombineUnordered(
+ CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
+ }
+ // Note that node_def_ may be created but not finalized. This can happen
+ // when the creation was triggered by a call to Set, but BuildNodeDef has
+ // not been called.
+ if (node_def_finalized_) return f;
+ }
+ for (const auto& p : string_attrs_) {
+ // TODO(agarwal): avoid ToString().
+ CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128(
+ p.second.ToString())),
+ &f);
+ }
+ for (const auto& p : int_attrs_) {
+ CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
+ &f);
+ }
+ static std::hash<float> float_hasher;
+ for (const auto& p : float_attrs_) {
+ CombineUnordered(
+ CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
+ &f);
+ }
+ for (const auto& p : bool_attrs_) {
+ CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
+ }
+ for (const auto& p : type_attrs_) {
+ CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
+ &f);
+ }
+ return f;
+}
+
+void AttrBuilder::MayBeInitializeNodeDef() {
+ if (node_def_ == nullptr) {
+ node_def_.reset(new NodeDef());
+ node_def_->set_name(op_name_);
+ node_def_->set_op(op_name_);
+ }
+}
+
+// static
+Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
+ KernelAndDevice* out) {
+ OpKernel* k = nullptr;
+ Status s = CreateOpKernel(device->device_type().c_str(), device,
+ device->GetAllocator(AllocatorAttributes()),
+ nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
+ out->device_ = device;
+ out->kernel_.reset(k);
+ out->flib_ = nullptr;
+ return s;
+}
+
+// static
+Status KernelAndDevice::InitFn(const NodeDef& ndef,
+ FunctionLibraryRuntime* flib,
+ KernelAndDevice* out) {
+ OpKernel* k = nullptr;
+ Status s = flib->CreateKernel(ndef, &k);
+ out->device_ = flib->device();
+ out->kernel_.reset(k);
+ out->flib_ = flib;
+ return s;
+}
+
+Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
+ std::vector<Tensor>* output_tensors) {
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ for (Tensor& t : *input_tensors) {
+ inputs.push_back(TensorValue(&t));
+ }
+
+ std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
+ for (size_t i = 0; i < out_attrs.size(); ++i) {
+ out_attrs[i].set_on_host(kernel_->output_memory_types()[i] ==
+ tensorflow::HOST_MEMORY);
+ }
+
+ OpKernelContext::Params params;
+ params.device = device_;
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = kernel_.get();
+ params.resource_manager = device_->resource_manager();
+ params.output_attr_array = gtl::vector_as_array(&out_attrs);
+ params.function_library = flib_;
+ params.slice_reader_cache = &slice_reader_cache_;
+ // TODO(apassos): use a thread pool.
+ std::function<void(std::function<void()>)> runner =
+ [](std::function<void()> f) { f(); };
+ params.runner = &runner;
+
+ OpKernelContext context(&params);
+ device_->Compute(kernel_.get(), &context);
+ if (!context.status().ok()) return context.status();
+
+ output_tensors->clear();
+ for (int i = 0; i < context.num_outputs(); ++i) {
+ output_tensors->push_back(Tensor(*context.mutable_output(i)));
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h
new file mode 100644
index 0000000000..5916302ff4
--- /dev/null
+++ b/tensorflow/c/eager/runtime.h
@@ -0,0 +1,193 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EAGER_RUNTIME_H_
+#define TENSORFLOW_C_EAGER_RUNTIME_H_
+
+// Support for eager execution of TensorFlow kernels.
+
+#include <memory>
+#include <unordered_map>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+namespace tensorflow {
+
+// Maps attribute name to an encoding of the type of the attribute value.
+// If the type is not a list type, the value is the same as the TF_AttrType type
+// of the value. Else, the highest order bit is on, and the rest of the bits
+// represent the TF_AttrType type of the values in the list.
+typedef std::unordered_map<string, uint32> AttrTypeMap;
+
+// Returns the AttrTypeMap for the TensorFlow operation named op_name.
+Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
+
+// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
+Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
+ TF_AttrType* out, unsigned char* is_list);
+
+// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
+// An AttrBuilder is a convenience class to help with that - providing a smaller
+// interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
+// checks (like number of inputs matching the OpDef - we only care about
+// attributes here).
+//
+// TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
+// ones make sense to replicate.
+
+// This is a helper class for creating a NodeDef. Additionally, this class
+// allows computing a cache key based on fingerprinting the attributes of this
+// NodeDef.
+//
+// Example usage:
+// AttrBuilder a;
+// a.NumInputs(2);
+// a.Set("T", TF_FLOAT);
+// uint64 cache_key = a.CacheKey("cpu:0");
+// const NodeDef& n = a.BuildNodeDef();
+//
+// Note that all calls to Set and NumInputs should happen before calling
+// BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
+// to CacheKey may cause different values to be returned by CacheKey.
+//
+// For performance reasons, the class internally delays the actual construction
+// of the NodeDef till BuildNodeDef is called, or Set is called with certain
+// uncommon types (see template specializations of Set to see which types
+// trigger a NodeDef creation).
+class AttrBuilder {
+ public:
+ explicit AttrBuilder(const char* op)
+ : op_name_(op),
+ num_inputs_(0),
+ node_def_(nullptr),
+ node_def_finalized_(false) {}
+
+ // Needed to work around call to ValidateNodeDef in CreateOpKernel.
+ AttrBuilder& NumInputs(int n);
+
+ template <class T>
+ AttrBuilder& Set(StringPiece attr_name, T&& value) {
+ MayBeInitializeNodeDef();
+ return SetInNodeDef(attr_name, value);
+ }
+
+ tensorflow::Fprint128 CacheKey(const string& device) const;
+
+ const NodeDef& BuildNodeDef();
+
+ private:
+ template <class T>
+ using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
+
+ void MayBeInitializeNodeDef();
+
+ template <class T>
+ AttrBuilder& SetInNodeDef(StringPiece attr_name, T&& value) {
+ DCHECK(!node_def_finalized_) << "Calling SetInNodeDef after BuildNodeDef.";
+ // Copied from NodeDefBuilder::Attr
+ const AttrValue* found = AttrSlice(*node_def_).Find(attr_name);
+ if (found == nullptr) {
+ AddNodeAttr(attr_name, std::forward<T>(value), node_def_.get());
+ } else {
+ AttrValue attr_value;
+ SetAttrValue(std::forward<T>(value), &attr_value);
+ // TODO(ashankar): Do what is done in
+ // NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value);
+ }
+ return *this;
+ }
+
+ AttrVec<StringPiece> string_attrs_;
+ AttrVec<int> int_attrs_;
+ AttrVec<float> float_attrs_;
+ AttrVec<bool> bool_attrs_;
+ AttrVec<tensorflow::DataType> type_attrs_;
+ string op_name_;
+ int num_inputs_;
+ std::unique_ptr<NodeDef> node_def_;
+ bool node_def_finalized_;
+}; // namespace tensorflow
+
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
+template <>
+AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
+ tensorflow::DataType&& value);
+
+// KernelAndDevice encapsulates an instantiated kernel and the device it is on.
+//
+// Also see:
+// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
+// and
+// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
+class KernelAndDevice {
+ public:
+ // Populates 'out' with a kernel appropriate for 'ndef'.
+ //
+ // Assumes that 'ndef' refers to a primitive op (as opposed to a function).
+ static Status InitOp(Device* device, const NodeDef& ndef,
+ KernelAndDevice* out);
+
+ // Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
+ // TensorFlow function in the FunctionLibraryRuntime).
+ //
+ // The provided FunctionLibraryRuntime MUST outlive all calls to
+ // Run() on the returned KernelAndDevice.
+ //
+ // TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
+ // The implementation of InitFn should work for both because
+ // FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
+ // appropriate. However, for now we keep them separate because I haven't
+ // figured out thread-safety concerns around FunctionLibraryRuntime (in
+ // particular, how the underlying FunctionLibraryDefinition might be mutated
+ // by another thread as new functions are registered with it).
+ // Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
+ // on to the caller (see locking in c_api.cc) for now. But I really should
+ // dig into this so that both InitOp and InitFn can be collapsed to
+ // FunctionLibraryRuntime::CreateKernel.
+ static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
+ KernelAndDevice* out);
+
+ KernelAndDevice() : device_(nullptr), flib_(nullptr) {}
+
+ // TODO(ashankar): Handle list-valued inputs.
+ Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs);
+
+ const OpKernel* kernel() const { return kernel_.get(); }
+
+ private:
+ std::unique_ptr<OpKernel> kernel_;
+ tensorflow::Device* device_;
+ tensorflow::FunctionLibraryRuntime* flib_;
+ tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_C_EAGER_RUNTIME_H_
diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc
new file mode 100644
index 0000000000..3b38e24704
--- /dev/null
+++ b/tensorflow/c/eager/runtime_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/runtime.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+Device* CPUDevice() {
+ return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
+}
+
+TEST(AttrTypeMap, Lookup) {
+ const AttrTypeMap* m = nullptr;
+ Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m);
+ EXPECT_FALSE(s.ok());
+ s = AttrTypeMapForOp("MatMul", &m);
+ ASSERT_TRUE(s.ok()) << s;
+
+ TF_AttrType t;
+ unsigned char is_list = 1;
+ s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
+ EXPECT_FALSE(s.ok());
+ EXPECT_NE(is_list, 0);
+ s = AttrTypeByName(m, "transpose_a", &t, &is_list);
+ ASSERT_TRUE(s.ok()) << s;
+ EXPECT_EQ(TF_ATTR_BOOL, t);
+ EXPECT_EQ(is_list, 0);
+
+ s = AttrTypeMapForOp("Squeeze", &m);
+ ASSERT_TRUE(s.ok()) << s;
+ s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
+ ASSERT_TRUE(s.ok()) << s;
+ EXPECT_EQ(TF_ATTR_INT, t);
+ EXPECT_NE(is_list, 0);
+}
+
+TEST(KernelAndDevice, Run) {
+ Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
+ std::vector<Tensor> inputs;
+ inputs.push_back(t);
+ inputs.push_back(t);
+ NodeDef ndef(AttrBuilder("MatMul")
+ .Set("T", DT_FLOAT)
+ .Set("transpose_a", false)
+ .Set("transpose_b", false)
+ .NumInputs(inputs.size())
+ .BuildNodeDef());
+ std::unique_ptr<Device> device(CPUDevice());
+ KernelAndDevice kernel;
+ Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
+ ASSERT_TRUE(s.ok()) << s;
+ std::vector<Tensor> outputs;
+ s = kernel.Run(&inputs, &outputs);
+ ASSERT_TRUE(s.ok()) << s;
+ ASSERT_EQ(1, outputs.size());
+ const Tensor& out = outputs[0];
+ EXPECT_EQ(7, out.matrix<float>()(0, 0));
+ EXPECT_EQ(10, out.matrix<float>()(0, 1));
+ EXPECT_EQ(15, out.matrix<float>()(1, 0));
+ EXPECT_EQ(22, out.matrix<float>()(1, 1));
+}
+
+// TODO(apassos) uncomment after rewriting to use the right benchmark API
+// void BM_CreateGraph(benchmark::State& state) {
+// for (auto _ : state) {
+// Scope root = Scope::NewRootScope();
+// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
+// auto M = ops::MatMul(root, C, C);
+// TF_CHECK_OK(root.status());
+// }
+// }
+// BENCHMARK(BM_CreateGraph);
+
+// void BM_RunGraph(benchmark::State& state) {
+// Scope root = Scope::NewRootScope();
+// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
+// auto M = ops::MatMul(root, C, C);
+// SessionOptions opts;
+// opts.config.set_inter_op_parallelism_threads(1);
+// opts.config.set_intra_op_parallelism_threads(1);
+// ClientSession sess(root, opts);
+// std::vector<Tensor> outputs;
+// for (auto _ : state) {
+// outputs.clear();
+// TF_CHECK_OK(sess.Run({M}, &outputs));
+// }
+// }
+// BENCHMARK(BM_RunGraph);
+
+// void BM_CreateAndDestroySession(benchmark::State& state) {
+// Scope root = Scope::NewRootScope();
+// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
+// auto M = ops::MatMul(root, C, C);
+// for (auto _ : state) {
+// ClientSession sess(root);
+// }
+// }
+// BENCHMARK(BM_CreateAndDestroySession);
+
+// void BM_KernelAndDeviceInit(benchmark::State& state) {
+// NodeDef ndef(AttrBuilder("MatMul")
+// .Set("T", DT_FLOAT)
+// .Set("transpose_a", false)
+// .Set("transpose_b", false)
+// .NumInputs(2)
+// .BuildNodeDef());
+// std::unique_ptr<Device> device(CPUDevice());
+// KernelAndDevice k;
+// for (auto _ : state) {
+// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
+// }
+// }
+// BENCHMARK(BM_KernelAndDeviceInit);
+
+// void BM_KernelAndDeviceRun(benchmark::State& state) {
+// Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
+// std::vector<Tensor> inputs;
+// inputs.push_back(t);
+// inputs.push_back(t);
+// std::vector<Tensor> outputs;
+// NodeDef ndef(AttrBuilder("MatMul")
+// .Set("T", DT_FLOAT)
+// .Set("transpose_a", false)
+// .Set("transpose_b", false)
+// .NumInputs(inputs.size())
+// .BuildNodeDef());
+// std::unique_ptr<Device> device(CPUDevice());
+// KernelAndDevice kernel;
+// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
+// for (auto _ : state) {
+// TF_CHECK_OK(kernel.Run(&inputs, &outputs));
+// }
+// }
+// BENCHMARK(BM_KernelAndDeviceRun);
+} // namespace
+} // namespace tensorflow