diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2017-06-23 12:55:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-23 12:59:28 -0700 |
commit | 6ada43366663210beb0159b8c1a67b26ebfe6cb7 (patch) | |
tree | 1f41bb5c9d2000cb4dd47645f57c181aef3fae3e | |
parent | 0eff699d3087171cf35671d9d0bd6f8e79441ab3 (diff) |
Prepare to not include node_def.proto.h in node_def_util.h
The goal is to make kernels mostly independent of proto headers, which will let
us lock down our .so imports. This CL makes a bunch of .cc files
either include node_def.proto.h themselves or not need the definition of
NodeDef; a second CL will make node_def_util.h not include node_def.proto.h.
RELNOTES: n/a
PiperOrigin-RevId: 159982117
58 files changed, 261 insertions, 169 deletions
diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc index c86e03118b..bd4eefbc0b 100644 --- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc +++ b/tensorflow/compiler/jit/kernels/parallel_check_op.cc @@ -64,7 +64,7 @@ class ParallelCheckOp : public OpKernel { ok = (diff <= tolerance); } if (ok) continue; - LOG(ERROR) << "Op " << def().name() << " fails equality at output " + LOG(ERROR) << "Op " << name() << " fails equality at output " << input_idx << " type " << DataTypeString(dtype) << " element " << i << ": std_val=" << p0[i] << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); @@ -75,7 +75,7 @@ class ParallelCheckOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - VLOG(1) << "Compute " << def().name(); + VLOG(1) << "Compute " << name(); const int num_pairs = ctx->num_inputs() / 2; for (int i = 0; i < num_pairs; ++i) { CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); @@ -113,7 +113,7 @@ class ParallelCheckOp : public OpKernel { LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); } if (failed > 0) { - LOG(ERROR) << "check failed for " << def().name() << " output " << i + LOG(ERROR) << "check failed for " << name() << " output " << i << " num_elts: " << num_elts; legacy_flags::ParallelCheckOpFlags* flags = legacy_flags::GetParallelCheckOpFlags(); @@ -121,7 +121,7 @@ class ParallelCheckOp : public OpKernel { LOG(QFATAL) << "failfast on first parallel-check failure"; } } else { - VLOG(1) << "check passed for " << def().name() << " output " << i + VLOG(1) << "check passed for " << name() << " output " << i << " num_elts: " << num_elts; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index f1fef85f99..eb9c348f7a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index f4d3bc9635..997ecd7ebb 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -249,6 +249,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index 8dacb6627b..af1085d5b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/node_def.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 3c34b8788d..3c6c9a91b6 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index aaef27f16d..20adf300ec 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/contrib/boosted_trees/kernels/ensemble_optimizer_ops.cc b/tensorflow/contrib/boosted_trees/kernels/ensemble_optimizer_ops.cc index 000b2e903a..a6b7a050c8 100644 --- a/tensorflow/contrib/boosted_trees/kernels/ensemble_optimizer_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/ensemble_optimizer_ops.cc @@ -86,7 +86,7 @@ class AddTreesToEnsembleOp : public OpKernel { OP_REQUIRES(context, fc_usage_counts_lhs_t.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", - def().input(kFeatureColumnUsageCountsHandleIdx))); + requested_input(kFeatureColumnUsageCountsHandleIdx))); Tensor fc_gains_lhs_t = context->mutable_input(kFeatureColumnGainsHandleIdx, true); @@ -95,7 +95,7 @@ class AddTreesToEnsembleOp : public OpKernel { OP_REQUIRES(context, fc_gains_lhs_t.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", - def().input(kFeatureColumnGainsHandleIdx))); + requested_input(kFeatureColumnGainsHandleIdx))); const Tensor fc_usage_counts_rhs_t = context->input(kFeatureColumnUsageCountsToAddIdx); diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 1c80ffcd7b..c390136d46 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -46,6 +46,10 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h" + "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" + "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h" "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index a048194a19..406b656349 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -223,6 +223,10 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/framework/*.h" "${tensorflow_source_dir}/tensorflow/core/framework/*.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h" + "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" + "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" "${tensorflow_source_dir}/tensorflow/core/util/*.h" "${tensorflow_source_dir}/tensorflow/core/util/*.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc index 1f079027ef..3e387129eb 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc @@ -16,6 +16,7 @@ limitations under the License. #include <unordered_set> #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index b2dd4c6e52..21a20bcc4d 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 24b519fb07..ce6bbb0d65 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1743,7 +1743,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) { return AttachDef(errors::FailedPrecondition( "Attempting to use uninitialized value ", - item.kernel->def().input(i)), + item.kernel->requested_input(i)), item.kernel->def()); } } diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index b0b834b66a..99389968ee 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 7eda5c90a1..f60d3a89c6 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -391,7 +391,7 @@ void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel, if (vlog_1) { VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op " - << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream[" + << op_kernel->type_string() << " on GPU" << gpu_id_ << " stream[" << stream_id << "]"; } @@ -465,7 +465,7 @@ void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel, const auto stream_id = gpu_device_context->stream_id(); VLOG(1) << "GpuDevice::ComputeAsync " << op_kernel->name() << " op " - << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream[" + << op_kernel->type_string() << " on GPU" << gpu_id_ << " stream[" << stream_id << "]"; // When TraceMe profiling is off (which is the default), the diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index f7b12bdaee..a6204b9d0d 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -294,7 +294,7 @@ Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge, } InferenceContext* c = it->second.get(); - if (node->def().op() == "Shape") { + if (node->type_string() == "Shape") { // If input shapes to the shape op are fully defined, // we can infer the shape op's output tensor. bool fully_defined_inputs = c->FullyDefined(c->input(0)); @@ -324,7 +324,7 @@ Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge, *output = t; *success = true; } - } else if (node->def().op() == "Rank") { + } else if (node->type_string() == "Rank") { bool rank_known = c->RankKnown(c->input(0)); if (rank_known) { int32 input_rank = c->Rank(c->input(0)); @@ -333,7 +333,7 @@ Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge, *output = t; *success = true; } - } else if (node->def().op() == "Size") { + } else if (node->type_string() == "Size") { bool fully_defined_inputs = c->FullyDefined(c->input(0)); if (fully_defined_inputs) { int32 rank = c->Rank(c->input(0)); diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 8206a678b4..c00eb3a2fc 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/simple_placer.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/subgraph.h" diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index f4bf9dcd3b..69f5b7d944 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index c79f68a068..2ea5194d9b 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph_partition.h" diff --git a/tensorflow/core/example/example_parser_configuration.cc b/tensorflow/core/example/example_parser_configuration.cc index e4a3f26209..485cf6da4b 100644 --- a/tensorflow/core/example/example_parser_configuration.cc +++ b/tensorflow/core/example/example_parser_configuration.cc @@ -17,6 +17,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/example/feature.pb_text.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 9026075a2f..fe6e9a6cd6 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/function.pb_text.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 7caddf3cb8..4ee23226da 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index 56b8f0aa1b..49e5b0c99d 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index 8496774793..aeedf4b0ef 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/versions.pb_text.h" diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index c1dde1504a..6a2eed94b9 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -18,6 +18,7 @@ limitations under the License. #include <utility> #include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h index c09d96bfa6..fd26d0ae64 100644 --- a/tensorflow/core/framework/node_def_builder.h +++ b/tensorflow/core/framework/node_def_builder.h @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index dec987e1ed..3892320b7d 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/kernel_def.pb_text.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" @@ -77,7 +78,7 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ OpKernel::OpKernel(OpKernelConstruction* context) - : def_(context->def()), + : def_(new NodeDef(context->def())), input_types_(context->input_types().begin(), context->input_types().end()), input_memory_types_(context->input_memory_types().begin(), @@ -91,7 +92,7 @@ OpKernel::OpKernel(OpKernelConstruction* context) input_name_map_(context->num_inputs()), output_name_map_(context->num_outputs()) { OP_REQUIRES_OK(context, - NameRangesForNode(def_, *context->op_def_, &input_name_map_, + NameRangesForNode(*def_, *context->op_def_, &input_name_map_, &output_name_map_)); OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_, context->graph_def_version())); @@ -103,6 +104,11 @@ OpKernel::OpKernel(OpKernelConstruction* context) OpKernel::~OpKernel() {} +const string& OpKernel::name() const { return def_->name(); } +const string& OpKernel::type_string() const { return def_->op(); } +const string& OpKernel::requested_device() const { return def_->device(); } +const string& OpKernel::requested_input(int i) const { return def_->input(i); } + Status OpKernel::InputRange(StringPiece input_name, int* start, int* stop) const { const auto result = input_name_map_.find(input_name.ToString()); @@ -165,6 +171,26 @@ Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { // OpKernelConstruction ------------------------------------------------------ +OpKernelConstruction::OpKernelConstruction( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib, + const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types, + const DataTypeSlice& output_types, + const MemoryTypeSlice& output_memory_types, int graph_def_version, + Status* status) + : device_type_(std::move(device_type)), + device_(device), + allocator_(allocator), + def_(node_def), + op_def_(op_def), + flib_(flib), + input_types_(input_types), + input_memory_types_(input_memory_types), + output_types_(output_types), + output_memory_types_(output_memory_types), + graph_def_version_(graph_def_version), + status_(status) {} + void OpKernelConstruction::SetStatus(const Status& status) { status_->Update(status); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 465395d858..f8a5946116 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -109,9 +109,10 @@ class OpKernel { virtual bool IsExpensive() { return expensive_; } // Accessors. - const NodeDef& def() const { return def_; } - const string& name() const { return def_.name(); } - const string& type_string() const { return def_.op(); } + const NodeDef& def() const { return *def_; } + const string& name() const; // Same as def().name() + const string& type_string() const; // Same as def().op() + const string& requested_device() const; // Same as def().device() bool is_internal() const { return is_internal_; } int num_inputs() const { return input_types_.size(); } @@ -120,6 +121,7 @@ class OpKernel { const MemoryTypeVector& input_memory_types() const { return input_memory_types_; } + const string& requested_input(int i) const; // Same as def().input(i) int num_outputs() const { return output_types_.size(); } DataType output_type(int o) const { return output_types_[o]; } @@ -157,7 +159,7 @@ class OpKernel { Status MakeShape(const Tensor& shape, TensorShape* out) const; private: - const NodeDef def_; + const std::unique_ptr<const NodeDef> def_; const DataTypeVector input_types_; const MemoryTypeVector input_memory_types_; const DataTypeVector output_types_; @@ -227,19 +229,7 @@ class OpKernelConstruction { const MemoryTypeSlice& input_memory_types, const DataTypeSlice& output_types, const MemoryTypeSlice& output_memory_types, - int graph_def_version, Status* status) - : device_type_(std::move(device_type)), - device_(device), - allocator_(allocator), - def_(node_def), - op_def_(op_def), - flib_(flib), - input_types_(input_types), - input_memory_types_(input_memory_types), - output_types_(output_types), - output_memory_types_(output_memory_types), - graph_def_version_(graph_def_version), - status_(status) {} + int graph_def_version, Status* status); Env* env() const { return device_->env(); } diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 4365a861e5..3018e4f655 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index cc7613b97d..df4d8c3591 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 2c18ddd48f..62f85d2dac 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -97,6 +97,21 @@ InferenceContext::InferenceContext( InferenceContext::~InferenceContext() {} +Status InferenceContext::Run( + const std::function<Status(shape_inference::InferenceContext* c)>& fn) { + Status s = fn(this); + if (!s.ok()) { + return AttachContext(s); + } +#ifndef NDEBUG + for (int i = 0; i < num_outputs(); ++i) { + DCHECK(output(i).IsSet()) + << i << " for " << node_def_.name() << " of type " << node_def_.op(); + } +#endif // NDEBUG + return s; +} + Status InferenceContext::set_output(StringPiece output_name, const std::vector<ShapeHandle>& shapes) { const auto result = output_name_map_.find(output_name.ToString()); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 5668667659..460aefe29e 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -196,19 +196,7 @@ class InferenceContext { // // On error, additional context is provided in the error message. Status Run( - const std::function<Status(shape_inference::InferenceContext* c)>& fn) { - Status s = fn(this); - if (!s.ok()) { - return AttachContext(s); - } -#ifndef NDEBUG - for (int i = 0; i < num_outputs(); ++i) { - DCHECK(output(i).IsSet()) - << i << " for " << node_def_.name() << " of type " << node_def_.op(); - } -#endif // NDEBUG - return s; - } + const std::function<Status(shape_inference::InferenceContext* c)>& fn); // Merge the stored shape of the input in position idx with <shape> according // to the following rules: diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index 6bd2cd4291..03c39e6dc1 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -17,6 +17,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc index 09c3d8d567..d3e7ff781c 100644 --- a/tensorflow/core/graph/gradients.cc +++ b/tensorflow/core/graph/gradients.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index dcb8520cf7..5d60e41d26 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include <vector> +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,6 +29,26 @@ namespace tensorflow { const int Graph::kControlSlot = -1; +class NodeProperties : public core::RefCounted { + public: + NodeProperties(const OpDef* op_def, const NodeDef& node_def, + const DataTypeSlice inputs, const DataTypeSlice outputs) + : op_def_(op_def), + node_def_(node_def), + input_types_(inputs.begin(), inputs.end()), + output_types_(outputs.begin(), outputs.end()) {} + + const OpDef* op_def_; // not owned + NodeDef node_def_; + const DataTypeVector input_types_; + const DataTypeVector output_types_; + + private: + // Destructor invoked when last reference goes away via Unref() + ~NodeProperties() override {} + TF_DISALLOW_COPY_AND_ASSIGN(NodeProperties); +}; + // Node #define REF_CLASS(key, value) \ @@ -99,7 +120,7 @@ Node::~Node() { } } -void Node::Initialize(int id, int cost_id, Properties* props) { +void Node::Initialize(int id, int cost_id, NodeProperties* props) { DCHECK_EQ(id_, -1); DCHECK(in_edges_.empty()); DCHECK(out_edges_.empty()); @@ -130,6 +151,29 @@ void Node::Clear() { assigned_device_name_index_ = 0; } +const string& Node::name() const { return props_->node_def_.name(); } +const string& Node::type_string() const { return props_->node_def_.op(); } +const NodeDef& Node::def() const { return props_->node_def_; } +const OpDef& Node::op_def() const { return *props_->op_def_; } + +int32 Node::num_inputs() const { return props_->input_types_.size(); } +DataType Node::input_type(int32 i) const { return props_->input_types_[i]; } +const DataTypeVector& Node::input_types() const { return props_->input_types_; } + +int32 Node::num_outputs() const { return props_->output_types_.size(); } +DataType Node::output_type(int32 o) const { return props_->output_types_[o]; } +const DataTypeVector& Node::output_types() const { + return props_->output_types_; +} + +AttrSlice Node::attrs() const { return AttrSlice(def()); } + +const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const { + return def().input(); +} + +const string& Node::requested_device() const { return def().device(); } + gtl::iterator_range<NeighborIter> Node::out_nodes() const { return gtl::make_range(NeighborIter(out_edges_.begin(), false), NeighborIter(out_edges_.end(), false)); @@ -141,16 +185,21 @@ gtl::iterator_range<NeighborIter> Node::in_nodes() const { } void Node::MaybeCopyOnWrite() { - // Properties may be shared between Nodes. Make a copy if so. + // NodeProperties may be shared between Nodes. Make a copy if so. if (!props_->RefCountIsOne()) { - Properties* new_props = - new Properties(props_->op_def_, props_->node_def_, props_->input_types_, - props_->output_types_); + NodeProperties* new_props = + new NodeProperties(props_->op_def_, props_->node_def_, + props_->input_types_, props_->output_types_); props_->Unref(); props_ = new_props; } } +AttrValue* Node::AddAttrHelper(const string& name) { + MaybeCopyOnWrite(); + return &((*props_->node_def_.mutable_attr())[name]); +} + void Node::ClearAttr(const string& name) { MaybeCopyOnWrite(); (*props_->node_def_.mutable_attr()).erase(name); @@ -225,17 +274,6 @@ Status Node::input_node(int idx, const Node** const_n) const { return Status::OK(); } -// Node::Properties - -Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def, - const DataTypeSlice inputs, - const DataTypeSlice outputs) - : op_def_(op_def), - node_def_(node_def), - input_types_(inputs.begin(), inputs.end()), - output_types_(outputs.begin(), outputs.end()) {} - -Node::Properties::~Properties() {} // Graph @@ -300,14 +338,14 @@ Node* Graph::AddNode(const NodeDef& node_def, Status* status) { } Node* node = AllocateNode( - new Node::Properties(op_def, node_def, inputs, outputs), nullptr); + new NodeProperties(op_def, node_def, inputs, outputs), nullptr); return node; } Node* Graph::CopyNode(Node* node) { DCHECK(!node->IsSource()); DCHECK(!node->IsSink()); - Node::Properties* props = node->properties(); + NodeProperties* props = node->properties(); props->Ref(); Node* copy = AllocateNode(props, node); copy->set_assigned_device_name(node->assigned_device_name()); @@ -502,7 +540,7 @@ bool Graph::IsValidNode(Node* node) const { return nodes_[id] == node; } -Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) { +Node* Graph::AllocateNode(NodeProperties* props, const Node* cost_node) { Node* node = nullptr; if (free_nodes_.empty()) { node = new (arena_.Alloc(sizeof(Node))) Node; // placement new diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 8cb270170e..08e2838d3c 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -63,14 +63,15 @@ class Node; class NeighborIter; // Declared below class NodeIter; // Declared below +class NodeProperties; // Defined in .cc class Node { public: string DebugString() const; int id() const { return id_; } int cost_id() const { return cost_id_; } - const string& name() const { return props_->node_def_.name(); } - const string& type_string() const { return props_->node_def_.op(); } + const string& name() const; + const string& type_string() const; // def() provides the NodeDef the user supplied, but the specifics // of this Node may have changed due to placement, optimization, etc. @@ -82,21 +83,21 @@ class Node { // the actual assigned device, see assigned_device_name() below; // * def().attr() is authoritative. // TODO(irving): Replace with NodeInfo. - const NodeDef& def() const { return props_->node_def_; } - const OpDef& op_def() const { return *props_->op_def_; } + const NodeDef& def() const; + const OpDef& op_def() const; // input and output types - int32 num_inputs() const { return props_->input_types_.size(); } - DataType input_type(int32 i) const { return props_->input_types_[i]; } - const DataTypeVector& input_types() const { return props_->input_types_; } + int32 num_inputs() const; + DataType input_type(int32 i) const; + const DataTypeVector& input_types() const; - int32 num_outputs() const { return props_->output_types_.size(); } - DataType output_type(int32 o) const { return props_->output_types_[o]; } - const DataTypeVector& output_types() const { return props_->output_types_; } + int32 num_outputs() const; + DataType output_type(int32 o) const; + const DataTypeVector& output_types() const; // The device requested by the user. For the actual assigned device, // use assigned_device_name() below. - const string& requested_device() const { return def().device(); } + const string& requested_device() const; // This gives the device the runtime has assigned this node to. If // you want the device the user requested, use def().device() instead. @@ -113,12 +114,10 @@ class Node { void set_assigned_device_name_index(int index); // Read only access to attributes - AttrSlice attrs() const { return AttrSlice(def()); } + AttrSlice attrs() const; // Inputs requested by the NodeDef. For the actual inputs, use in_edges. - const protobuf::RepeatedPtrField<string>& requested_inputs() const { - return def().input(); - } + const protobuf::RepeatedPtrField<string>& requested_inputs() const; // Get the neighboring nodes via edges either in or out of this node. gtl::iterator_range<NeighborIter> in_nodes() const; @@ -162,8 +161,7 @@ class Node { template <typename T> void AddAttr(const string& name, const T& val) { - MaybeCopyOnWrite(); - SetAttrValue(val, &((*props_->node_def_.mutable_attr())[name])); + SetAttrValue(val, AddAttrHelper(name)); } void ClearAttr(const string& name); @@ -185,36 +183,24 @@ class Node { Node(); ~Node(); - class Properties : public core::RefCounted { - public: - Properties(const OpDef* op_def, const NodeDef& node_def, - const DataTypeSlice inputs, const DataTypeSlice outputs); - - const OpDef* op_def_; // not owned - NodeDef node_def_; - const DataTypeVector input_types_; - const DataTypeVector output_types_; - - private: - // Destructor invoked when last reference goes away via Unref() - virtual ~Properties(); - TF_DISALLOW_COPY_AND_ASSIGN(Properties); - }; - - Properties* properties() const { return props_; } + NodeProperties* properties() const { return props_; } // Initialize() adopts a reference to props, and so is suitable if props was // just allocated or you call props->Ref() to increment the reference // count for a props being held by another Node. - void Initialize(int id, int cost_id, Properties* props); + void Initialize(int id, int cost_id, NodeProperties* props); + // Releases memory from props_, in addition to restoring *this to its // uninitialized state. void Clear(); + // Make a copy of the Node's props_ if props_ is shared with // other nodes. This must be called before mutating properties, // e.g. in AddAttr. void MaybeCopyOnWrite(); + AttrValue* AddAttrHelper(const string& name); + // A set of mutually exclusive classes for different kinds of nodes, // class_ is initialized in the Node::Initialize routine based on the // node's type_string(). @@ -252,7 +238,7 @@ class Node { EdgeSet in_edges_; EdgeSet out_edges_; - Properties* props_; + NodeProperties* props_; // Index within Graph::device_names_ of the name of device assigned // to perform this computation. @@ -519,7 +505,7 @@ class Graph { // If cost_node is non-null, then cost accounting (in CostModel) // will be associated with that node rather than the new one being // created. - Node* AllocateNode(Node::Properties* props, const Node* cost_node); + Node* AllocateNode(NodeProperties* props, const Node* cost_node); void ReleaseNode(Node* node); // Registry of all known ops, including functions. diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 10f110686f..38a780dfac 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/versions.h" diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc index 54cfd10cdf..47337ce8a2 100644 --- a/tensorflow/core/graph/optimizer_cse.cc +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -42,6 +42,7 @@ limitations under the License. #include <utility> #include <vector> +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc index bfdc5cab0d..bd905651d2 100644 --- a/tensorflow/core/graph/validate.cc +++ b/tensorflow/core/graph/validate.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/validate.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/versions.pb.h" diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index fc4cc2d83a..8b7c269a11 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -300,6 +300,7 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", ], ) @@ -322,6 +323,7 @@ cc_library( ":typed_queue", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", ], ) @@ -413,6 +415,7 @@ cc_library( hdrs = ["warn_about_ints.h"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", ], ) @@ -1413,7 +1416,7 @@ tf_kernel_library( tf_kernel_library( name = "random_shuffle_queue_op", prefix = "random_shuffle_queue_op", - deps = DATA_FLOW_DEPS, + deps = DATA_FLOW_DEPS + ["//tensorflow/core:protos_all_cc"], ) tf_kernel_library( @@ -1514,6 +1517,7 @@ cc_library( ":typed_queue", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", ], ) @@ -1528,6 +1532,7 @@ cc_library( ":typed_queue", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index 33991fa1f9..ef34946d96 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -111,7 +111,7 @@ class DenseUpdateOp : public OpKernel { OP_REQUIRES(context, Tparams.IsInitialized(), errors::FailedPrecondition("Attempting to use uninitialized " "parameters: ", - def().input(0))); + requested_input(0))); OP_REQUIRES( context, Tparams.IsSameSize(Tupdate), errors::InvalidArgument("Parameters and update must be the same size")); diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index 030cf8a49d..ea86b04762 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -19,6 +19,7 @@ limitations under the License. #include <deque> #include <vector> +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index b6825b4e95..f143d39816 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -168,7 +168,7 @@ class OneShotIteratorOp : public AsyncOpKernel { thread_pool_(new thread::ThreadPool( ctx->env(), ThreadOptions(), strings::StrCat("one_shot_iterator_initialization_thread_", - SanitizeThreadSuffix(def().name())), + SanitizeThreadSuffix(name())), 1 /* num_threads */, false /* low_latency_hint */)) { @@ -359,7 +359,7 @@ class IteratorGetNextOp : public AsyncOpKernel { thread_pool_(new thread::ThreadPool( ctx->env(), ThreadOptions(), strings::StrCat("iterator_get_next_thread_", - SanitizeThreadSuffix(def().name())), + SanitizeThreadSuffix(name())), 1 /* num_threads */, false /* low_latency_hint */)) {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h index 96de4094fe..2554359d13 100644 --- a/tensorflow/core/kernels/ops_testutil.h +++ b/tensorflow/core/kernels/ops_testutil.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index f4626d4a5d..d0f7683f3d 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -18,6 +18,7 @@ limitations under the License. #include <deque> #include <vector> +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/core/kernels/priority_queue.cc b/tensorflow/core/kernels/priority_queue.cc index 8884c0c4a0..894ad3c9a0 100644 --- a/tensorflow/core/kernels/priority_queue.cc +++ b/tensorflow/core/kernels/priority_queue.cc @@ -18,6 +18,7 @@ limitations under the License. #include <queue> #include <vector> +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc index 56a7e161df..5658dcc069 100644 --- a/tensorflow/core/kernels/quantized_conv_ops.cc +++ b/tensorflow/core/kernels/quantized_conv_ops.cc @@ -211,7 +211,7 @@ class Im2ColConvFunctor { ++warning_count; LOG(WARNING) << "For kernel '" << context->op_kernel().name() << "' from input '" - << context->op_kernel().def().input(0) + << context->op_kernel().requested_input(0) << "': Zero is not representable in the quantized range used by the" << " input. This means QuantizedConv2d has to fall back to a slow" << " implementation, since the border of zero values can't be" diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc index 07ff70a875..8a9af39e1f 100644 --- a/tensorflow/core/kernels/queue_base.cc +++ b/tensorflow/core/kernels/queue_base.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/queue_base.h" #include <vector> +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/mutex.h" diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index d9efb5fe7d..30bbbd4aed 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include <deque> #include <vector> +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/reader_ops.cc b/tensorflow/core/kernels/reader_ops.cc index e2eb40677b..abd16de6a1 100644 --- a/tensorflow/core/kernels/reader_ops.cc +++ b/tensorflow/core/kernels/reader_ops.cc @@ -50,8 +50,7 @@ class ReaderVerbAsyncOpKernel : public AsyncOpKernel { : AsyncOpKernel(context), thread_pool_(new thread::ThreadPool( context->env(), ThreadOptions(), - strings::StrCat("reader_thread_", - SanitizeThreadSuffix(def().name())), + strings::StrCat("reader_thread_", SanitizeThreadSuffix(name())), 1 /* num_threads */, false /* low_latency_hint */)) {} void ComputeAsync(OpKernelContext* context, DoneCallback done) override { diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc index b24482f2d5..ad95b25cb1 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc index 27ad2fcd87..185c5b248f 100644 --- a/tensorflow/core/kernels/session_ops.cc +++ b/tensorflow/core/kernels/session_ops.cc @@ -43,21 +43,21 @@ class GetSessionHandleOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& val = ctx->input(0); int64 id = ctx->session_state()->GetNewId(); - TensorStore::TensorAndKey tk{val, id, def().device()}; - OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(def().name(), tk)); + TensorStore::TensorAndKey tk{val, id, requested_device()}; + OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk)); Tensor* handle = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); if (ctx->expected_output_dtype(0) == DT_RESOURCE) { ResourceHandle resource_handle = MakeResourceHandle<Tensor>( ctx, SessionState::kTensorHandleResourceTypeName, - tk.GetHandle(def().name())); + tk.GetHandle(name())); resource_handle.set_maybe_type_name( SessionState::kTensorHandleResourceTypeName); handle->scalar<ResourceHandle>()() = resource_handle; } else { // Legacy behavior in V1. - handle->flat<string>().setConstant(tk.GetHandle(def().name())); + handle->flat<string>().setConstant(tk.GetHandle(name())); } } diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index f6b6194f0a..e798a7e4a4 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -352,7 +352,7 @@ class ApplyGradientDescentOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); const Tensor& alpha = ctx->input(1); OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()), errors::InvalidArgument("alpha is not a scalar: ", @@ -391,7 +391,7 @@ class ApplyGradientDescentOp < SYCLDevice, T > : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); const Tensor& alpha_dev = ctx->input(1); OP_REQUIRES(ctx, IsLegacyScalar(alpha_dev.shape()), errors::InvalidArgument("alpha is not a scalar: ", @@ -480,7 +480,7 @@ class ApplyDelayCompensatedGradientDescentOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); const Tensor& alpha = ctx->input(1); OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()), errors::InvalidArgument("alpha is not a scalar: ", @@ -575,15 +575,15 @@ class ApplyAdadeltaOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, accum_update.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); const Tensor& lr = ctx->input(3); const Tensor& rho = ctx->input(4); @@ -711,15 +711,15 @@ class SparseApplyAdadeltaOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum_grad.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, accum_update.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); OP_REQUIRES( ctx, var.shape().IsSameSize(accum_grad.shape()), errors::InvalidArgument("var and accum_grad do not have the same shape", @@ -851,7 +851,7 @@ class ApplyProximalGradientDescentOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); const Tensor& alpha = ctx->input(1); OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()), errors::InvalidArgument("alpha is not a scalar: ", @@ -1066,11 +1066,11 @@ class ApplyAdagradOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); const Tensor& lr = ctx->input(2); OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", @@ -1159,11 +1159,11 @@ class ApplyProximalAdagradOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", @@ -1266,11 +1266,11 @@ class SparseApplyAdagradOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", @@ -1400,11 +1400,11 @@ class SparseApplyProximalAdagradOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", @@ -1575,15 +1575,15 @@ class ApplyAdagradDAOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, gradient_accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, gradient_squared_accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); OP_REQUIRES( ctx, var.shape().IsSameSize(gradient_accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", @@ -1677,15 +1677,15 @@ class SparseApplyAdagradDAOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, gradient_accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, gradient_squared_accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); OP_REQUIRES( ctx, var.shape().IsSameSize(gradient_accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", @@ -1874,15 +1874,15 @@ class ApplyFtrlOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, linear.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); const Tensor& grad = ctx->input(3); OP_REQUIRES( @@ -1988,15 +1988,15 @@ class SparseApplyFtrlOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, linear.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", @@ -2196,11 +2196,11 @@ class ApplyMomentumOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); const Tensor& lr = ctx->input(2); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", @@ -2299,11 +2299,11 @@ class SparseApplyMomentumOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", @@ -2419,15 +2419,15 @@ class ApplyAdamOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, m.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, v.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); const Tensor& beta1_power = ctx->input(3); const Tensor& beta2_power = ctx->input(4); @@ -2505,15 +2505,15 @@ class ApplyAdamOp < SYCLDevice, T> : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, m.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, v.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); const Tensor& beta1_power_dev = ctx->input(3); const Tensor& beta2_power_dev = ctx->input(4); @@ -2679,15 +2679,15 @@ class ApplyRMSPropOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, ms.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, mom.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); const Tensor& lr = ctx->input(3); const Tensor& rho = ctx->input(4); @@ -2764,19 +2764,19 @@ class ApplyCenteredRMSPropOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, mg.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, ms.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); OP_REQUIRES( ctx, mom.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(3))); + "Attempting to use uninitialized variables: ", requested_input(3))); const Tensor& lr = ctx->input(4); const Tensor& rho = ctx->input(5); @@ -2922,15 +2922,15 @@ class SparseApplyRMSPropOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, ms.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(1))); + "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, mom.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); const Tensor& lr = ctx->input(3); const Tensor& rho = ctx->input(4); @@ -3054,15 +3054,15 @@ class SparseApplyCenteredRMSPropOp : public OpKernel { OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(0))); + "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, ms.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(2))); + "Attempting to use uninitialized variables: ", requested_input(2))); OP_REQUIRES( ctx, mom.IsInitialized(), errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", def().input(3))); + "Attempting to use uninitialized variables: ", requested_input(3))); const Tensor& lr = ctx->input(4); const Tensor& rho = ctx->input(5); diff --git a/tensorflow/core/kernels/warn_about_ints.cc b/tensorflow/core/kernels/warn_about_ints.cc index fd0a889c99..75ecdf2ae4 100644 --- a/tensorflow/core/kernels/warn_about_ints.cc +++ b/tensorflow/core/kernels/warn_about_ints.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/warn_about_ints.h" +#include "tensorflow/core/framework/node_def.pb.h" namespace tensorflow { diff --git a/tensorflow/core/util/equal_graph_def.cc b/tensorflow/core/util/equal_graph_def.cc index 2db026da56..45d6a6662a 100644 --- a/tensorflow/core/util/equal_graph_def.cc +++ b/tensorflow/core/util/equal_graph_def.cc @@ -18,6 +18,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc index d5e58c174b..8ebdbafb85 100644 --- a/tensorflow/python/framework/cpp_shape_inference.cc +++ b/tensorflow/python/framework/cpp_shape_inference.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/python/framework/cpp_shape_inference.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/python/util/kernel_registry.cc b/tensorflow/python/util/kernel_registry.cc index d451bbace2..7d47692f6b 100644 --- a/tensorflow/python/util/kernel_registry.cc +++ b/tensorflow/python/util/kernel_registry.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/python/util/kernel_registry.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index dfad11adf0..81892aef96 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -27,6 +27,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/tools/graph_transforms/rename_attribute_test.cc b/tensorflow/tools/graph_transforms/rename_attribute_test.cc index a0a33e9fc0..31619d82ad 100644 --- a/tensorflow/tools/graph_transforms/rename_attribute_test.cc +++ b/tensorflow/tools/graph_transforms/rename_attribute_test.cc @@ -43,17 +43,17 @@ class RenameAttributeTest : public ::testing::Test { mul_node1->set_op("Mul"); mul_node1->add_input("add_node2"); mul_node1->add_input("add_node3"); - AddNodeAttr<int32>("foo", 23, mul_node1); - AddNodeAttr<string>("bar", "something", mul_node1); + AddNodeAttr("foo", 23, mul_node1); + AddNodeAttr("bar", "something", mul_node1); NodeDef* add_node2 = graph_def.add_node(); add_node2->set_name("add_node2"); add_node2->set_op("Add"); add_node2->add_input("const_node1"); add_node2->add_input("const_node2"); - AddNodeAttr<int32>("foo", 46, add_node2); - AddNodeAttr<int32>("bob", 23, add_node2); - AddNodeAttr<string>("bar", "something else", add_node2); + AddNodeAttr("foo", 46, add_node2); + AddNodeAttr("bob", 23, add_node2); + AddNodeAttr("bar", "something else", add_node2); NodeDef* add_node3 = graph_def.add_node(); add_node3->set_name("add_node3"); |