diff options
author | Alexandre Passos <apassos@google.com> | 2018-03-27 15:08:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-27 15:11:59 -0700 |
commit | 9a0b91023d8444cd4691be10b36ce469ca08058d (patch) | |
tree | 78ecca3c6db5e520471295a7f05fe4de87b96242 /tensorflow/c/eager | |
parent | 736e055a756cf0f99d59b67284aade01baec9799 (diff) |
Moves Execute() from c_api.cc
PiperOrigin-RevId: 190681610
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 90 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 4 |
3 files changed, 10 insertions, 85 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 8df7b56623..e57011a08b 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -30,6 +30,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/common_runtime/eager:copy_to_device_node", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index eaeb2fd07a..ac7114f71e 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" +#include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/node_def_util.h" @@ -574,83 +575,6 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, return nullptr; } -tensorflow::Status Execute( - TFE_Context* ctx, tensorflow::Device* device, - const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>& - op_inputs, - tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats, - tensorflow::TensorHandle** retvals, int num_retvals) { - if (!ctx->context.SoftPlacement() && device == nullptr) { - device = ctx->context.HostCPU(); - } - - if (device == nullptr) { - // TODO(apassos) debug how the assignment below might return a different - // device from the one requested above. - device = kernel->device(); - } - - std::vector<tensorflow::Tensor> outputs(1); - const tensorflow::MemoryTypeVector* output_memory_types = nullptr; - output_memory_types = &kernel->kernel()->output_memory_types(); - std::vector<tensorflow::Tensor> inputs(op_inputs.size()); - for (int i = 0; i < op_inputs.size(); ++i) { - const tensorflow::Tensor* input_tensor = nullptr; - TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor)); - inputs[i] = *input_tensor; - } - // WARNING: kernel->Run utilizes the FunctionLibraryRuntime - // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def. - // But knowledge of the implementation - // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by - // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. - // This is quite subtle. Re-work things to make this better? (Would it make - // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats - // for ops which are a part of functions. - // TODO(agarwal): change Run to take vector of handles ? - TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats)); - if (maybe_stats != nullptr) { - maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - - maybe_stats->all_start_micros()); - tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); - if (ctx->context.ShouldStoreMetadata()) { - auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats(); - // Lazily initialize the RunMetadata with information about all devices if - // this is the first call. - while (step_stats->dev_stats_size() < ctx->context.devices()->size()) { - step_stats->add_dev_stats(); - } - // Find the current device's index. - int device_idx = 0; - for (int i = 0; i < ctx->context.devices()->size(); ++i) { - if (ctx->context.devices()->at(i) == device) { - device_idx = i; - break; - } - } - // Populate the device stats for this device. - auto* dev_stats = step_stats->mutable_dev_stats(device_idx); - dev_stats->set_device(device->name()); - *dev_stats->add_node_stats() = *maybe_stats; - } - } - DCHECK_EQ(num_retvals, outputs.size()); - tensorflow::Device* op_device = IsCPU(device) ? nullptr : device; - for (int i = 0; i < num_retvals; ++i) { - tensorflow::Device* d = op_device; - if (d != nullptr && output_memory_types != nullptr && - (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { - d = nullptr; - } - if (retvals[i] == nullptr) { - retvals[i] = new tensorflow::TensorHandle(outputs[i], d, op_device); - } else { - retvals[i]->SetTensorAndDevice(outputs[i], d, op_device); - } - } - return tensorflow::Status::OK(); -} // TODO(agarwal): move EagerExecutor and EagerNode related code to a separate // file. @@ -690,9 +614,9 @@ class ExecuteNode : public tensorflow::EagerNode { } tensorflow::Status Run() override { - const tensorflow::Status status = - Execute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(), - retvals_.begin(), retvals_.size()); + const tensorflow::Status status = tensorflow::EagerExecute( + &ctx_->context, op_device_, inputs_, kernel_, maybe_stats_.get(), + retvals_.begin(), retvals_.size()); if (status.ok()) { return status; } else { @@ -1062,9 +986,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // allocate it. std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals, nullptr); - status->status = - Execute(op->ctx, op->device, op->inputs, kernel, maybe_stats.get(), - handle_retvals.data(), *num_retvals); + status->status = tensorflow::EagerExecute( + &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(), + handle_retvals.data(), *num_retvals); for (int i = 0; i < *num_retvals; ++i) { retvals[i] = new TFE_TensorHandle(handle_retvals[i]); } diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 2268aba90d..d88a6c1dda 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -688,12 +688,12 @@ TEST(CAPI, Execute_Min_CPU) { TFE_DeleteOp(minOp); TFE_DeleteTensorHandle(input); TFE_DeleteTensorHandle(axis); - TFE_DeleteContext(ctx, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float output[2] = {0}; EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); |