aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-04-20 11:34:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-20 11:37:39 -0700
commit76ea66f24d4370e6e7848b83fc0b571ba7edfa2d (patch)
treedb34453119f995aceed382d7779791e55874f3b0 /tensorflow/c/eager
parent712bbc5d7babd523951445f361f0e339061cd259 (diff)
Move the guts of TFE_Op into EagerOperation
PiperOrigin-RevId: 193698320
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/BUILD2
-rw-r--r--tensorflow/c/eager/c_api.cc230
-rw-r--r--tensorflow/c/eager/c_api_internal.h16
3 files changed, 119 insertions, 129 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 3e14c10727..d66386acbd 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -51,6 +51,7 @@ tf_cuda_library(
],
"//conditions:default": [],
}) + [
+ "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core:gpu_runtime",
],
)
@@ -73,6 +74,7 @@ tf_cuda_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 369342b142..b7a3097208 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -241,21 +241,18 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
- tensorflow::Device* d = nullptr;
- if (device_name != nullptr && strlen(device_name) > 0) {
- status->status = op->ctx->context.FindDeviceByName(device_name, &d);
- }
- op->device = d;
+ status->status = op->operation.SetDevice(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
- tensorflow::Device* device =
- (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device;
+ tensorflow::Device* device = (op->operation.Device() == nullptr)
+ ? op->operation.EagerContext()->HostCPU()
+ : op->operation.Device();
return device->name().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
- op->use_xla = enable;
+ op->operation.SetUseXla(enable);
#ifndef TENSORFLOW_EAGER_USE_XLA
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
@@ -263,22 +260,20 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
- h->handle->Ref();
- op->inputs.push_back(h->handle);
- op->attrs.NumInputs(op->inputs.size());
+ op->operation.AddInput(h->handle);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret;
- if (op->is_function()) {
+ if (op->operation.is_function()) {
status->status = tensorflow::errors::Unimplemented(
"TODO(apassos): Support for attributes for TensorFlow functions is not "
"ready yet.");
return TF_ATTR_INT; // The compiler requires that we return something.
}
- status->status =
- tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
+ status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
+ attr_name, &ret, is_list);
return ret;
}
@@ -297,23 +292,24 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
}
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
- op->attrs.Set(attr_name, value);
+ op->operation.MutableAttrs()->Set(attr_name, value);
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
- op->attrs.Set(attr_name, static_cast<int64>(value));
+ op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
- op->attrs.Set(attr_name, value);
+ op->operation.MutableAttrs()->Set(attr_name, value);
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
- op->attrs.Set(attr_name, (value == 0) ? false : true);
+ op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
- op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
+ op->operation.MutableAttrs()->Set(attr_name,
+ static_cast<tensorflow::DataType>(value));
}
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
@@ -335,23 +331,24 @@ void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
proto.add_dim()->set_size(dims[d]);
}
}
- op->attrs.Set(attr_name, proto);
+ op->operation.MutableAttrs()->Set(attr_name, proto);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
- func->set_name(value->name);
- value->attrs.FillAttrValueMap(func->mutable_attr());
- op->attrs.Set(attr_name, attr_value);
+ func->set_name(value->operation.Name());
+ value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
+ op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
#define TFE_OP_SET_ATTR_LIST(fn, type) \
void fn(TFE_Op* op, const char* attr_name, const type* values, \
int num_values) { \
- op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
- values, num_values)); \
+ op->operation.MutableAttrs()->Set( \
+ attr_name, \
+ tensorflow::gtl::ArraySlice<const type>(values, num_values)); \
}
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
@@ -359,14 +356,14 @@ TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
- op->attrs.Set(attr_name,
- tensorflow::gtl::ArraySlice<const int64>(
- reinterpret_cast<const int64*>(values), num_values));
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice<const int64>(
+ reinterpret_cast<const int64*>(values), num_values));
}
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
- op->attrs.Set(
+ op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
@@ -378,8 +375,8 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
- op->attrs.Set(attr_name,
- tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
}
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
@@ -409,9 +406,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
}
}
}
- op->attrs.Set(attr_name,
- tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
- proto.get(), num_values));
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
+ proto.get(), num_values));
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
@@ -419,12 +416,12 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
new tensorflow::NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
- funcs[i].set_name(value[i]->name);
- value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr());
+ funcs[i].set_name(value[i]->operation.Name());
+ value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
}
- op->attrs.Set(attr_name,
- tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
- funcs.get(), num_values));
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
+ funcs.get(), num_values));
}
} // extern "C"
@@ -460,18 +457,19 @@ int StepStatsDeviceIndex(tensorflow::StepStats* step_stats,
}
tensorflow::Status ValidateInputTypeAndPlacement(
- tensorflow::EagerContext* ctx, tensorflow::Device* op_device, TFE_Op* op,
- const tensorflow::OpKernel* kernel, tensorflow::RunMetadata* run_metadata) {
+ tensorflow::EagerContext* ctx, tensorflow::Device* op_device,
+ tensorflow::EagerOperation* op, const tensorflow::OpKernel* kernel,
+ tensorflow::RunMetadata* run_metadata) {
tensorflow::Device* host_device = ctx->HostCPU();
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
- if (memtypes.size() != op->inputs.size()) {
+ if (memtypes.size() != op->Inputs().size()) {
return tensorflow::errors::InvalidArgument(
- "expected ", memtypes.size(), " inputs, got ", op->inputs.size());
+ "expected ", memtypes.size(), " inputs, got ", op->Inputs().size());
}
- for (int i = 0; i < op->inputs.size(); ++i) {
+ for (int i = 0; i < op->Inputs().size(); ++i) {
const tensorflow::Device* expected_device =
memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
- tensorflow::TensorHandle* handle = op->inputs[i];
+ tensorflow::TensorHandle* handle = op->Inputs()[i];
tensorflow::Device* handle_device = nullptr;
TF_RETURN_IF_ERROR(handle->Device(&handle_device));
const tensorflow::Device* actual_device =
@@ -491,7 +489,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
return tensorflow::errors::InvalidArgument(
"Tensors on conflicting devices:"
" cannot compute ",
- op->name, " as input #", i, " was expected to be on ",
+ op->Name(), " as input #", i, " was expected to be on ",
expected_device->name(), " but is actually on ",
actual_device->name(), " (operation running on ",
op_device->name(), ")",
@@ -502,7 +500,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
"between devices"
" may slow down your model");
case tensorflow::DEVICE_PLACEMENT_WARN:
- LOG(WARNING) << "before computing " << op->name << " input #" << i
+ LOG(WARNING) << "before computing " << op->Name() << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
<< " (operation running on " << op_device->name()
@@ -534,16 +532,16 @@ tensorflow::Status ValidateInputTypeAndPlacement(
if (copied_tensor != nullptr) copied_tensor->Unref();
return tensorflow::errors::Internal(
"Failed copying input tensor from ", actual_device->name(), " to ",
- expected_device->name(), " in order to run ", op->name, ": ",
+ expected_device->name(), " in order to run ", op->Name(), ": ",
status.error_message());
}
handle->Unref();
handle = copied_tensor;
- op->inputs[i] = copied_tensor;
+ (*op->MutableInputs())[i] = copied_tensor;
}
if (handle->dtype != kernel->input_type(i)) {
return tensorflow::errors::InvalidArgument(
- "cannot compute ", op->name, " as input #", i,
+ "cannot compute ", op->Name(), " as input #", i,
" was expected to be a ",
tensorflow::DataTypeString(kernel->input_type(i)),
" tensor but is a ", tensorflow::DataTypeString(handle->dtype),
@@ -554,9 +552,10 @@ tensorflow::Status ValidateInputTypeAndPlacement(
}
tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
- TFE_Context* ctx, TF_Status* status) {
+ tensorflow::EagerContext* ctx,
+ TF_Status* status) {
tensorflow::DeviceSet ds;
- for (tensorflow::Device* d : *ctx->context.devices()) {
+ for (tensorflow::Device* d : *ctx->devices()) {
ds.AddDevice(d);
}
tensorflow::DeviceTypeVector final_devices;
@@ -570,7 +569,7 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
"Could not find valid device for node ", ndef.DebugString());
return nullptr;
}
- for (tensorflow::Device* d : *ctx->context.devices()) {
+ for (tensorflow::Device* d : *ctx->devices()) {
if (d->device_type() == final_devices[0].type_string()) {
return d;
}
@@ -599,15 +598,16 @@ const tensorflow::FunctionDef* OpToFunction(
std::vector<TF_DataType>* arg_input_types,
tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input,
TF_Status* status) {
- DCHECK(!op->is_function());
+ DCHECK(!op->operation.is_function());
tensorflow::FunctionDef fdef;
// Get the OpDef of the op we are trying to encapsulate.
- TFE_Context* ctx = op->ctx;
+ TFE_Context* ctx = op->operation.ctx;
const tensorflow::OpRegistrationData* op_data;
{
- status->status = ctx->context.FindFunctionOpData(op->name, &op_data);
+ status->status =
+ ctx->context.FindFunctionOpData(op->operation.Name(), &op_data);
if (!status->status.ok()) {
return nullptr;
}
@@ -618,7 +618,8 @@ const tensorflow::FunctionDef* OpToFunction(
// Handle constant inputs.
const std::unordered_set<string> const_inputs(
- *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
+ *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
+ op->operation.Name()));
// First add place holders for the input args, so that we can refer to them by
// position in the next loop. Also tally up the resource inputs.
@@ -644,7 +645,7 @@ const tensorflow::FunctionDef* OpToFunction(
(*op_input_to_func_input)[i] = const_index;
func_input_arg = signature->mutable_input_arg(const_index++);
const_input_types->push_back(
- static_cast<TF_DataType>(op->inputs[i]->dtype));
+ static_cast<TF_DataType>(op->operation.Inputs()[i]->dtype));
} else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
VLOG(1) << "For resource input, mapping op input " << i
<< " to func input " << resource_index;
@@ -656,11 +657,11 @@ const tensorflow::FunctionDef* OpToFunction(
(*op_input_to_func_input)[i] = arg_index;
func_input_arg = signature->mutable_input_arg(arg_index++);
arg_input_types->push_back(
- static_cast<TF_DataType>(op->inputs[i]->dtype));
+ static_cast<TF_DataType>(op->operation.Inputs()[i]->dtype));
}
func_input_arg->set_name(op_input_arg.name());
- func_input_arg->set_type(op->inputs[i]->dtype);
+ func_input_arg->set_type(op->operation.Inputs()[i]->dtype);
}
VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
@@ -673,7 +674,8 @@ const tensorflow::FunctionDef* OpToFunction(
op_def.name(), func_id_generator.fetch_add(1)));
// Add the node def and set its input names to match op_def's names.
- const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
+ const tensorflow::NodeDef& ndef =
+ op->operation.MutableAttrs()->BuildNodeDef();
DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
*fdef.add_node_def() = ndef;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
@@ -713,17 +715,18 @@ const tensorflow::FunctionDef* OpToFunction(
// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
// via XLA.
std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
- VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
- auto launch_op =
- std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status));
+ VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->operation.Name();
+ auto launch_op = std::unique_ptr<TFE_Op>(
+ TFE_NewOp(op->operation.ctx, "_XlaLaunch", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
- if (op->device) {
- TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
+ if (op->operation.device) {
+ TFE_OpSetDevice(launch_op.get(), op->operation.device->name().c_str(),
+ status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
const tensorflow::FunctionDef* fdef;
- { fdef = op->ctx->context.FindFunctionDef(op->name); }
+ { fdef = op->operation.ctx->FindFunctionDef(op->operation.Name()); }
std::vector<TF_DataType> const_input_types;
std::vector<TF_DataType> arg_input_types;
tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
@@ -748,20 +751,21 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
// Copy inputs and their devices.
// Since input param reordering may have occurred between `op` and `launch_op`
// via `op_input_to_func_input`, adjust the actual inputs accordingly.
- launch_op->inputs = op->inputs;
- for (tensorflow::TensorHandle* h : launch_op->inputs) {
+ *launch_op->operation.MutableInputs() = op->operation.Inputs();
+ for (tensorflow::TensorHandle* h : launch_op->operation.Inputs()) {
h->Ref();
}
if (!op_input_to_func_input.empty()) {
- DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
+ DCHECK_EQ(op->operation.Inputs().size(), op_input_to_func_input.size());
for (int i = 0; i < op_input_to_func_input.size(); ++i) {
VLOG(1) << "mapping op input " << i << " to func input "
<< op_input_to_func_input[i];
- launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
+ (*launch_op->operation.MuableInputs())[op_input_to_func_input[i]] =
+ op->operation.Inputs()[i];
}
}
- launch_op->attrs.NumInputs(op->inputs.size());
+ launch_op->operation.MutableAttrs()->NumInputs(op->operation.Inputs().size());
TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
const_input_types.size());
@@ -796,16 +800,17 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
extern "C" {
-void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
+void TFE_Execute(TFE_Op* tfe_op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
- TFE_Context* ctx = op->ctx;
- status->status = ctx->context.GetStatus();
+ tensorflow::EagerOperation* op = &tfe_op->operation;
+ tensorflow::EagerContext* ctx = op->EagerContext();
+ status->status = ctx->GetStatus();
if (!status->status.ok()) {
return;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
std::unique_ptr<TFE_Op> xla_launch_op;
- if (op->use_xla && op->name != "_XlaLaunch") {
+ if (op->UseXla() && op->Name() != "_XlaLaunch") {
xla_launch_op = BuildXlaLaunch(op, status);
if (!status->status.ok()) {
return;
@@ -816,31 +821,31 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// Ensure all resource-touching ops run in the device the resource is,
// regardless of anything else that has been specified. This is identical to
// the graph mode behavior.
- for (int i = 0; i < op->inputs.size(); ++i) {
+ for (int i = 0; i < op->Inputs().size(); ++i) {
tensorflow::Device* input_op_device = nullptr;
- status->status = op->inputs[i]->OpDevice(&input_op_device);
+ status->status = op->Inputs()[i]->OpDevice(&input_op_device);
if (!status->status.ok()) return;
- VLOG(2) << "for op " << op->name << " input " << i << " "
- << tensorflow::DataTypeString(op->inputs[i]->dtype) << " "
+ VLOG(2) << "for op " << op->Name() << " input " << i << " "
+ << tensorflow::DataTypeString(op->Inputs()[i]->dtype) << " "
<< (input_op_device == nullptr ? "cpu" : input_op_device->name())
- << " " << (op->device == nullptr ? "cpu" : op->device->name());
- if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
- (input_op_device != op->device || input_op_device == nullptr)) {
+ << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
+ if (op->Inputs()[i]->dtype == tensorflow::DT_RESOURCE &&
+ (input_op_device != op->Device() || input_op_device == nullptr)) {
tensorflow::Device* d =
- input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device;
- VLOG(1) << "Changing device of operation " << op->name << " to "
+ input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
+ VLOG(1) << "Changing device of operation " << op->Name() << " to "
<< d->name() << " because input #" << i
<< " is a resource in this device.";
- op->device = d;
+ op->SetDevice(d);
}
}
- tensorflow::Device* device = op->device;
+ tensorflow::Device* device = op->Device();
- tensorflow::Fprint128 cache_key =
- op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
- tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key);
+ tensorflow::Fprint128 cache_key = op->MutableAttrs()->CacheKey(
+ device == nullptr ? "unspecified" : device->name());
+ tensorflow::KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key);
if (kernel == nullptr) {
- const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
+ const tensorflow::NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
if (device == nullptr) {
device = SelectDevice(ndef, ctx, status);
if (!status->status.ok()) {
@@ -848,19 +853,19 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
}
CHECK(device != nullptr);
- if (ctx->context.LogDevicePlacement()) {
+ if (ctx->LogDevicePlacement()) {
LOG(INFO) << "Executing op " << ndef.op() << " in device "
<< device->name();
}
- kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous());
+ kernel = new tensorflow::KernelAndDevice(ctx->GetRendezvous());
// Knowledge of the implementation of Init (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment in Execute (before kernel->Run) - would be nice to
// rework to avoid this subtlety.
- tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu());
- status->status = tensorflow::KernelAndDevice::Init(
- ndef, ctx->context.func_lib(device), kernel);
+ tensorflow::tf_shared_lock l(*ctx->FunctionsMu());
+ status->status =
+ tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
if (!status->status.ok()) {
delete kernel;
return;
@@ -868,7 +873,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// Update output_dtypes inside `kernel`.
const tensorflow::OpDef* op_def = nullptr;
const tensorflow::FunctionDef* function_def =
- ctx->context.FuncLibDef()->Find(ndef.op());
+ ctx->FuncLibDef()->Find(ndef.op());
if (function_def != nullptr) {
op_def = &(function_def->signature());
}
@@ -884,7 +889,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
if (!status->status.ok()) {
return;
}
- ctx->context.AddKernelToCache(cache_key, kernel);
+ ctx->AddKernelToCache(cache_key, kernel);
}
const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
const int output_dtypes_size = output_dtypes.size();
@@ -903,43 +908,42 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
device = kernel->device();
}
status->status = ValidateInputTypeAndPlacement(
- &ctx->context, device, op, kernel->kernel(),
- ctx->context.ShouldStoreMetadata() ? ctx->context.RunMetadataProto()
- : nullptr);
+ ctx, device, op, kernel->kernel(),
+ ctx->ShouldStoreMetadata() ? ctx->RunMetadataProto() : nullptr);
if (!status->status.ok()) return;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
- if (ctx->context.ShouldStoreMetadata()) {
+ if (ctx->ShouldStoreMetadata()) {
maybe_stats.reset(new tensorflow::NodeExecStats);
- maybe_stats->set_node_name(op->name);
+ maybe_stats->set_node_name(op->Name());
maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
maybe_stats->set_op_start_rel_micros(0);
maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
// TODO(apassos) track referenced tensors
}
- if (ctx->context.Async()) {
+ if (ctx->Async()) {
// Note that for async mode, execution order will make sure that all
// input handles are ready before executing them.
// TODO(agarwal): Consider executing "cheap" kernels inline for performance.
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
*num_retvals);
- tensorflow::uint64 id = op->ctx->context.NextId();
+ tensorflow::uint64 id = ctx->NextId();
for (int i = 0; i < *num_retvals; ++i) {
tensorflow::TensorHandle* h =
- new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context);
+ new tensorflow::TensorHandle(id, output_dtypes[i], ctx);
retvals[i] = new TFE_TensorHandle(h);
handle_retvals[i] = h;
}
tensorflow::EagerNode* node = new tensorflow::ExecuteNode(
- id, &op->ctx->context, op->device, op->inputs, kernel,
- maybe_stats.release(), output_dtypes, handle_retvals);
- ctx->context.ExecutorAdd(node);
+ id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(),
+ output_dtypes, handle_retvals);
+ ctx->ExecutorAdd(node);
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
*num_retvals);
status->status = tensorflow::EagerExecute(
- &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(),
+ ctx, op->Device(), op->Inputs(), kernel, maybe_stats.get(),
handle_retvals.data(), *num_retvals);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
@@ -1142,9 +1146,3 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
}
}
} // namespace tensorflow
-
-TFE_Op::~TFE_Op() {
- for (tensorflow::TensorHandle* h : inputs) {
- h->Unref();
- }
-}
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 05dc64f521..49e1aab1ce 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -45,7 +46,6 @@ limitations under the License.
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
-
struct TFE_ContextOptions {
TF_SessionOptions session_options;
// true if async execution is enabled.
@@ -85,19 +85,9 @@ struct TFE_Op {
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
// primitive operation.
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
- : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
-
- ~TFE_Op();
-
- bool const is_function() const { return attr_types == nullptr; }
+ : operation(&ctx->context, op, t) {}
- TFE_Context* ctx; // Must outlive the TFE_Op.
- const tensorflow::string name;
- tensorflow::AttrBuilder attrs;
- const tensorflow::AttrTypeMap* attr_types;
- tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs;
- tensorflow::Device* device;
- bool use_xla = false;
+ tensorflow::EagerOperation operation;
};
namespace tensorflow {