aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-03-27 15:08:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 15:11:59 -0700
commit9a0b91023d8444cd4691be10b36ce469ca08058d (patch)
tree78ecca3c6db5e520471295a7f05fe4de87b96242 /tensorflow/c/eager
parent736e055a756cf0f99d59b67284aade01baec9799 (diff)
Moves Execute() from c_api.cc
PiperOrigin-RevId: 190681610
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/BUILD1
-rw-r--r--tensorflow/c/eager/c_api.cc90
-rw-r--r--tensorflow/c/eager/c_api_test.cc4
3 files changed, 10 insertions, 85 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 8df7b56623..e57011a08b 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -30,6 +30,7 @@ tf_cuda_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index eaeb2fd07a..ac7114f71e 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
+#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -574,83 +575,6 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
return nullptr;
}
-tensorflow::Status Execute(
- TFE_Context* ctx, tensorflow::Device* device,
- const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>&
- op_inputs,
- tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
- tensorflow::TensorHandle** retvals, int num_retvals) {
- if (!ctx->context.SoftPlacement() && device == nullptr) {
- device = ctx->context.HostCPU();
- }
-
- if (device == nullptr) {
- // TODO(apassos) debug how the assignment below might return a different
- // device from the one requested above.
- device = kernel->device();
- }
-
- std::vector<tensorflow::Tensor> outputs(1);
- const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
- output_memory_types = &kernel->kernel()->output_memory_types();
- std::vector<tensorflow::Tensor> inputs(op_inputs.size());
- for (int i = 0; i < op_inputs.size(); ++i) {
- const tensorflow::Tensor* input_tensor = nullptr;
- TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
- inputs[i] = *input_tensor;
- }
- // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
- // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def.
- // But knowledge of the implementation
- // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
- // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
- // This is quite subtle. Re-work things to make this better? (Would it make
- // sense for FunctionLibraryRuntime to ensure thread-safe access to
- // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
- // for ops which are a part of functions.
- // TODO(agarwal): change Run to take vector of handles ?
- TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
- if (maybe_stats != nullptr) {
- maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
- maybe_stats->all_start_micros());
- tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
- if (ctx->context.ShouldStoreMetadata()) {
- auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats();
- // Lazily initialize the RunMetadata with information about all devices if
- // this is the first call.
- while (step_stats->dev_stats_size() < ctx->context.devices()->size()) {
- step_stats->add_dev_stats();
- }
- // Find the current device's index.
- int device_idx = 0;
- for (int i = 0; i < ctx->context.devices()->size(); ++i) {
- if (ctx->context.devices()->at(i) == device) {
- device_idx = i;
- break;
- }
- }
- // Populate the device stats for this device.
- auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
- dev_stats->set_device(device->name());
- *dev_stats->add_node_stats() = *maybe_stats;
- }
- }
- DCHECK_EQ(num_retvals, outputs.size());
- tensorflow::Device* op_device = IsCPU(device) ? nullptr : device;
- for (int i = 0; i < num_retvals; ++i) {
- tensorflow::Device* d = op_device;
- if (d != nullptr && output_memory_types != nullptr &&
- (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
- d = nullptr;
- }
- if (retvals[i] == nullptr) {
- retvals[i] = new tensorflow::TensorHandle(outputs[i], d, op_device);
- } else {
- retvals[i]->SetTensorAndDevice(outputs[i], d, op_device);
- }
- }
- return tensorflow::Status::OK();
-}
// TODO(agarwal): move EagerExecutor and EagerNode related code to a separate
// file.
@@ -690,9 +614,9 @@ class ExecuteNode : public tensorflow::EagerNode {
}
tensorflow::Status Run() override {
- const tensorflow::Status status =
- Execute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
- retvals_.begin(), retvals_.size());
+ const tensorflow::Status status = tensorflow::EagerExecute(
+ &ctx_->context, op_device_, inputs_, kernel_, maybe_stats_.get(),
+ retvals_.begin(), retvals_.size());
if (status.ok()) {
return status;
} else {
@@ -1062,9 +986,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// allocate it.
std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals,
nullptr);
- status->status =
- Execute(op->ctx, op->device, op->inputs, kernel, maybe_stats.get(),
- handle_retvals.data(), *num_retvals);
+ status->status = tensorflow::EagerExecute(
+ &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(),
+ handle_retvals.data(), *num_retvals);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
}
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 2268aba90d..d88a6c1dda 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -688,12 +688,12 @@ TEST(CAPI, Execute_Min_CPU) {
TFE_DeleteOp(minOp);
TFE_DeleteTensorHandle(input);
TFE_DeleteTensorHandle(axis);
- TFE_DeleteContext(ctx, status);
- ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float output[2] = {0};
EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));