aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-05-16 16:08:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-16 16:12:05 -0700
commit749e5cc18381f7a5ec174673f76e20aead8529c6 (patch)
tree4b92d36c9e1d8e59e34fd8d08e7f11fbda1315d9 /tensorflow/core/framework
parented5d05d8b53425ef98aad129a60143a5011a4288 (diff)
Reduce direct references to NodeDef in favor of Node and AttrSlice
This is one step towards replacing in-memory use of NodeDef with a customized NodeInfo class. There are still quite a few Node::def() references, but far fewer than before. Those remaining require more work, either because they are part of kernel registration (which is a bunch of functions), copy and modify the NodeDef, etc. Follow-on CLs will remove more. RELNOTES: n/a PiperOrigin-RevId: 156244933
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/function.cc102
-rw-r--r--tensorflow/core/framework/function.h49
-rw-r--r--tensorflow/core/framework/function_test.cc110
-rw-r--r--tensorflow/core/framework/node_def_util.cc92
-rw-r--r--tensorflow/core/framework/node_def_util.h39
-rw-r--r--tensorflow/core/framework/op_kernel.cc23
-rw-r--r--tensorflow/core/framework/shape_inference.h9
7 files changed, 254 insertions, 170 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index f46bb6e2ed..186095201d 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb_text.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -44,12 +45,11 @@ namespace {
// Otherwise (arg_def is a simple type T), *is_type_list is set to
// false, and *dtypes is set to a single element vector, whose only
// element is T.
-Status ArgNumType(const InstantiateAttrValueMap& attrs,
- const OpDef::ArgDef& arg_def, bool* is_type_list,
- DataTypeVector* dtypes) {
+Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
+ bool* is_type_list, DataTypeVector* dtypes) {
dtypes->clear();
if (!arg_def.type_list_attr().empty()) {
- const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_list_attr());
+ const AttrValue* v = attrs.Find(arg_def.type_list_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ",
arg_def.type_list_attr());
@@ -64,7 +64,7 @@ Status ArgNumType(const InstantiateAttrValueMap& attrs,
*is_type_list = false;
int num = 1;
if (!arg_def.number_attr().empty()) {
- const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr());
+ const AttrValue* v = attrs.Find(arg_def.number_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ", arg_def.type_attr());
}
@@ -77,7 +77,7 @@ Status ArgNumType(const InstantiateAttrValueMap& attrs,
} else if (arg_def.type_attr().empty()) {
dtype = DT_INVALID;
} else {
- const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr());
+ const AttrValue* v = attrs.Find(arg_def.type_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ", arg_def.type_attr());
}
@@ -92,18 +92,17 @@ void AddAttr(const string& name, const T& val, NodeDef* ndef) {
SetAttrValue(val, &((*ndef->mutable_attr())[name]));
}
-Status ValidateSignatureWithAttrs(const OpDef& sig,
- const InstantiateAttrValueMap& attr_values) {
+Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
// attr_values should specify all attrs defined in fdef.
for (const auto& a : sig.attr()) {
- auto const iter = attr_values.find(a.name());
- if (iter == attr_values.end()) {
+ const AttrValue* v = attr_values.Find(a.name());
+ if (!v) {
return errors::NotFound("Attr ", a.name(), " is not found from ",
SummarizeOpDef(sig));
}
- Status status = AttrValueHasType(iter->second, a.type());
+ Status status = AttrValueHasType(*v, a.type());
if (!status.ok()) {
- errors::AppendToMessage(&status, "for attr '", iter->first, "'");
+ errors::AppendToMessage(&status, "for attr '", a.name(), "'");
return status;
}
}
@@ -146,7 +145,7 @@ class FunctionInstantiationHelper {
// Builds index for nodes that can be used as node's input arguments.
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
- const InstantiateAttrValueMap& attr_values) {
+ AttrSlice attr_values) {
bool is_type_list;
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(
@@ -175,8 +174,7 @@ class FunctionInstantiationHelper {
return Status::OK();
}
- Status BuildNodeOutputIndex(const NodeDef& node,
- const InstantiateAttrValueMap& attrs,
+ Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
const int arg_index) {
const OpDef* node_sig = nullptr;
TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
@@ -206,8 +204,7 @@ class FunctionInstantiationHelper {
return Status::OK();
}
- Status InstantiateNode(const NodeDef& fnode,
- const InstantiateAttrValueMap& attrs) {
+ Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
const OpDef* fnode_sig = nullptr;
TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
NodeDef* gnode = AddNode(fnode.name());
@@ -295,7 +292,7 @@ class FunctionInstantiationHelper {
}
Status AddReturnNode(
- const OpDef::ArgDef& ret_def, const InstantiateAttrValueMap& attrs,
+ const OpDef::ArgDef& ret_def, AttrSlice attrs,
const ::tensorflow::protobuf::Map<string, string>& ret_map,
int* ret_index) {
auto ret_iter = ret_map.find(ret_def.name());
@@ -604,7 +601,7 @@ string Print(const GraphDef& gdef) {
Status AddDefaultAttrs(const string& op,
const GetFunctionSignature& get_function,
- InstantiateAttrValueMap* attrs) {
+ AttrValueMap* attrs) {
const OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(get_function(op, &op_def));
AttrSlice attr_slice(attrs);
@@ -620,8 +617,7 @@ Status AddDefaultAttrs(const string& op,
} // end namespace
-Status InstantiateFunction(const FunctionDef& fdef,
- const InstantiateAttrValueMap& attr_values,
+Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
GetFunctionSignature get_function,
InstantiationResult* result) {
VLOG(3) << "Instantiation Function: " << Print(fdef);
@@ -639,19 +635,17 @@ Status InstantiateFunction(const FunctionDef& fdef,
}
}
- auto substitute = [&attr_values](const string& name, AttrValue* val) {
- auto iter = attr_values.find(name);
- if (iter == attr_values.end()) {
- return false;
- } else {
- *val = iter->second;
+ auto substitute = [attr_values](StringPiece name, AttrValue* val) {
+ if (const AttrValue* v = attr_values.Find(name)) {
+ *val = *v;
return true;
}
+ return false;
};
// Makes a copy of all attrs in fdef and substitutes placeholders.
// After this step, every attr is bound to a concrete value.
- std::vector<InstantiateAttrValueMap> node_attrs;
+ std::vector<AttrValueMap> node_attrs;
node_attrs.resize(fdef.node_def_size());
for (int i = 0; i < fdef.node_def_size(); ++i) {
for (auto attr : fdef.node_def(i).attr()) {
@@ -668,7 +662,7 @@ Status InstantiateFunction(const FunctionDef& fdef,
}
for (int i = 0; i < fdef.node_def_size(); ++i) {
- s = helper.BuildNodeOutputIndex(fdef.node_def(i), node_attrs[i],
+ s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
result->gdef.node_size() + i);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
@@ -677,7 +671,7 @@ Status InstantiateFunction(const FunctionDef& fdef,
}
// Emits one gdef.node for each fdef.node_def.
for (int i = 0; i < fdef.node_def_size(); ++i) {
- s = helper.InstantiateNode(fdef.node_def(i), node_attrs[i]);
+ s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
return s;
@@ -748,8 +742,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
return true;
}
-string Canonicalize(const string& funcname,
- const InstantiateAttrValueMap& attrs) {
+string Canonicalize(const string& funcname, AttrSlice attrs) {
std::vector<string> entries;
entries.reserve(attrs.size());
for (auto p : attrs) {
@@ -953,8 +946,7 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
// If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
// Foo's attributes.
const NameAttrList* forward_func_attrs;
- if (!GetNodeAttr(AttrSlice(&ndef.attr()), kFuncAttr, &forward_func_attrs)
- .ok()) {
+ if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
return nullptr;
}
const string& func_name = forward_func_attrs->name();
@@ -981,34 +973,30 @@ FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
return lib;
}
-Status InstantiateFunction(const FunctionDef& fdef,
- InstantiateAttrValueSlice attr_values,
- GetFunctionSignature get_function,
- InstantiationResult* result) {
- InstantiateAttrValueMap m;
- for (const auto& aval : attr_values) {
- m.insert({aval.first, aval.second.proto});
+template <typename T>
+Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
+ const string& attr, T* value) const {
+ const FunctionDef* fdef = GetAttrImpl(ndef);
+ if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
+ return Status::OK();
}
- return InstantiateFunction(fdef, m, std::move(get_function), result);
+ return errors::InvalidArgument("Attr ", attr, " is not defined.");
}
-string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) {
- InstantiateAttrValueMap m;
- for (const auto& aval : attrs) {
- m.insert({aval.first, aval.second.proto});
- }
- return Canonicalize(funcname, m);
+template <typename T>
+Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
+ T* value) const {
+ return GetAttr(node.def(), attr, value);
}
-Status FunctionLibraryRuntime::Instantiate(const string& function_name,
- InstantiateAttrValueSlice attrs,
- Handle* handle) {
- InstantiateAttrValueMap m;
- for (const auto& aval : attrs) {
- m.insert({aval.first, aval.second.proto});
- }
- return Instantiate(function_name, m, handle);
-}
+#define GET_ATTR(T) \
+ template Status FunctionLibraryDefinition::GetAttr(const Node&, \
+ const string&, T*) const; \
+ template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
+ const string&, T*) const;
+GET_ATTR(string)
+GET_ATTR(bool)
+#undef GET_ATTR
void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
if (val.size() >= 2 && val[0] == '$') {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 210e5b949a..188c3855c6 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -36,6 +36,7 @@ class CancellationManager;
class OpKernel;
class ResourceMgr;
class ScopedStepContainer;
+class Node;
// FunctionDefHelper::Create is a convenient helper to construct a
// FunctionDef proto.
@@ -190,11 +191,6 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
// InstantiateFunction calls "get_function" to find signatures of other
// functions and primitive ops.
-// Placeholders in "fdef" is substituted based on "attr_values" here.
-typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap;
-typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
- InstantiateAttrValueSlice;
-
// GetFunctionSignature(func name, opdef) returns OK if the func name is found
// and opdef is filled with a pointer to the corresponding signature
// (a OpDef proto). Otherwise, returns an error.
@@ -206,12 +202,7 @@ struct InstantiationResult {
DataTypeVector ret_types;
GraphDef gdef;
};
-Status InstantiateFunction(const FunctionDef& fdef,
- const InstantiateAttrValueMap& attr_values,
- GetFunctionSignature get_function,
- InstantiationResult* result);
-Status InstantiateFunction(const FunctionDef& fdef,
- InstantiateAttrValueSlice attr_values,
+Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
GetFunctionSignature get_function,
InstantiationResult* result);
@@ -241,9 +232,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
// space. But it may be change as the implementation
// evolves. Therefore, it should not be persisted or compared across
// address spaces.
-string Canonicalize(const string& funcname,
- const InstantiateAttrValueMap& attrs);
-string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs);
+string Canonicalize(const string& funcname, AttrSlice attrs);
// Represents a function call frame. I.e., the data structure used to
// pass arguments to a function and retrieve its results.
@@ -330,9 +319,16 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// Given a node def 'ndef', inspects attributes of the callee
// function to derive the attribute 'value' for 'attr'. Returns OK
// iff the attribute is given by the function's definition.
+ // TODO(irving): Remove; keep only the const Node& version.
template <typename T>
Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
+ // Given a node, inspects attributes of the callee function to derive the
+ // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
+ // function's definition.
+ template <typename T>
+ Status GetAttr(const Node& node, const string& attr, T* value) const;
+
// Returns a proto representation of the state of this function library.
FunctionDefLibrary ToProto() const;
@@ -375,11 +371,8 @@ class FunctionLibraryRuntime {
// Returns OK and fills in "handle" if the instantiation succeeds.
// Otherwise returns an error and "handle" is undefined.
typedef uint64 Handle;
- virtual Status Instantiate(const string& function_name,
- const InstantiateAttrValueMap& attrs,
+ virtual Status Instantiate(const string& function_name, AttrSlice attrs,
Handle* handle) = 0;
- Status Instantiate(const string& function_name,
- InstantiateAttrValueSlice attrs, Handle* handle);
// Returns the function body for the instantiated function given its
// handle 'h'. Returns nullptr if "h" is not found.
@@ -506,17 +499,15 @@ bool RegisterOp(const string& op, Creator func);
Status GetOpGradientCreator(const string& op, Creator* creator);
};
-// Implementation details.
-
-template <typename T>
-Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
- const string& attr, T* value) const {
- const FunctionDef* fdef = GetAttrImpl(ndef);
- if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
- return Status::OK();
- }
- return errors::InvalidArgument("Attr ", attr, " is not defined.");
-}
+// Declare explicit instantiations of GetAttr
+#define GET_ATTR(T) \
+ extern template Status FunctionLibraryDefinition::GetAttr( \
+ const Node&, const string&, T*) const; \
+ extern template Status FunctionLibraryDefinition::GetAttr( \
+ const NodeDef&, const string&, T*) const;
+GET_ATTR(string)
+GET_ATTR(bool)
+#undef GET_ATTR
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
index 07462a575e..c83ecf4e5e 100644
--- a/tensorflow/core/framework/function_test.cc
+++ b/tensorflow/core/framework/function_test.cc
@@ -29,6 +29,24 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+namespace {
+
+// A helper class to make AttrSlice from initializer lists
+class Attrs {
+ public:
+ Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
+ std::pair<string, FunctionDefHelper::AttrValueWrapper>>
+ attrs) {
+ for (const auto& aval : attrs) {
+ map_.insert({aval.first, aval.second.proto});
+ }
+ }
+
+ operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
+
+ private:
+ AttrValueMap map_;
+};
typedef FunctionDefHelper FDH;
@@ -46,8 +64,6 @@ y: A scalar in type T.
)doc");
-static InstantiateAttrValueMap kNoAttrs;
-
TEST(TFunc, SquarePlusOne) {
auto fdef = FDH::Create(
// Name
@@ -81,7 +97,8 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
// Instantiate one with T=float
InstantiationResult result;
- TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
+ TF_ASSERT_OK(
+ InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
const char* e2 = R"P(
(x:float) -> (y:float) {
a = Square[T=float](x)
@@ -126,7 +143,8 @@ ControlDep(x:int32) -> (y:int32) {
// Instantiate one with T=float
InstantiationResult result;
- TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
+ TF_ASSERT_OK(
+ InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
const char* e2 = R"P(
(x:int32) -> (y:int32) {
a = Identity[T=int32](x)
@@ -171,8 +189,7 @@ BackCompat() -> (y:float) {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
- TF_ASSERT_OK(
- InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result));
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
// Should get T=float from Op's default.
const char* e2 = R"P(
() -> (a:float) {
@@ -209,7 +226,7 @@ NTimesT(x:float, y:float) -> (z:float) {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
- TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
const char* e2 = R"P(
(x:float, y:float) -> (a:float) {
a = AddN[N=2, T=float](x, y)
@@ -272,8 +289,8 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
// Instantiate one with T=float
InstantiationResult result;
- TF_ASSERT_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig,
- &result));
+ TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}),
+ GetOpSig, &result));
const char* e2 = R"P(
(x_0:float, x_1:float, x_2:float) -> (y:float) {
a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2)
@@ -315,7 +332,7 @@ ControlDeps(x:float) -> () {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
- TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
const char* e2 = R"P(
(x:float) -> () {
a = One[T=float]() @ x
@@ -395,7 +412,7 @@ Test(i:float) -> (o:float) {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
- TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
const char* e2 = R"P(
(i:float) -> (o:float) {
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
@@ -467,7 +484,7 @@ MySelect(x:float) -> (z:float) {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
- TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
const char* e2 = R"P(
(x:float) -> (z:float) {
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
@@ -488,8 +505,9 @@ TEST(InstantiateErrors, Not_Sufficient_Attrs) {
auto fdef =
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result),
- "Attr T is not found from ");
+ HasError(
+ InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result),
+ "Attr T is not found from ");
}
#if 0 // TODO(josh11b): Enable this test once having an extra attr is an error.
@@ -497,7 +515,7 @@ TEST(InstantiateErrors, Too_Many_Attrs) {
auto fdef =
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, {{"T", DT_INT32}, {"U", DT_FLOAT}},
+ HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}),
GetOpSig, &result),
"Attr U is not found in ");
}
@@ -508,7 +526,7 @@ TEST(InstantiateErrors, AttrValue_Value_Placeholder) {
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
InstantiationResult result;
HasError(
- InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result),
+ InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result),
"AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'");
}
@@ -518,14 +536,15 @@ TEST(InstantiateErrors, Unbounded_Attr) {
{{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}},
});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result),
- "Failed to bind all placeholders");
+ HasError(
+ InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result),
+ "Failed to bind all placeholders");
}
TEST(InstantiateErrors, DupArgs) {
auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Duplicated arg name");
}
@@ -536,7 +555,7 @@ TEST(InstantiateErrors, Dup_Node_Names) {
{{"y"}, "One", {}, {{"T", DT_FLOAT}}},
});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Duplicated ret name");
}
@@ -547,7 +566,7 @@ TEST(InstantiateErrors, Node_Arg_Notfound) {
},
{});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"input z is not found");
}
@@ -557,7 +576,7 @@ TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"input x[0] expected type int32 != float, the type of x[0]");
}
@@ -568,7 +587,7 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) {
{{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}},
});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"input[2] == '^z', is not found.");
}
@@ -579,7 +598,7 @@ TEST(InstantiateErrors, FuncRet_Missing) {
},
{});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Return y missing");
}
@@ -590,7 +609,7 @@ TEST(InstantiateErrors, FuncRet_NotFound) {
},
{{"y", "z"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Return y -> z is not found");
}
@@ -601,7 +620,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) {
},
{{"z", "x:y:0"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Return y missing");
}
@@ -613,7 +632,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) {
// },
// {{"y", "x:y:0"}, {"z", "x:y:0"}});
// InstantiationResult result;
-// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+// HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
// "ret is not found");
// }
@@ -623,7 +642,7 @@ TEST(InstantiateErrors, FuncRet_TypeMismatch) {
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Invalid ret types y : float vs. double\n\tIn function output y");
}
@@ -649,7 +668,7 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
},
{{"y", "y:output"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"type attr not found: out_types");
}
@@ -676,7 +695,7 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
},
{{"y", "y:output"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Invalid ret types");
}
@@ -703,7 +722,7 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
},
{{"y", "y:output"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"input unknown is not found");
}
@@ -724,7 +743,7 @@ TEST(InstantiateErrors, TooManyInputs) {
{{"z", "a:sum:0"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Expected input[2] == 'x' to be a control input.");
}
@@ -745,7 +764,7 @@ TEST(InstantiateErrors, TooFewInputs) {
{{"z", "a:sum:0"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Attempt to access beyond input size: 2 >= 2");
}
@@ -773,7 +792,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray1) {
{{"z", "a:sum:0"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Expected input[1] == 'y' to be a control input.");
}
@@ -801,7 +820,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray2) {
{{"z", "a:sum:0"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"Input a:output too long for inputs");
}
@@ -822,7 +841,7 @@ TEST(InstantiateErrors, TypeMismatch) {
{{"z", "a:sum:0"}});
InstantiationResult result;
- HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
"input inputs[1] expected type float != int32, the type of y[0]");
}
@@ -874,17 +893,17 @@ TEST(FunctionCallFrame, Float_Float_Float) {
}
TEST(Canonicalize, Basic) {
- EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
- {"transpose_a", false},
- {"transpose_b", false}}),
+ EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
+ {"transpose_a", false},
+ {"transpose_b", false}})),
"MatMul[T=float,transpose_a=false,transpose_b=false]");
- EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
- {"transpose_b", false},
- {"transpose_a", false}}),
+ EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
+ {"transpose_b", false},
+ {"transpose_a", false}})),
"MatMul[T=float,transpose_a=false,transpose_b=false]");
- EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE},
- {"transpose_b", true},
- {"transpose_a", false}}),
+ EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE},
+ {"transpose_b", true},
+ {"transpose_a", false}})),
"MatMul[T=double,transpose_a=false,transpose_b=true]");
}
@@ -1148,4 +1167,5 @@ TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) {
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
}
+} // end namespace
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 36c0842924..9b737e1f72 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/scanner.h"
@@ -36,18 +37,23 @@ namespace tensorflow {
const char* const kColocationAttrName = "_class";
const char* const kColocationGroupPrefix = "loc:@";
+AttrSlice::AttrSlice() : ndef_(nullptr) {
+ static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
+ attrs_ = kEmptyAttrValueMap;
+}
+
AttrSlice::AttrSlice(const NodeDef& node_def)
: ndef_(&node_def), attrs_(&ndef_->attr()) {}
AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
-string SummarizeNodeDef(const NodeDef& node_def) {
- string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
+static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
+ string ret;
// We sort the attrs so the output is deterministic.
std::vector<string> attr_names;
- attr_names.reserve(node_def.attr().size());
- for (const auto& attr : node_def.attr()) {
+ attr_names.reserve(attrs.size());
+ for (const auto& attr : attrs) {
attr_names.push_back(attr.first);
}
std::sort(attr_names.begin(), attr_names.end());
@@ -55,20 +61,34 @@ string SummarizeNodeDef(const NodeDef& node_def) {
for (const string& attr_name : attr_names) {
if (!first) strings::StrAppend(&ret, ", ");
first = false;
- auto iter = node_def.attr().find(attr_name);
- strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second));
+ strings::StrAppend(&ret, attr_name, "=",
+ SummarizeAttrValue(*attrs.Find(attr_name)));
}
// Consider the device to be a final attr with name "_device".
- if (!node_def.device().empty()) {
+ if (!device.empty()) {
if (!first) strings::StrAppend(&ret, ", ");
first = false;
- strings::StrAppend(&ret, "_device=\"", node_def.device(), "\"");
+ strings::StrAppend(&ret, "_device=\"", device, "\"");
}
+ return ret;
+}
+
+string AttrSlice::SummarizeNode() const {
+ return ndef_ ? SummarizeNodeDef(*ndef_)
+ : strings::StrCat(
+ "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
+}
+
+string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
+
+string SummarizeNodeDef(const NodeDef& node_def) {
+ string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
+ strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
strings::StrAppend(&ret, "](");
// Output inputs, including control inputs, verbatim.
- first = true;
+ bool first = true;
for (const string& input : node_def.input()) {
if (!first) strings::StrAppend(&ret, ", ");
first = false;
@@ -109,12 +129,28 @@ Status AttrSlice::Find(StringPiece attr_name,
// Skip AttachDef for internal attrs since it is a little bit
// expensive and it is common for them to correctly not be included
// in a NodeDef.
- if (!StringPiece(attr_name).starts_with("_") && ndef_) {
+ if (!attr_name.starts_with("_") && ndef_ != nullptr) {
s = AttachDef(s, *ndef_);
}
return s;
}
+bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
+ if (size() != other.size()) return false;
+
+ for (const auto& attr : *other.attrs_) {
+ auto iter = attrs_->find(attr.first);
+ if (iter == attrs_->end()) return false;
+ // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
+ // TensorProto is a nonunique representation of Tensor. This bug will go
+ // away once AttrSlice switches over to NodeInfo.
+ iter->second.SerializeToString(&scratch->a);
+ attr.second.SerializeToString(&scratch->b);
+ if (scratch->a != scratch->b) return false;
+ }
+ return true;
+}
+
// The ... is to allow the caller to inject some value validation code. Use
// just ; if no additional validation code is needed.
#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
@@ -341,14 +377,14 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
if (StringPiece(input).starts_with("^")) {
seen_control = true;
if (input.find(':') != string::npos) {
- return errors::InvalidArgument("Control input '", input,
- "' must not have ':' in NodeDef: ",
- SummarizeNodeDef(node_def));
+ return errors::InvalidArgument(
+ "Control input '", input,
+ "' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def));
}
} else if (seen_control) {
- return errors::InvalidArgument("Non-control input '", input,
- "' after control input in NodeDef: ",
- SummarizeNodeDef(node_def));
+ return errors::InvalidArgument(
+ "Non-control input '", input,
+ "' after control input in NodeDef: ", SummarizeNodeDef(node_def));
} else {
++num_inputs;
}
@@ -358,8 +394,8 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
for (const auto& attr : op_def.attr()) {
if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
return errors::InvalidArgument("OpDef has duplicate attr name '",
- attr.name(), "': ",
- SummarizeOpDef(op_def));
+ attr.name(),
+ "': ", SummarizeOpDef(op_def));
}
}
for (const auto& attr : node_def.attr()) {
@@ -383,8 +419,9 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
"with your GraphDef-generating binary.).");
}
TF_RETURN_WITH_CONTEXT_IF_ERROR(
- ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ",
- SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def));
+ ValidateAttrValue(attr.second, *iter->second),
+ "; NodeDef: ", SummarizeNodeDef(node_def), "; ",
+ SummarizeOpDef(op_def));
// Keep track of which attr names have (not) been found in the NodeDef.
op_attrs.erase(iter);
}
@@ -431,9 +468,9 @@ Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
} else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
*num = 1;
} else {
- return errors::InvalidArgument("Argument '", arg_def.name(),
- "' incorrectly specified in op definition: ",
- SummarizeOpDef(op_def));
+ return errors::InvalidArgument(
+ "Argument '", arg_def.name(),
+ "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
}
return Status::OK();
}
@@ -465,6 +502,11 @@ Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
return Status::OK();
}
+Status NameRangesForNode(const Node& node, const OpDef& op_def,
+ NameRangeMap* inputs, NameRangeMap* outputs) {
+ return NameRangesForNode(node.def(), op_def, inputs, outputs);
+}
+
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
for (const auto& attr_def : op_def.attr()) {
AttrSlice attrs(*node_def);
@@ -565,4 +607,8 @@ Status AttachDef(const Status& status, const NodeDef& node_def) {
return ret;
}
+Status AttachDef(const Status& status, const Node& node) {
+ return AttachDef(status, node.def());
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 018e4d15f2..1438abdec6 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -29,6 +29,8 @@ limitations under the License.
namespace tensorflow {
+class Node;
+
// Name of the attribute used to encode node colocation constraints.
//
// Nodes can be co-located on the same device. Desire for explicit co-location
@@ -39,8 +41,9 @@ extern const char* const kColocationAttrName;
// String prefix applied to the operation name for colocation constraints.
extern const char* const kColocationGroupPrefix;
-// Produce a human-readable version of a NodeDef that is more concise
+// Produce a human-readable version of a Node or NodeDef that is more concise
// than a text-format proto.
+string SummarizeNode(const Node& node);
string SummarizeNodeDef(const NodeDef& node_def);
typedef protobuf::Map<string, AttrValue> AttrValueMap;
@@ -78,8 +81,11 @@ class AttrSlice {
public:
AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
+ AttrSlice(); // Empty
explicit AttrSlice(const AttrValueMap* a);
+ int size() const { return attrs_->size(); }
+
// Returns the attr with attr_name if found. Otherwise, returns
// nullptr.
const AttrValue* Find(StringPiece attr_name) const;
@@ -88,6 +94,33 @@ class AttrSlice {
// NotFound status.
Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
+ // Helper class to avoid allocations in EqualAttrs.
+ // TODO(irving): Will go away once NodeInfo is used.
+ struct Scratch {
+ string a;
+ string b;
+ };
+
+ // Check if all attrs and attr values match. Does not take defaults into
+ // account.
+ //
+ // TODO(irving): There is a bug in this routine inherited from its
+ // OptimizerCSE::EqualAttrs precedecessor. The same tensor attr can be
+ // represented in more than one way as an AttrValue, since TensorProto is
+ // not 1-1. This bug will go away once I replace everything with NodeInfo,
+ // which stores a Tensor object directly. The Scratch object will also go
+ // away.
+ bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
+
+ // If this AttrSlice has an attached NodeDef, summarize it. This is for
+ // error messages only: we intentionally do not provide direct access to the
+ // NodeDef, since it is not always there.
+ string SummarizeNode() const;
+
+ // Iteration over all attrs
+ AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
+ AttrValueMap::const_iterator end() const { return attrs_->end(); }
+
private:
const NodeDef* ndef_;
const AttrValueMap* attrs_;
@@ -183,9 +216,12 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
// corresponding input/output index range. For example,
// input "foo" corresponds to input indices
// [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
+// TODO(irving): Remove the NodeDef version; keep only the Node version.
typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap;
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs);
+Status NameRangesForNode(const Node& node, const OpDef& op_def,
+ NameRangeMap* inputs, NameRangeMap* outputs);
// Adds default values to *node_def for unspecified attrs from op_def.
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
@@ -206,6 +242,7 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
// Returns "status" with kernel's NodeDef attached as additional text
// in the error message.
Status AttachDef(const Status& status, const NodeDef& node_def);
+Status AttachDef(const Status& status, const Node& node);
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 422ee80720..6c3917c686 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -842,13 +843,10 @@ bool InTypeList(DataType dt, const AttrValue& type_list) {
return false;
}
-// Returns whether the attrs in the NodeDef satisfy the constraints in
-// the kernel_def. Returns an error if attrs in kernel_def are not
-// found, or have a mismatching type.
-Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def,
- bool* match) {
+// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
+// an error if attrs in kernel_def are not found, or have a mismatching type.
+Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
*match = false;
- AttrSlice attrs(node_def);
for (const auto& constraint : kernel_def.constraint()) {
if (constraint.allowed_values().list().type_size() == 0) {
return errors::Unimplemented(
@@ -872,7 +870,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def,
"' that has value '", SummarizeAttrValue(*found),
"' that does not have type 'type' or 'list(type)' in NodeDef "
"'",
- SummarizeNodeDef(node_def), "'");
+ attrs.SummarizeNode(), "'");
}
for (int t : found->list().type()) {
@@ -885,7 +883,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def,
} else {
return errors::InvalidArgument(
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
- constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def),
+ constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
"', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
}
}
@@ -895,6 +893,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def,
static const StringPiece kKernelAttr("_kernel");
+// TODO(irving): Replace with const Node& version below.
Status FindKernelRegistration(const DeviceType& device_type,
const NodeDef& node_def,
const KernelRegistration** reg,
@@ -927,8 +926,16 @@ Status FindKernelRegistration(const DeviceType& device_type,
return Status::OK();
}
+Status FindKernelRegistration(const DeviceType& device_type, const Node& node,
+ const KernelRegistration** reg,
+ bool* was_attr_mismatch) {
+ return FindKernelRegistration(device_type, node.def(), reg,
+ was_attr_mismatch);
+}
+
} // namespace
+// TODO(irving): Change const NodeDef& to const Node&
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
const KernelDef** def, string* kernel_class_name) {
const KernelRegistration* reg = nullptr;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index cebadcc5b4..d064a8ec4d 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -184,8 +184,8 @@ class InferenceContext {
}
#ifndef NDEBUG
for (int i = 0; i < num_outputs(); ++i) {
- DCHECK(output(i).IsSet()) << i << " for " << node_def().name()
- << " of type " << node_def().op();
+ DCHECK(output(i).IsSet())
+ << i << " for " << node_def_.name() << " of type " << node_def_.op();
}
#endif // NDEBUG
return s;
@@ -394,11 +394,6 @@ class InferenceContext {
// the value.
Status MakeDimForScalarInput(int idx, DimensionHandle* out);
- // Returns the NodeDef. The returned reference does not outlive the
- // InferenceContext, and it should not be used after InferenceContext is
- // destroyed.
- const NodeDef& node_def() { return node_def_; }
-
// Look up the attr for the NodeDef being evaluated with name attr_name and
// set *value to its value. If no attr with attr_name is found in def(), or
// the attr does not have a matching type, a non-ok status will be returned.