diff options
author | 2017-08-10 14:19:55 -0700 | |
---|---|---|
committer | 2017-08-10 14:22:58 -0700 | |
commit | 13eb3b90e9ed8778ffd2b1bf6401677938b1ec39 (patch) | |
tree | 40a2e7e926f3ed9fa0b99f88056bacc471547be7 /tensorflow/c | |
parent | 7dfabcc01c9c752747c473346bb3f8c1cd290ad1 (diff) |
Experimental C and Python APIs to invoke TensorFlow kernels on concrete values.
PiperOrigin-RevId: 164902588
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/eager/BUILD | 67 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 561 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.h | 159 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 463 | ||||
-rw-r--r-- | tensorflow/c/eager/runtime.cc | 289 | ||||
-rw-r--r-- | tensorflow/c/eager/runtime.h | 193 | ||||
-rw-r--r-- | tensorflow/c/eager/runtime_test.cc | 160 |
7 files changed, 1892 insertions, 0 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD new file mode 100644 index 0000000000..18e1a68a93 --- /dev/null +++ b/tensorflow/c/eager/BUILD @@ -0,0 +1,67 @@ +# Experimental extensions to the C API for eager execution of kernels. +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = ["c_api.h"], + visibility = [ + "//tensorflow:internal", + "//tensorflow/python/eager:__pkg__", + ], + deps = [ + ":runtime", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "c_api_test", + srcs = ["c_api_test.cc"], + deps = [ + ":c_api", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "runtime", + srcs = ["runtime.cc"], + hdrs = ["runtime.h"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +cc_test( + name = "runtime_test", + srcs = ["runtime_test.cc"], + deps = [ + ":runtime", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc new file mode 100644 index 0000000000..05e09ea120 --- /dev/null +++ b/tensorflow/c/eager/c_api.cc @@ -0,0 +1,561 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include <algorithm> +#include <cstddef> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/runtime.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" + +using tensorflow::int64; +using tensorflow::string; + +namespace { +bool IsCPU(tensorflow::Device* d) { + return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; +} + +string DeviceName(tensorflow::Device* d) { + return (d == nullptr) ? "cpu:0" : d->name(); +} +} // namespace + +struct TFE_Context { + explicit TFE_Context(TF_Session* s) : session(s) {} + + // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. + TF_Session* session; + + tensorflow::mutex functions_mu; + tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ + tensorflow::OpRegistry::Global(), {}}; + + // One FunctionLibraryRuntime per device. + // func_libs[i] is the FunctionLibraryRuntime corresponding to + // session->devices[i]. + std::vector<std::unique_ptr<tensorflow::FunctionLibraryRuntime> > func_libs; + + std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*, + tensorflow::Fprint128Hasher> + kernel_cache; + + tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { + for (int i = 0; i < session->devices.size(); ++i) { + if (session->devices[i] == d) { + return func_libs[i].get(); + } + } + return nullptr; + } + + const std::vector<tensorflow::Device*>& devices() { return session->devices; } +}; + +struct TFE_TensorHandle { + TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) + : t(t), d(d) {} + + tensorflow::Tensor t; + // TODO(ashankar): d == nullptr iff local CPU + // This was expedient, but perhaps worth revisiting ('d' should always be a + // valid pointer?) + // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are + // provided with the appropriate TFE_Context. + // + // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a + // TFE_TensorHandle does not outlive the TFE_Context from which it came? + tensorflow::Device* d; +}; + +struct TFE_Op { + TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) + : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} + + bool const is_function() const { return attr_types == nullptr; } + + TFE_Context* ctx; // Must outlive the TFE_Op. + const char* name; + tensorflow::AttrBuilder attrs; + const tensorflow::AttrTypeMap* attr_types; + std::vector<tensorflow::Tensor> inputs; + std::vector<tensorflow::Device*> input_devices; + tensorflow::Device* device; +}; + +extern "C" { + +TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { + TF_Graph* graph = TF_NewGraph(); + TF_Session* session = TF_NewSession(graph, opts, status); + if (status->status.ok()) { + if (session->device_mgr == nullptr || session->devices.empty()) { + status->status = tensorflow::errors::InvalidArgument( + "Provided TF_SessionOptions are not compatible with eager execution " + "(perhaps the TF_SessionOptions alluded to session execution in a " + "remote address space?)"); + } + } + if (!status->status.ok()) { + TF_DeleteGraph(graph); + return nullptr; + } + + TFE_Context* ret = new TFE_Context(session); + ret->func_libs.resize(ret->devices().size()); + for (int i = 0; i < ret->devices().size(); ++i) { + ret->func_libs[i] = tensorflow::NewFunctionLibraryRuntime( + ret->session->device_mgr, opts->options.env, ret->devices()[i], + TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}); + } + + return ret; +} + +void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { + status->status = tensorflow::Status::OK(); + tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); + TF_Graph* graph = ctx->session->graph; + TF_DeleteSession(ctx->session, status); + TF_DeleteGraph(graph); + delete ctx; +} + +TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { + return TF_SessionListDevices(ctx->session, status); +} + +TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) { + return new TFE_TensorHandle( + tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer), + nullptr); +} + +void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; } + +TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { + return static_cast<TF_DataType>(h->t.dtype()); +} + +int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); } + +int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) { + return h->t.dim_size(dim_index); +} + +const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) { + // This might be a bit confusing as a tensor on CPU can sometimes return + // "CPU:0" and sometimes "/job:localhost/replica:0/task:0/cpu:0". + // TODO(ashankar): Figure out which one would be nicer. + return (h->d == nullptr) ? "CPU:0" : h->d->name().c_str(); +} + +TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { + if (!IsCPU(h->d)) { + TF_SetStatus(status, TF_UNIMPLEMENTED, + tensorflow::strings::StrCat( + "TFE_TensorHandle can be resolved iff it is on CPU (this " + "handle is on ", + h->d->name(), + "). Consider using TFE_TensorHandleCopyToDevice to get a " + "copy of the tensor on CPU") + .c_str()); + return nullptr; + } + return tensorflow::TF_TensorFromTensor(h->t, status); +} + +TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, + TFE_Context* ctx, + const char* device_name, + TF_Status* status) { + tensorflow::Device* dstd = nullptr; + status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd); + if (!status->status.ok()) return nullptr; + + tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d; + const bool src_cpu = IsCPU(srcd); + const bool dst_cpu = IsCPU(dstd); + if (!src_cpu && !dst_cpu) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + tensorflow::strings::StrCat( + "TFE_TensorHandleCopyToDevice requires either the source " + "TFE_TensorHandle be on or the destination device be CPU (they " + "are ", + DeviceName(srcd), " and ", DeviceName(dstd), " in this call)") + .c_str()); + return nullptr; + } + tensorflow::Tensor* src = &(h->t); + if (src_cpu && dst_cpu) { + // There must be a better way, but for now redirect through proto to ensure + // that the underlying buffers are not shared. + tensorflow::TensorProto proto; + src->AsProtoTensorContent(&proto); + tensorflow::Tensor dst(src->dtype(), src->shape()); + if (!dst.FromProto(proto)) { + TF_SetStatus( + status, TF_INTERNAL, + tensorflow::strings::StrCat( + "error copying between TFE_TensorHandles on CPU. Consider filing " + "a bug report at https://github.com/tensorflow/tensorflow/issues " + "mentioning version: ", + TF_Version(), " and ", __FILE__, ":", __LINE__) + .c_str()); + return nullptr; + } + return new TFE_TensorHandle(dst, nullptr); + } + if (src_cpu) { + tensorflow::Tensor dst( + dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(), + src->shape()); + tensorflow::Notification n; + dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice( + src, dstd, &dst, [status, &n](const tensorflow::Status& s) { + status->status = s; + n.Notify(); + }); + n.WaitForNotification(); + return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd) + : nullptr; + } + CHECK(dst_cpu); + tensorflow::Tensor dst(src->dtype(), src->shape()); + tensorflow::Notification n; + // TODO(ashankar): The Sync() call below may be more aggressive than + // necessary. It is based on knowledge of implementation details - that + // GPU devices are implemented using 3 streams - one for host->device copies, + // one for device->host copies and one for sending operations to the GPU. + // With that setup, Sync()ing across all 3 streams should be sufficient + // but more than necessary (since it waits for operations that might have + // nothing to do with this tensor to complete). + status->status = srcd->Sync(); + if (!status->status.ok()) return nullptr; + srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU( + src, "IGNORE_MY_TENSOR_NAME", srcd, &dst, + [status, &n](const tensorflow::Status& s) { + status->status = s; + n.Notify(); + }); + n.WaitForNotification(); + return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr) + : nullptr; +} + +TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, + TF_Status* status) { + const char* name = op_or_function_name; // Shorthand + const tensorflow::AttrTypeMap* types; + status->status = tensorflow::AttrTypeMapForOp(name, &types); + if (status->status.ok()) return new TFE_Op(ctx, name, types); + if (TF_GetCode(status) == TF_NOT_FOUND) { + tensorflow::mutex_lock l(ctx->functions_mu); + if (ctx->func_lib_def.Find(name) != nullptr) { + status->status = tensorflow::Status::OK(); + return new TFE_Op(ctx, name, nullptr); + } + } + return nullptr; +} + +void TFE_DeleteOp(TFE_Op* op) { delete op; } + +static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device, + TF_Status* status) { + // Questionable heuristic: Place the op on the same device as the first input + // placed outside of host memory? + if (IsCPU(op->device) && !IsCPU(device)) { + op->device = device; + } +} + +void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, const char* device_name, + TF_Status* status) { + tensorflow::Device* d = nullptr; + status->status = ctx->session->device_mgr->LookupDevice(device_name, &d); + if (!status->status.ok()) return; + TFE_OpSetDeviceHelper(op, d, status); +} + +void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { + TFE_OpSetDeviceHelper(op, h->d, status); + if (!status->status.ok()) return; + op->inputs.push_back(h->t); + op->input_devices.push_back(h->d); + op->attrs.NumInputs(op->inputs.size()); +} + +TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, + unsigned char* is_list, TF_Status* status) { + TF_AttrType ret; + if (op->is_function()) { + status->status = tensorflow::errors::Unimplemented( + "TODO(apassos): Support for attributes for TensorFlow functions is not " + "ready yet."); + return TF_ATTR_INT; // The compiler requires that we return something. + } + status->status = + tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list); + return ret; +} + +void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { + op->attrs.Set(attr_name, value); +} + +void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { + op->attrs.Set(attr_name, static_cast<int64>(value)); +} + +void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { + op->attrs.Set(attr_name, value); +} + +void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { + op->attrs.Set(attr_name, (value == 0) ? false : true); +} + +void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { + op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value)); +} + +void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, + const int num_dims, TF_Status* out_status) { + if (num_dims > tensorflow::TensorShape::MaxDimensions()) { + TF_SetStatus(out_status, TF_INVALID_ARGUMENT, + tensorflow::strings::StrCat( + "Value specified for `", attr_name, "` has ", num_dims, + " dimensions which is over the limit of ", + tensorflow::TensorShape::MaxDimensions(), ".") + .c_str()); + return; + } + tensorflow::TensorShapeProto proto; + if (num_dims < 0) { + proto.set_unknown_rank(true); + } else { + for (int d = 0; d < num_dims; ++d) { + proto.add_dim()->set_size(dims[d]); + } + } + op->attrs.Set(attr_name, proto); +} + +#define TFE_OP_SET_ATTR_LIST(fn, type) \ + void fn(TFE_Op* op, const char* attr_name, const type* values, \ + int num_values) { \ + op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \ + values, num_values)); \ + } +TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) +TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) +#undef TFE_OP_SET_ATTR_LIST + +void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, + const int64_t* values, int num_values) { + op->attrs.Set(attr_name, + tensorflow::gtl::ArraySlice<const int64>( + reinterpret_cast<const int64*>(values), num_values)); +} + +void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, + const TF_DataType* values, int num_values) { + op->attrs.Set( + attr_name, + tensorflow::gtl::ArraySlice<const tensorflow::DataType>( + reinterpret_cast<const tensorflow::DataType*>(values), num_values)); +} + +void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, + const unsigned char* values, int num_values) { + std::unique_ptr<bool[]> b(new bool[num_values]); + for (int i = 0; i < num_values; ++i) { + b[i] = values[i]; + } + op->attrs.Set(attr_name, + tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values)); +} + +void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, + const int64_t** dims, const int* num_dims, + int num_values, TF_Status* out_status) { + std::unique_ptr<tensorflow::TensorShapeProto[]> proto( + new tensorflow::TensorShapeProto[num_values]); + for (int i = 0; i < num_values; ++i) { + const auto num_dims_i = num_dims[i]; + + if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) { + TF_SetStatus(out_status, TF_INVALID_ARGUMENT, + tensorflow::strings::StrCat( + "Value specified for `", attr_name, "` has ", num_dims_i, + " dimensions which is over the limit of ", + tensorflow::TensorShape::MaxDimensions(), ".") + .c_str()); + return; + } + if (num_dims_i < 0) { + proto[i].set_unknown_rank(true); + } else { + const int64_t* dims_i = dims[i]; + auto proto_i = &proto[i]; + for (int d = 0; d < num_dims_i; ++d) { + proto_i->add_dim()->set_size(dims_i[d]); + } + } + } + op->attrs.Set(attr_name, + tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>( + proto.get(), num_values)); +} + +namespace { + +tensorflow::Status ValidateInputTypeAndPlacement( + tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, + const tensorflow::OpKernel* kernel) { + const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); + if (memtypes.size() != op->inputs.size()) { + return tensorflow::errors::InvalidArgument( + "expected ", memtypes.size(), " inputs, got ", op->inputs.size()); + } + for (int i = 0; i < op->inputs.size(); ++i) { + const tensorflow::Device* expected_device = + memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; + const tensorflow::Device* actual_device = + op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; + if (expected_device != actual_device) { + return tensorflow::errors::InvalidArgument( + "cannot compute ", op->name, " as input #", i, + " was expected to be on ", expected_device->name(), + " but is actually on ", actual_device->name(), + " (operation running on ", op_device->name(), ")"); + } + if (op->inputs[i].dtype() != kernel->input_type(i)) { + return tensorflow::errors::InvalidArgument( + "cannot compute ", op->name, " as input #", i, + " was expected to be a ", + tensorflow::DataType_Name(kernel->input_type(i)), " tensor but is a ", + tensorflow::DataType_Name(op->inputs[i].dtype()), " tensor"); + } + } + return tensorflow::Status::OK(); +} +} // namespace + +void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, + TF_Status* status) { + TFE_Context* ctx = op->ctx; + // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU + tensorflow::Device* device = + (op->device == nullptr) ? ctx->devices()[0] : op->device; + std::vector<tensorflow::Tensor> outputs(1); + const tensorflow::MemoryTypeVector* output_memory_types = nullptr; + tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); + tensorflow::KernelAndDevice* kernel = + tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); + if (kernel == nullptr) { + const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); + kernel = new tensorflow::KernelAndDevice(); + if (!op->is_function()) { + status->status = + tensorflow::KernelAndDevice::InitOp(device, ndef, kernel); + } else { + // Knowledge of the implementation of InitFn (and in-turn + // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def + // will be accessed, so grab on to the lock. + // See WARNING comment below - would be nice to rework to avoid this + // subtlety. + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = tensorflow::KernelAndDevice::InitFn( + ndef, ctx->func_lib(device), kernel); + } + if (!status->status.ok()) { + return; + } + tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); + } + status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op, + kernel->kernel()); + output_memory_types = &kernel->kernel()->output_memory_types(); + if (!status->status.ok()) { + return; + } + // WARNING: kernel->Run utilizes the FunctionLibraryRuntime + // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, + // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation + // of FunctionLibraryRuntime tells use that func_lib_def is not accessed by + // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. + // This is quite subtle. Re-work things to make this better? (Would it make + // sense for FunctionLibraryRuntime to ensure thread-safe access to + // FunctionLibraryDefinition?). + status->status = kernel->Run(&op->inputs, &outputs); + if (!status->status.ok()) return; + *num_retvals = std::min<int>(*num_retvals, outputs.size()); + for (int i = 0; i < *num_retvals; ++i) { + tensorflow::Device* d = IsCPU(device) ? nullptr : device; + if (d != nullptr && output_memory_types != nullptr && + (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { + d = nullptr; + } + retvals[i] = new TFE_TensorHandle(outputs[i], d); + } +} + +void TFE_ContextAddFunctionDef(TFE_Context* ctx, + const char* serialized_function_def, size_t size, + TF_Status* status) { + tensorflow::FunctionDef function_def; + if (!function_def.ParseFromArray(serialized_function_def, size)) { + status->status = + tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); + return; + } + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(function_def); +} + +} // extern "C" + +TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { + return new TFE_TensorHandle(t, nullptr); +} + +const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( + TFE_TensorHandle* h, TF_Status* status) { + if (h->d != nullptr) { + status->status = tensorflow::errors::FailedPrecondition( + "TFE_TensorHandle is placed in device (not host) memory. Cannot return " + "a tensorflow::Tensor"); + return nullptr; + } + return &h->t; +} diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h new file mode 100644 index 0000000000..66a5e43bfc --- /dev/null +++ b/tensorflow/c/eager/c_api.h @@ -0,0 +1,159 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_C_API_H_ +#define TENSORFLOW_C_EAGER_C_API_H_ + +// C API extensions to experiment with eager execution of kernels. + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// "Context" under which operations/functions are executed. It encapsulates +// things like the available devices, resource manager etc. +// +// TODO(ashankar): Merge with TF_Session? +typedef struct TFE_Context TFE_Context; + +extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, + TF_Status* status); +extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); +extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, + TF_Status* status); + +// A handle to a tensor on a device. +// +// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, +// type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors +// placed in memory of different devices or remote address spaces. +typedef struct TFE_TensorHandle TFE_TensorHandle; + +extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t); +extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); +extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); +extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); +extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index); +extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h); +extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, + TF_Status* status); + +// Create a new TFE_TensorHandle with the same contents as 'h' but placed +// in the memory of the device name 'device_name'. +// +// Currently requires at least one of the source or destination devices to +// be CPU (i.e., for the source or destination tensor to be placed in +// host memory). +extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, + TFE_Context* ctx, + const char* device_name, + TF_Status* status); + +// Description of the TensorFlow op to execute. +// +// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e., +// TFE_DeleteOp() is called before TFE_DeleteContext(). +// +// Very similar to TF_OperationDescription with some differences: +// (1) TF_Output or TFE_TensorHandle* as arguments to TF_AddInput, +// TF_AddInputList +// (2) TF_ColocateWith, TF_AddControlInput etc. do not make sense. +// (3) Implementation detail: Avoid use of NodeBuilder/NodeDefBuilder since +// the additional sanity checks there seem unnecessary; +typedef struct TFE_Op TFE_Op; + +extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, + TF_Status* status); +extern void TFE_DeleteOp(TFE_Op* op); + +// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context +// parameter. Instead, the TFE_Context should be captured when creating the +// TFE_Op. +extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, + const char* device_name, TF_Status* status); + +extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); + +extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, + unsigned char* is_list, TF_Status* status); + +extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, + const char* value); +extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); +extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value); +extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, + unsigned char value); +extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, + TF_DataType value); +// If the number of dimensions is unknown, `num_dims` must be set to +// -1 and `dims` can be null. If a dimension is unknown, the +// corresponding entry in the `dims` array must be -1. +extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, + const int64_t* dims, const int num_dims, + TF_Status* out_status); + +extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, + const char** value, int num_values); +extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, + const int64_t* values, int num_values); +extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, + const float* values, int num_values); +extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, + const unsigned char* values, int num_values); +extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, + const TF_DataType* values, int num_values); +extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, + const int64_t** dims, const int* num_dims, + int num_values, TF_Status* out_status); + +// Execute the operation defined by 'op' and return handles to computed +// tensors in 'retvals'. +// +// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* +// and '*num_retvals' should be set to the size of this array. +// +// On return, 'num_retvals' will be set to the actual number of outputs +// returned by the operation. +extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status); + +// Add a function (serialized FunctionDef protocol buffer) to ctx so +// that it can be invoked using TFE_Execute. +extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, + const char* serialized_function_def, + size_t size, TF_Status* status); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#ifdef __cplusplus +// A workaround to ease conversion to and from numpy objects and +// TFE_TensorHandle's. +// +// TODO(ashankar): Figure out an alternative scheme that precludes the need for +// these API-boundary breaking methods. +namespace tensorflow { +class Tensor; +} // namespace tensorflow + +const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( + TFE_TensorHandle* h, TF_Status* status); +TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t); +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc new file mode 100644 index 0000000000..797506422b --- /dev/null +++ b/tensorflow/c/eager/c_api_test.cc @@ -0,0 +1,463 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include <string.h> +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +using tensorflow::string; + +namespace { + +TFE_TensorHandle* TestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandle(t); + TF_DeleteTensor(t); + return th; +} + +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, b, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrBool(op, "transpose_a", 0); + TFE_OpSetAttrBool(op, "transpose_b", 0); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + +// TODO(apassos) uncomment after rewriting to use the right benchmark API +// void BM_InitOp(benchmark::State& state) { +// TF_Status* status = TF_NewStatus(); +// TF_SessionOptions* opts = TF_NewSessionOptions(); +// TFE_Context* ctx = TFE_NewContext(opts, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteSessionOptions(opts); + +// TFE_TensorHandle* m = TestMatrixTensorHandle(); +// for (auto _ : state) { +// TFE_Op* matmul = MatMulOp(ctx, m, m); +// TFE_DeleteOp(matmul); +// } +// TFE_DeleteTensorHandle(m); +// TFE_DeleteContext(ctx, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteStatus(status); +// } +// BENCHMARK(BM_InitOp); + +// void BM_Execute(benchmark::State& state) { +// TF_Status* status = TF_NewStatus(); +// TF_SessionOptions* opts = TF_NewSessionOptions(); +// TFE_Context* ctx = TFE_NewContext(opts, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteSessionOptions(opts); + +// TFE_TensorHandle* m = TestMatrixTensorHandle(); +// TFE_Op* matmul = MatMulOp(ctx, m, m); +// TFE_TensorHandle* retvals[1]; +// int num_retvals = 1; +// for (auto _ : state) { +// TFE_Execute(matmul, &retvals[0], &num_retvals, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// } +// TFE_DeleteOp(matmul); +// TFE_DeleteTensorHandle(m); +// TFE_DeleteContext(ctx, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteStatus(status); +// } +// BENCHMARK(BM_Execute); + +TEST(CAPI, Context) { + TF_Status* status = TF_NewStatus(); + TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + TF_DeleteSessionOptions(opts); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const int num_devices = TF_DeviceListCount(devices); + EXPECT_GE(num_devices, 1) << "At least one CPU device should exist"; + for (int i = 0; i < num_devices; ++i) { + EXPECT_NE("", TF_DeviceListName(devices, i, status)) << i; + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + } + TF_DeleteDeviceList(devices); + TF_DeleteStatus(status); +} + +TEST(CAPI, TensorHandle) { + TFE_TensorHandle* h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = TFE_TensorHandleResolve(h, status.get()); + ASSERT_EQ(16, TF_TensorByteSize(t)); + float data[4] = {0}; + memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(1.0, data[0]); + EXPECT_EQ(2.0, data[1]); + EXPECT_EQ(3.0, data[2]); + EXPECT_EQ(4.0, data[3]); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(h); +} + +TEST(CAPI, TensorHandleCopyBetweenDevices) { + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TF_DeleteSessionOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + const char* kCPUDevice = "CPU:0"; + for (int i = 0; i < num_devices; ++i) { + const string name(TF_DeviceListName(devices, i, status.get())); + if (TF_GetCode(status.get()) != TF_OK) { + ADD_FAILURE() << i << " -- " << TF_Message(status.get()); + continue; + } + auto tag = tensorflow::strings::StrCat("Device #", i, " (", name, ")"); + // Copy to device + TFE_TensorHandle* hdevice = + TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + if (TF_GetCode(status.get()) != TF_OK) { + ADD_FAILURE() << tag << " -- " << TF_Message(status.get()); + continue; + } + // Copy back to CPU + TFE_TensorHandle* hcopy = + TFE_TensorHandleCopyToDevice(hdevice, ctx, kCPUDevice, status.get()); + if (TF_GetCode(status.get()) != TF_OK) { + ADD_FAILURE() << tag << " -- " << TF_Message(status.get()); + continue; + } + TFE_DeleteTensorHandle(hdevice); + + // Ensure that the contents are the same! + TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get()); + TFE_DeleteTensorHandle(hcopy); + if (TF_GetCode(status.get()) != TF_OK) { + ADD_FAILURE() << tag; + continue; + } + EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)) << tag; + EXPECT_EQ( + 0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t))) + << tag; + TF_DeleteTensor(tcopy); + } + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + +TEST(CAPI, Execute) { + TF_Status* status = TF_NewStatus(); + TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteSessionOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} + +string MatMulFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'MatMulFunction'" + " input_arg {" + " name: 'a'" + " type: DT_FLOAT" + " }" + " output_arg {" + " name: 'm'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'matmul'" + " op: 'MatMul'" + " input: 'a'" + " input: 'a'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'm'" + " value: 'matmul:product'" + " }", + &def)); + return def.SerializeAsString(); +} + +TEST(CAPI, FunctionDefAndExecute) { + TF_Status* status = TF_NewStatus(); + TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteSessionOptions(opts); + + string function_def = MatMulFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_TensorHandle* retval[1] = {nullptr}; + int num_retvals = 1; + TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, m, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Execute(op, &retval[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TFE_DeleteOp(op); + TFE_DeleteTensorHandle(m); + TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status); + TFE_DeleteTensorHandle(retval[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +// TODO(apassos) uncomment after rewriting to use the right benchmark API +// void BM_ExecuteFunction(benchmark::State& state) { +// TF_Status* status = TF_NewStatus(); +// TF_SessionOptions* opts = TF_NewSessionOptions(); +// TFE_Context* ctx = TFE_NewContext(opts, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteSessionOptions(opts); + +// string function_def = MatMulFunction(); +// TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), +// status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + +// TFE_TensorHandle* m = TestMatrixTensorHandle(); +// TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TFE_OpAddInput(matmul, m, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TFE_TensorHandle* retval[1] = {nullptr}; +// int num_retvals = 1; +// for (auto _ : state) { +// TFE_Execute(matmul, &retval[0], &num_retvals, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// } +// TFE_DeleteTensorHandle(m); +// TFE_DeleteTensorHandle(retval[0]); +// TFE_DeleteContext(ctx, status); +// EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteStatus(status); +// } +// BENCHMARK(BM_ExecuteFunction); + +// TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, +// TF_Status* status) { +// // Create the variable handle. +// TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); +// if (TF_GetCode(status) != TF_OK) return nullptr; +// TFE_OpSetAttrType(op, "dtype", TF_FLOAT); +// TFE_OpSetAttrShape(op, "shape", {}, 0, status); +// TFE_OpSetAttrString(op, "container", ""); +// TFE_OpSetAttrString(op, "shared_name", ""); +// if (TF_GetCode(status) != TF_OK) return nullptr; +// TFE_TensorHandle* var_handle = nullptr; +// int num_retvals = 1; +// TFE_Execute(op, &var_handle, &num_retvals, status); +// TFE_DeleteOp(op); +// if (TF_GetCode(status) != TF_OK) return nullptr; +// CHECK_EQ(1, num_retvals); + +// // Assign 'value' to it. +// op = TFE_NewOp(ctx, "AssignVariableOp", status); +// if (TF_GetCode(status) != TF_OK) return nullptr; +// TFE_OpSetAttrType(op, "dtype", TF_FLOAT); +// TFE_OpAddInput(op, var_handle, status); + +// // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. +// std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t( +// TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), +// TF_DeleteTensor); +// memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); + +// std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> +// value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle); + +// TFE_OpAddInput(op, value_handle.get(), status); +// if (TF_GetCode(status) != TF_OK) return nullptr; + +// num_retvals = 0; +// TFE_Execute(op, nullptr, &num_retvals, status); +// TFE_DeleteOp(op); +// if (TF_GetCode(status) != TF_OK) return nullptr; +// CHECK_EQ(0, num_retvals); + +// return var_handle; +// } + +// TEST(CAPI, Variables) { +// // Variables use resource handles, so this is really a test for resource +// // tensor handling. +// TF_Status* status = TF_NewStatus(); +// TF_SessionOptions* opts = TF_NewSessionOptions(); +// TFE_Context* ctx = TFE_NewContext(opts, status); +// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteSessionOptions(opts); + +// TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); +// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + +// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); +// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TFE_OpSetAttrType(op, "dtype", TF_FLOAT); +// TFE_OpAddInput(op, var_handle, status); +// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// int num_retvals = 1; +// TFE_TensorHandle* value_handle = nullptr; +// TFE_Execute(op, &value_handle, &num_retvals, status); +// TFE_DeleteOp(op); + +// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// ASSERT_EQ(1, num_retvals); +// EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle)); +// EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle)); +// float value = 0.0f; +// TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status); +// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// ASSERT_EQ(sizeof(float), TF_TensorByteSize(t)); +// memcpy(&value, TF_TensorData(t), sizeof(float)); +// TF_DeleteTensor(t); +// EXPECT_EQ(12.0, value); + +// TFE_DeleteTensorHandle(var_handle); +// TFE_DeleteTensorHandle(value_handle); +// TFE_DeleteContext(ctx, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteStatus(status); +// } + +// void BM_ReadVariable(benchmark::State& state) { +// TF_Status* status = TF_NewStatus(); +// TF_SessionOptions* opts = TF_NewSessionOptions(); +// TFE_Context* ctx = TFE_NewContext(opts, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteSessionOptions(opts); + +// TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + +// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TFE_OpSetAttrType(op, "dtype", TF_FLOAT); +// TFE_OpAddInput(op, var_handle, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + +// int num_retvals = 1; +// TFE_TensorHandle* h = nullptr; +// for (auto _ : state) { +// TFE_Execute(op, &h, &num_retvals, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// CHECK_EQ(1, num_retvals); +// CHECK(h); +// CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); +// CHECK_EQ(0, TFE_TensorHandleNumDims(h)); +// h = nullptr; +// } +// TFE_DeleteOp(op); + +// TFE_DeleteTensorHandle(var_handle); +// TFE_DeleteContext(ctx, status); +// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +// TF_DeleteStatus(status); +// } +// BENCHMARK(BM_ReadVariable); + +} // namespace diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc new file mode 100644 index 0000000000..87e9f5377f --- /dev/null +++ b/tensorflow/c/eager/runtime.cc @@ -0,0 +1,289 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/runtime.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" + +namespace tensorflow { +namespace { + +mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED); + +std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() { + static auto* const m = new std::unordered_map<string, const AttrTypeMap*>; + return m; +} + +const uint32 kIsList = 1U << 31; + +} // namespace + +Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { + mutex_lock l(g_op_name_to_attr_type_map_lock); + *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); + if (*out != nullptr) return Status::OK(); + const OpRegistrationData* op_reg_data = nullptr; + Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (!s.ok()) return s; + std::unique_ptr<AttrTypeMap> m(new AttrTypeMap); + // TODO(agarwal): Avoid having to create this "registry" at runtime, + // perhaps can be done at op registration time? + for (const auto& attr : op_reg_data->op_def.attr()) { + string type = attr.type(); + const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0); + if (is_list) { + type = type.substr(5, type.length() - 6); + } + uint32 t = is_list ? kIsList : 0; + if (type == "string") { + t |= TF_ATTR_STRING; + } else if (type == "int") { + t |= TF_ATTR_INT; + } else if (type == "float") { + t |= TF_ATTR_FLOAT; + } else if (type == "bool") { + t |= TF_ATTR_BOOL; + } else if (type == "type") { + t |= TF_ATTR_TYPE; + } else if (type == "shape") { + t |= TF_ATTR_SHAPE; + } else if (type == "tensor") { + t |= TF_ATTR_TENSOR; + } else { + return errors::Unimplemented( + "TODO(agarwal): Enable support for ops with attributes of type '", + type, "'"); + } + gtl::InsertIfNotPresent(m.get(), attr.name(), t); + } + *out = m.get(); + (*OpNameToAttrTypeMap())[op_name] = m.release(); + return Status::OK(); +} + +Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list) { + CHECK(m); + auto* t = gtl::FindOrNull(*m, attr_name); + if (t == nullptr) { + return errors::InvalidArgument("Attribute '", attr_name, + "' does not exist for this operation"); + } + *out = static_cast<TF_AttrType>(*t & ~kIsList); + if (*t & kIsList) { + *is_list = 1; + } else { + *is_list = 0; + } + return Status::OK(); +} + +#define DEFINE_SET_ATTR(value_type, value_field) \ + template <> \ + AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \ + value_field.push_back(std::make_pair(attr_name, value)); \ + return *this; \ + } + +DEFINE_SET_ATTR(StringPiece, string_attrs_); +DEFINE_SET_ATTR(float, float_attrs_); +DEFINE_SET_ATTR(int, int_attrs_); +DEFINE_SET_ATTR(bool, bool_attrs_); +DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_); + +#undef DEFINE_SET_ATTR + +AttrBuilder& AttrBuilder::NumInputs(int n) { + DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef."; + num_inputs_ = n; + return *this; +} + +const NodeDef& AttrBuilder::BuildNodeDef() { + if (node_def_finalized_) return *node_def_; + MayBeInitializeNodeDef(); + for (int i = 0; i < num_inputs_; ++i) { + node_def_->add_input("dummy_input"); + } + for (const auto& p : string_attrs_) { + SetInNodeDef(p.first, p.second); + } + for (const auto& p : int_attrs_) { + SetInNodeDef(p.first, p.second); + } + for (const auto& p : float_attrs_) { + SetInNodeDef(p.first, p.second); + } + for (const auto& p : bool_attrs_) { + SetInNodeDef(p.first, p.second); + } + for (const auto& p : type_attrs_) { + SetInNodeDef(p.first, p.second); + } + node_def_finalized_ = true; + return *node_def_; +} + +namespace { +inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, + const tensorflow::Fprint128& b) { + return {tensorflow::FingerprintCat64(a.low64, b.low64), + tensorflow::FingerprintCat64(a.low64, b.low64)}; +} + +void CombineUnordered(const tensorflow::Fprint128& a, + tensorflow::Fprint128* b) { + b->low64 += a.low64; + b->high64 += a.high64; +} + +inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, + const tensorflow::Fprint128& b) { + // TODO(agarwal): avoid ToString(). + tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString()); + return FingerprintCat128(a, b); +} + +inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) { + return CacheKeyHelper(s, {b, b}); +} + +} // namespace + +tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const { + tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_); + f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device)); + if (node_def_ != nullptr) { + // Some attributes are directly written to node_def_ instead of being + // stored explicitly. + string value; + for (const auto& attr : node_def_->attr()) { + attr.second.SerializeToString(&value); + CombineUnordered( + CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f); + } + // Note that node_def_ may be created but not finalized. This can happen + // when the creation was triggered by a call to Set, but BuildNodeDef has + // not been called. + if (node_def_finalized_) return f; + } + for (const auto& p : string_attrs_) { + // TODO(agarwal): avoid ToString(). + CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128( + p.second.ToString())), + &f); + } + for (const auto& p : int_attrs_) { + CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)), + &f); + } + static std::hash<float> float_hasher; + for (const auto& p : float_attrs_) { + CombineUnordered( + CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))), + &f); + } + for (const auto& p : bool_attrs_) { + CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f); + } + for (const auto& p : type_attrs_) { + CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)), + &f); + } + return f; +} + +void AttrBuilder::MayBeInitializeNodeDef() { + if (node_def_ == nullptr) { + node_def_.reset(new NodeDef()); + node_def_->set_name(op_name_); + node_def_->set_op(op_name_); + } +} + +// static +Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, + KernelAndDevice* out) { + OpKernel* k = nullptr; + Status s = CreateOpKernel(device->device_type().c_str(), device, + device->GetAllocator(AllocatorAttributes()), + nullptr, ndef, TF_GRAPH_DEF_VERSION, &k); + out->device_ = device; + out->kernel_.reset(k); + out->flib_ = nullptr; + return s; +} + +// static +Status KernelAndDevice::InitFn(const NodeDef& ndef, + FunctionLibraryRuntime* flib, + KernelAndDevice* out) { + OpKernel* k = nullptr; + Status s = flib->CreateKernel(ndef, &k); + out->device_ = flib->device(); + out->kernel_.reset(k); + out->flib_ = flib; + return s; +} + +Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, + std::vector<Tensor>* output_tensors) { + gtl::InlinedVector<TensorValue, 4> inputs; + for (Tensor& t : *input_tensors) { + inputs.push_back(TensorValue(&t)); + } + + std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs()); + for (size_t i = 0; i < out_attrs.size(); ++i) { + out_attrs[i].set_on_host(kernel_->output_memory_types()[i] == + tensorflow::HOST_MEMORY); + } + + OpKernelContext::Params params; + params.device = device_; + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = kernel_.get(); + params.resource_manager = device_->resource_manager(); + params.output_attr_array = gtl::vector_as_array(&out_attrs); + params.function_library = flib_; + params.slice_reader_cache = &slice_reader_cache_; + // TODO(apassos): use a thread pool. + std::function<void(std::function<void()>)> runner = + [](std::function<void()> f) { f(); }; + params.runner = &runner; + + OpKernelContext context(¶ms); + device_->Compute(kernel_.get(), &context); + if (!context.status().ok()) return context.status(); + + output_tensors->clear(); + for (int i = 0; i < context.num_outputs(); ++i) { + output_tensors->push_back(Tensor(*context.mutable_output(i))); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h new file mode 100644 index 0000000000..5916302ff4 --- /dev/null +++ b/tensorflow/c/eager/runtime.h @@ -0,0 +1,193 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_RUNTIME_H_ +#define TENSORFLOW_C_EAGER_RUNTIME_H_ + +// Support for eager execution of TensorFlow kernels. + +#include <memory> +#include <unordered_map> + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" + +namespace tensorflow { + +// Maps attribute name to an encoding of the type of the attribute value. +// If the type is not a list type, the value is the same as the TF_AttrType type +// of the value. Else, the highest order bit is on, and the rest of the bits +// represent the TF_AttrType type of the values in the list. +typedef std::unordered_map<string, uint32> AttrTypeMap; + +// Returns the AttrTypeMap for the TensorFlow operation named op_name. +Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); + +// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. +Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list); + +// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. +// An AttrBuilder is a convenience class to help with that - providing a smaller +// interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity +// checks (like number of inputs matching the OpDef - we only care about +// attributes here). +// +// TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which +// ones make sense to replicate. + +// This is a helper class for creating a NodeDef. Additionally, this class +// allows computing a cache key based on fingerprinting the attributes of this +// NodeDef. +// +// Example usage: +// AttrBuilder a; +// a.NumInputs(2); +// a.Set("T", TF_FLOAT); +// uint64 cache_key = a.CacheKey("cpu:0"); +// const NodeDef& n = a.BuildNodeDef(); +// +// Note that all calls to Set and NumInputs should happen before calling +// BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations +// to CacheKey may cause different values to be returned by CacheKey. +// +// For performance reasons, the class internally delays the actual construction +// of the NodeDef till BuildNodeDef is called, or Set is called with certain +// uncommon types (see template specializations of Set to see which types +// trigger a NodeDef creation). +class AttrBuilder { + public: + explicit AttrBuilder(const char* op) + : op_name_(op), + num_inputs_(0), + node_def_(nullptr), + node_def_finalized_(false) {} + + // Needed to work around call to ValidateNodeDef in CreateOpKernel. + AttrBuilder& NumInputs(int n); + + template <class T> + AttrBuilder& Set(StringPiece attr_name, T&& value) { + MayBeInitializeNodeDef(); + return SetInNodeDef(attr_name, value); + } + + tensorflow::Fprint128 CacheKey(const string& device) const; + + const NodeDef& BuildNodeDef(); + + private: + template <class T> + using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>; + + void MayBeInitializeNodeDef(); + + template <class T> + AttrBuilder& SetInNodeDef(StringPiece attr_name, T&& value) { + DCHECK(!node_def_finalized_) << "Calling SetInNodeDef after BuildNodeDef."; + // Copied from NodeDefBuilder::Attr + const AttrValue* found = AttrSlice(*node_def_).Find(attr_name); + if (found == nullptr) { + AddNodeAttr(attr_name, std::forward<T>(value), node_def_.get()); + } else { + AttrValue attr_value; + SetAttrValue(std::forward<T>(value), &attr_value); + // TODO(ashankar): Do what is done in + // NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value); + } + return *this; + } + + AttrVec<StringPiece> string_attrs_; + AttrVec<int> int_attrs_; + AttrVec<float> float_attrs_; + AttrVec<bool> bool_attrs_; + AttrVec<tensorflow::DataType> type_attrs_; + string op_name_; + int num_inputs_; + std::unique_ptr<NodeDef> node_def_; + bool node_def_finalized_; +}; // namespace tensorflow + +template <> +AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value); +template <> +AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value); +template <> +AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value); +template <> +AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value); +template <> +AttrBuilder& AttrBuilder::Set(StringPiece attr_name, + tensorflow::DataType&& value); + +// KernelAndDevice encapsulates an instantiated kernel and the device it is on. +// +// Also see: +// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h +// and +// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h +class KernelAndDevice { + public: + // Populates 'out' with a kernel appropriate for 'ndef'. + // + // Assumes that 'ndef' refers to a primitive op (as opposed to a function). + static Status InitOp(Device* device, const NodeDef& ndef, + KernelAndDevice* out); + + // Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a + // TensorFlow function in the FunctionLibraryRuntime). + // + // The provided FunctionLibraryRuntime MUST outlive all calls to + // Run() on the returned KernelAndDevice. + // + // TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn. + // The implementation of InitFn should work for both because + // FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if + // appropriate. However, for now we keep them separate because I haven't + // figured out thread-safety concerns around FunctionLibraryRuntime (in + // particular, how the underlying FunctionLibraryDefinition might be mutated + // by another thread as new functions are registered with it). + // Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed + // on to the caller (see locking in c_api.cc) for now. But I really should + // dig into this so that both InitOp and InitFn can be collapsed to + // FunctionLibraryRuntime::CreateKernel. + static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib, + KernelAndDevice* out); + + KernelAndDevice() : device_(nullptr), flib_(nullptr) {} + + // TODO(ashankar): Handle list-valued inputs. + Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs); + + const OpKernel* kernel() const { return kernel_.get(); } + + private: + std::unique_ptr<OpKernel> kernel_; + tensorflow::Device* device_; + tensorflow::FunctionLibraryRuntime* flib_; + tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_RUNTIME_H_ diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc new file mode 100644 index 0000000000..3b38e24704 --- /dev/null +++ b/tensorflow/c/eager/runtime_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/runtime.h" + +#include <memory> +#include <vector> + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +Device* CPUDevice() { + return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); +} + +TEST(AttrTypeMap, Lookup) { + const AttrTypeMap* m = nullptr; + Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m); + EXPECT_FALSE(s.ok()); + s = AttrTypeMapForOp("MatMul", &m); + ASSERT_TRUE(s.ok()) << s; + + TF_AttrType t; + unsigned char is_list = 1; + s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); + EXPECT_FALSE(s.ok()); + EXPECT_NE(is_list, 0); + s = AttrTypeByName(m, "transpose_a", &t, &is_list); + ASSERT_TRUE(s.ok()) << s; + EXPECT_EQ(TF_ATTR_BOOL, t); + EXPECT_EQ(is_list, 0); + + s = AttrTypeMapForOp("Squeeze", &m); + ASSERT_TRUE(s.ok()) << s; + s = AttrTypeByName(m, "squeeze_dims", &t, &is_list); + ASSERT_TRUE(s.ok()) << s; + EXPECT_EQ(TF_ATTR_INT, t); + EXPECT_NE(is_list, 0); +} + +TEST(KernelAndDevice, Run) { + Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); + std::vector<Tensor> inputs; + inputs.push_back(t); + inputs.push_back(t); + NodeDef ndef(AttrBuilder("MatMul") + .Set("T", DT_FLOAT) + .Set("transpose_a", false) + .Set("transpose_b", false) + .NumInputs(inputs.size()) + .BuildNodeDef()); + std::unique_ptr<Device> device(CPUDevice()); + KernelAndDevice kernel; + Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel); + ASSERT_TRUE(s.ok()) << s; + std::vector<Tensor> outputs; + s = kernel.Run(&inputs, &outputs); + ASSERT_TRUE(s.ok()) << s; + ASSERT_EQ(1, outputs.size()); + const Tensor& out = outputs[0]; + EXPECT_EQ(7, out.matrix<float>()(0, 0)); + EXPECT_EQ(10, out.matrix<float>()(0, 1)); + EXPECT_EQ(15, out.matrix<float>()(1, 0)); + EXPECT_EQ(22, out.matrix<float>()(1, 1)); +} + +// TODO(apassos) uncomment after rewriting to use the right benchmark API +// void BM_CreateGraph(benchmark::State& state) { +// for (auto _ : state) { +// Scope root = Scope::NewRootScope(); +// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); +// auto M = ops::MatMul(root, C, C); +// TF_CHECK_OK(root.status()); +// } +// } +// BENCHMARK(BM_CreateGraph); + +// void BM_RunGraph(benchmark::State& state) { +// Scope root = Scope::NewRootScope(); +// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); +// auto M = ops::MatMul(root, C, C); +// SessionOptions opts; +// opts.config.set_inter_op_parallelism_threads(1); +// opts.config.set_intra_op_parallelism_threads(1); +// ClientSession sess(root, opts); +// std::vector<Tensor> outputs; +// for (auto _ : state) { +// outputs.clear(); +// TF_CHECK_OK(sess.Run({M}, &outputs)); +// } +// } +// BENCHMARK(BM_RunGraph); + +// void BM_CreateAndDestroySession(benchmark::State& state) { +// Scope root = Scope::NewRootScope(); +// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); +// auto M = ops::MatMul(root, C, C); +// for (auto _ : state) { +// ClientSession sess(root); +// } +// } +// BENCHMARK(BM_CreateAndDestroySession); + +// void BM_KernelAndDeviceInit(benchmark::State& state) { +// NodeDef ndef(AttrBuilder("MatMul") +// .Set("T", DT_FLOAT) +// .Set("transpose_a", false) +// .Set("transpose_b", false) +// .NumInputs(2) +// .BuildNodeDef()); +// std::unique_ptr<Device> device(CPUDevice()); +// KernelAndDevice k; +// for (auto _ : state) { +// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k)); +// } +// } +// BENCHMARK(BM_KernelAndDeviceInit); + +// void BM_KernelAndDeviceRun(benchmark::State& state) { +// Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); +// std::vector<Tensor> inputs; +// inputs.push_back(t); +// inputs.push_back(t); +// std::vector<Tensor> outputs; +// NodeDef ndef(AttrBuilder("MatMul") +// .Set("T", DT_FLOAT) +// .Set("transpose_a", false) +// .Set("transpose_b", false) +// .NumInputs(inputs.size()) +// .BuildNodeDef()); +// std::unique_ptr<Device> device(CPUDevice()); +// KernelAndDevice kernel; +// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel)); +// for (auto _ : state) { +// TF_CHECK_OK(kernel.Run(&inputs, &outputs)); +// } +// } +// BENCHMARK(BM_KernelAndDeviceRun); +} // namespace +} // namespace tensorflow |