aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD17
-rw-r--r--tensorflow/core/common_runtime/function.cc10
-rw-r--r--tensorflow/core/common_runtime/function_test.cc4
-rw-r--r--tensorflow/core/framework/function.cc35
-rw-r--r--tensorflow/core/framework/function.h17
-rw-r--r--tensorflow/core/framework/function_test.cc35
-rw-r--r--tensorflow/core/framework/graph_def_util.cc29
-rw-r--r--tensorflow/core/framework/graph_def_util_test.cc42
-rw-r--r--tensorflow/core/framework/memory_types.cc7
-rw-r--r--tensorflow/core/framework/node_def_builder.cc9
-rw-r--r--tensorflow/core/framework/node_def_builder_test.cc4
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc6
-rw-r--r--tensorflow/core/framework/op.cc100
-rw-r--r--tensorflow/core/framework/op.h62
-rw-r--r--tensorflow/core/framework/op_compatibility_test.cc365
-rw-r--r--tensorflow/core/framework/op_def_builder.cc33
-rw-r--r--tensorflow/core/framework/op_def_builder.h33
-rw-r--r--tensorflow/core/framework/op_def_builder_test.cc43
-rw-r--r--tensorflow/core/framework/op_def_util_test.cc6
-rw-r--r--tensorflow/core/framework/op_kernel.cc25
-rw-r--r--tensorflow/core/framework/shape_inference.cc9
-rw-r--r--tensorflow/core/framework/shape_inference.h3
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc14
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc153
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.h56
-rw-r--r--tensorflow/core/framework/shape_inference_testutil_test.cc154
-rw-r--r--tensorflow/core/graph/graph.cc5
-rw-r--r--tensorflow/core/graph/validate.cc4
-rw-r--r--tensorflow/core/graph/validate_test.cc8
-rw-r--r--tensorflow/core/ops/math_ops.cc15
-rw-r--r--tensorflow/core/ops/math_ops_test.cc44
31 files changed, 962 insertions, 385 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 37d9062612..dc0a2fc157 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -543,6 +543,7 @@ cc_library(
"common_runtime/kernel_benchmark_testlib.h",
"framework/fake_input.h",
"framework/function_testlib.h",
+ "framework/shape_inference_testutil.h",
"framework/tensor_testutil.h",
"graph/testlib.h",
# TODO(josh11b): Drop this once users are depending on
@@ -559,6 +560,7 @@ cc_library(
":lib",
":proto_text",
":protos_all_cc",
+ ":shape_inference_testutil",
":tensor_testutil",
":test",
"//tensorflow/core/kernels:constant_op",
@@ -748,6 +750,8 @@ filegroup(
srcs = [
"//tensorflow/core:framework/fake_input.cc",
"//tensorflow/core:framework/fake_input.h",
+ "//tensorflow/core:framework/shape_inference_testutil.cc",
+ "//tensorflow/core:framework/shape_inference_testutil.h",
"//tensorflow/core:framework/tensor_testutil.cc",
"//tensorflow/core:framework/tensor_testutil.h",
"//tensorflow/core:platform/test.h",
@@ -1212,6 +1216,19 @@ cc_library(
],
)
+cc_library(
+ name = "shape_inference_testutil",
+ testonly = 1,
+ srcs = ["framework/shape_inference_testutil.cc"],
+ hdrs = ["framework/shape_inference_testutil.h"],
+ copts = tf_copts(),
+ deps = [
+ ":framework",
+ ":lib",
+ ":test",
+ ],
+)
+
# Main program for tests
cc_library(
name = "test_main",
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index c243b7fc18..1ccf66ed34 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -301,9 +301,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
lib_def_(lib_def),
optimizer_(optimizer_options) {
get_func_sig_ = [this](const string& op, const OpDef** sig) {
- Status s;
- *sig = lib_def_->LookUp(op, &s);
- return s;
+ return lib_def_->LookUpOpDef(op, sig);
};
create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
return CreateKernel(ndef, kernel);
@@ -689,9 +687,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
}
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
- Status s;
- auto sig = lib_def_->LookUp(func, &s);
- return s.ok() && sig->is_stateful();
+ const OpDef* op_def;
+ const Status s = lib_def_->LookUpOpDef(func, &op_def);
+ return s.ok() && op_def->is_stateful();
}
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 13c8de8384..d5fc71195f 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -38,9 +38,7 @@ namespace tensorflow {
typedef FunctionDefHelper FDH;
Status GetOpSig(const string& op, const OpDef** sig) {
- Status s;
- *sig = OpRegistry::Global()->LookUp(op, &s);
- return s;
+ return OpRegistry::Global()->LookUpOpDef(op, sig);
}
void FunctionTestSchedClosure(std::function<void()> fn) {
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index be6dfd600c..0d437346a4 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -708,14 +708,19 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
FunctionLibraryDefinition::FunctionLibraryDefinition(
const FunctionLibraryDefinition& other)
- : function_defs_(other.function_defs_), func_grad_(other.func_grad_) {}
+ : func_grad_(other.func_grad_) {
+ for (const auto& it : other.function_defs_) {
+ TF_CHECK_OK(AddFunctionDef(it.second->fdef));
+ }
+}
FunctionLibraryDefinition::FunctionLibraryDefinition(
const FunctionDefLibrary& def_lib)
: function_defs_(def_lib.function_size()) {
for (const auto& fdef : def_lib.function()) {
// The latter function definition wins.
- function_defs_[fdef.signature().name()] = fdef;
+ auto& ptr = function_defs_[fdef.signature().name()];
+ ptr.reset(new FunctionDefAndOpRegistration(fdef));
}
for (const auto& grad : def_lib.gradient()) {
func_grad_[grad.function_name()] = grad.gradient_func();
@@ -729,16 +734,18 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
if (iter == function_defs_.end()) {
return nullptr;
} else {
- return &iter->second;
+ return &iter->second->fdef;
}
}
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
- if (!function_defs_.insert({fdef.signature().name(), fdef}).second) {
+ auto& ptr = function_defs_[fdef.signature().name()];
+ if (ptr != nullptr) {
return errors::InvalidArgument("Function with name: ",
fdef.signature().name(),
" already exists in function library.");
}
+ ptr.reset(new FunctionDefAndOpRegistration(fdef));
return Status::OK();
}
@@ -746,19 +753,20 @@ string FunctionLibraryDefinition::FindGradient(const string& func) const {
return gtl::FindWithDefault(func_grad_, func, "");
}
-const OpDef* FunctionLibraryDefinition::LookUp(const string& op,
- Status* status) const {
- auto fdef = Find(op);
- if (fdef != nullptr) {
- return &(fdef->signature());
+Status FunctionLibraryDefinition::LookUp(
+ const string& op, const OpRegistrationData** op_reg_data) const {
+ auto iter = function_defs_.find(op);
+ if (iter != function_defs_.end()) {
+ *op_reg_data = &iter->second->op_registration_data;
+ return Status::OK();
}
- return OpRegistry::Global()->LookUp(op, status);
+ return OpRegistry::Global()->LookUp(op, op_reg_data);
}
FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
FunctionDefLibrary lib;
for (const auto& f : function_defs_) {
- *lib.add_function() = f.second;
+ *lib.add_function() = f.second->fdef;
}
for (const auto& g : func_grad_) {
GradientDef* gd = lib.add_gradient();
@@ -845,7 +853,10 @@ FunctionDef FunctionDefHelper::Define(const string& name,
for (const auto& a : arg_def) b.Input(a);
for (const auto& r : ret_def) b.Output(r);
for (const auto& a : attr_def) b.Attr(a);
- TF_CHECK_OK(b.Finalize(fdef.mutable_signature()));
+
+ OpRegistrationData op_reg_data;
+ TF_CHECK_OK(b.Finalize(&op_reg_data));
+ fdef.mutable_signature()->Swap(&op_reg_data.op_def);
for (const auto& n : node_def) {
*(fdef.add_node()) = n.ToProto();
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 0b05aa4a8d..708861fb9d 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -277,14 +277,25 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
//
// If "op" is defined in the library, returns its signature.
// Otherwise, assume "op" is a primitive op and returns its op
- // signature.
- const OpDef* LookUp(const string& op, Status* status) const override;
+ // signature and shape inference function.
+ Status LookUp(const string& op_type_name,
+ const OpRegistrationData** op_reg_data) const override;
// Returns a proto representation of the state of this function library.
FunctionDefLibrary ToProto() const;
private:
- std::unordered_map<string, FunctionDef> function_defs_;
+ // TODO(cwhipkey): support shape functions in FunctionDefLibrary.
+ struct FunctionDefAndOpRegistration {
+ FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
+ : fdef(fdef_in), op_registration_data(fdef.signature()) {}
+
+ FunctionDef fdef;
+ OpRegistrationData op_registration_data;
+ };
+
+ std::unordered_map<string, std::unique_ptr<FunctionDefAndOpRegistration>>
+ function_defs_;
std::unordered_map<string, string> func_grad_;
};
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
index 5849051062..d7e967ec3d 100644
--- a/tensorflow/core/framework/function_test.cc
+++ b/tensorflow/core/framework/function_test.cc
@@ -33,9 +33,7 @@ namespace tensorflow {
typedef FunctionDefHelper FDH;
Status GetOpSig(const string& op, const OpDef** sig) {
- Status s;
- *sig = OpRegistry::Global()->LookUp(op, &s);
- return s;
+ return OpRegistry::Global()->LookUpOpDef(op, sig);
}
REGISTER_OP("One")
@@ -643,12 +641,12 @@ TEST(FunctionLibraryDefinitionTest, LookUp) {
*proto.add_function() = test::function::XTimesTwo();
FunctionLibraryDefinition lib_def(proto);
- Status s;
- EXPECT_EQ(lib_def.LookUp("XTimes16", &s), nullptr);
+ const OpDef* op_def;
+ EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok());
- auto found = lib_def.LookUp("XTimesTwo", &s);
- ASSERT_NE(found, nullptr);
- EXPECT_EQ(found->DebugString(),
+ TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def));
+ ASSERT_NE(op_def, nullptr);
+ EXPECT_EQ(op_def->DebugString(),
test::function::XTimesTwo().signature().DebugString());
}
@@ -662,14 +660,15 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
// Test lookup of first function.
- Status s;
- auto first = lib_def.LookUp("XTimesTwo", &s);
+ const OpDef* first;
+ TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first));
ASSERT_NE(first, nullptr);
EXPECT_EQ(first->DebugString(),
test::function::XTimesTwo().signature().DebugString());
// Test lookup of second function.
- auto second = lib_def.LookUp("WXPlusB", &s);
+ const OpDef* second;
+ TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second));
ASSERT_NE(second, nullptr);
EXPECT_EQ(second->DebugString(),
test::function::WXPlusB().signature().DebugString());
@@ -689,18 +688,14 @@ TEST(FunctionLibraryDefinitionTest, ToProto) {
FunctionLibraryDefinition lib_def2(proto2);
// Test that the first function exists in both libraries.
- Status s;
- auto f1 = lib_def1.LookUp("XTimesTwo", &s);
- TF_EXPECT_OK(s);
- auto f2 = lib_def1.LookUp("XTimesTwo", &s);
- TF_EXPECT_OK(s);
+ const OpDef *f1, *f2, *f3, *f4;
+ TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1));
+ TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2));
EXPECT_EQ(f1->DebugString(), f2->DebugString());
// Test that the second function exists in both libraries.
- auto f3 = lib_def1.LookUp("WXPlusB", &s);
- TF_EXPECT_OK(s);
- auto f4 = lib_def1.LookUp("WXPlusB", &s);
- TF_EXPECT_OK(s);
+ TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3));
+ TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4));
EXPECT_EQ(f3->DebugString(), f4->DebugString());
}
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
index e1e0729e8d..35b765cc51 100644
--- a/tensorflow/core/framework/graph_def_util.cc
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -56,17 +56,14 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
node_offset, " with total nodes in graph: ", graph_def->node_size());
}
- Status s;
for (int i = node_offset; i < graph_def->node_size(); ++i) {
NodeDef* node_def = graph_def->mutable_node(i);
- const OpDef* op_def = op_registry.LookUp(node_def->op(), &s);
- if (!s.ok()) {
- return s;
- }
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def));
AddDefaultsToNodeDef(*op_def, node_def);
}
- return s;
+ return Status::OK();
}
Status RemoveNewDefaultAttrsFromGraphDef(
@@ -77,12 +74,13 @@ Status RemoveNewDefaultAttrsFromGraphDef(
std::vector<string> to_remove;
for (int n = 0; n < graph_def->node_size(); ++n) {
NodeDef* node_def = graph_def->mutable_node(n);
- const OpDef* producer_op_def =
- producer_op_registry.LookUp(node_def->op(), &s);
- if (!s.ok()) return s;
- const OpDef* consumer_op_def =
- consumer_op_registry.LookUp(node_def->op(), &s);
- if (!s.ok()) return s;
+ const OpDef* producer_op_def;
+ const OpDef* consumer_op_def;
+
+ TF_RETURN_IF_ERROR(
+ producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
+ TF_RETURN_IF_ERROR(
+ consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
for (const auto& attr : node_def->attr()) {
// If the attr is not in consumer_op_def and doesn't start with '_'...
@@ -172,13 +170,12 @@ Status StrippedOpListForGraph(const GraphDef& graph_def,
OpsUsedByGraph(graph_def, &used_ops);
// Build the stripped op list in sorted order, ignoring functions.
- Status status;
stripped_op_list->clear_op();
for (const string& op_name : used_ops) {
- const OpDef* op = op_registry.LookUp(op_name, &status);
- if (!op) return status;
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
OpDef* stripped_op = stripped_op_list->add_op();
- stripped_op->CopyFrom(*op);
+ stripped_op->CopyFrom(*op_def);
RemoveDescriptionsFromOpDef(stripped_op);
}
return Status::OK();
diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc
index b4df868193..98f6e9b89b 100644
--- a/tensorflow/core/framework/graph_def_util_test.cc
+++ b/tensorflow/core/framework/graph_def_util_test.cc
@@ -27,12 +27,19 @@ limitations under the License.
namespace tensorflow {
namespace {
+Status FinalizeOpDef(OpDefBuilder b, OpDef* op_def) {
+ OpRegistrationData op_reg_data;
+ const Status s = b.Finalize(&op_reg_data);
+ *op_def = op_reg_data.op_def;
+ return s;
+}
+
// Producer and consumer have default for an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
OpList op_list;
- TF_ASSERT_OK(OpDefBuilder("NoChangeWithDefault")
- .Attr("a: int = 12")
- .Finalize(op_list.add_op()));
+ TF_ASSERT_OK(
+ FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"),
+ op_list.add_op()));
OpListOpRegistry registry(&op_list);
GraphDef graph_def;
@@ -51,9 +58,8 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
// Producer and consumer both have an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
OpList op_list;
- TF_ASSERT_OK(OpDefBuilder("NoChangeNoDefault")
- .Attr("a: int")
- .Finalize(op_list.add_op()));
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"),
+ op_list.add_op()));
OpListOpRegistry registry(&op_list);
GraphDef graph_def;
@@ -75,13 +81,13 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
// attr removed from graph (and so able to be consumed).
TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
OpList consumer_op_list;
- TF_ASSERT_OK(OpDefBuilder("UsesDefault").Finalize(consumer_op_list.add_op()));
+ TF_ASSERT_OK(
+ FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
OpListOpRegistry consumer_registry(&consumer_op_list);
OpList producer_op_list;
- TF_ASSERT_OK(OpDefBuilder("UsesDefault")
- .Attr("a: int = 17")
- .Finalize(producer_op_list.add_op()));
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
+ producer_op_list.add_op()));
OpListOpRegistry producer_registry(&producer_op_list);
GraphDef produced_graph_def;
@@ -107,14 +113,14 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
// graph unchanged (but not able to be consumed by consumer).
TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
OpList consumer_op_list;
- TF_ASSERT_OK(
- OpDefBuilder("ChangedFromDefault").Finalize(consumer_op_list.add_op()));
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
+ consumer_op_list.add_op()));
OpListOpRegistry consumer_registry(&consumer_op_list);
OpList producer_op_list;
- TF_ASSERT_OK(OpDefBuilder("ChangedFromDefault")
- .Attr("a: int = 17")
- .Finalize(producer_op_list.add_op()));
+ TF_ASSERT_OK(
+ FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
+ producer_op_list.add_op()));
OpListOpRegistry producer_registry(&producer_op_list);
GraphDef produced_graph_def;
@@ -136,11 +142,13 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
// Attrs starting with underscores should not be removed.
TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
OpList consumer_op_list;
- TF_ASSERT_OK(OpDefBuilder("Underscore").Finalize(consumer_op_list.add_op()));
+ TF_ASSERT_OK(
+ FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op()));
OpListOpRegistry consumer_registry(&consumer_op_list);
OpList producer_op_list;
- TF_ASSERT_OK(OpDefBuilder("Underscore").Finalize(producer_op_list.add_op()));
+ TF_ASSERT_OK(
+ FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op()));
// Add the _underscore attr manually since OpDefBuilder would complain
OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
attr->set_name("_underscore");
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index 776bd5e2fc..d538228494 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -83,13 +83,12 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
MemoryTypeVector* inp_mtypes,
MemoryTypeVector* out_mtypes) {
// Look up the Op registered for this op name.
- Status status;
- const OpDef* op_def = op_registry->LookUp(ndef.op(), &status);
- if (op_def == nullptr) return status;
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(ndef.op(), &op_def));
// Look up the Kernel registered for this node def.
const KernelDef* kdef = nullptr;
- status =
+ Status status =
FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
if (!status.ok() || HasTypeList(*op_def)) {
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
index 47b9fa5b9e..9385d1266a 100644
--- a/tensorflow/core/framework/node_def_builder.cc
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -38,13 +38,12 @@ void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry) {
node_def_.set_name(name.ToString());
- Status status;
- op_def_ = op_registry->LookUp(op_name.ToString(), &status);
- if (op_def_ == nullptr) {
+ const Status status = op_registry->LookUpOpDef(op_name.ToString(), &op_def_);
+ if (status.ok()) {
+ Initialize();
+ } else {
errors_.push_back(status.error_message());
inputs_specified_ = 0;
- } else {
- Initialize();
}
}
diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc
index d2b29397cd..861ad9d7d0 100644
--- a/tensorflow/core/framework/node_def_builder_test.cc
+++ b/tensorflow/core/framework/node_def_builder_test.cc
@@ -32,7 +32,9 @@ class NodeDefBuilderTest : public ::testing::Test {
protected:
// Specify an OpDef via an OpDefBuilder.
void Op(const OpDefBuilder& op_def_builder) {
- TF_EXPECT_OK(op_def_builder.Finalize(&op_def_));
+ OpRegistrationData op_reg_data;
+ TF_EXPECT_OK(op_def_builder.Finalize(&op_reg_data));
+ op_def_ = op_reg_data.op_def;
}
// Resets builder_ with a new NodeDefBuilder using the Op from the last call
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index bc02f39d30..b60709c5a4 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -28,9 +28,9 @@ namespace tensorflow {
namespace {
OpDef ToOpDef(const OpDefBuilder& builder) {
- OpDef op_def;
- TF_EXPECT_OK(builder.Finalize(&op_def));
- return op_def;
+ OpRegistrationData op_reg_data;
+ TF_EXPECT_OK(builder.Finalize(&op_reg_data));
+ return op_reg_data.op_def;
}
NodeDef ToNodeDef(const string& text) {
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index bbe0513814..2ea735a790 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -32,38 +32,51 @@ namespace tensorflow {
OpRegistryInterface::~OpRegistryInterface() {}
+Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
+ const OpDef** op_def) const {
+ *op_def = nullptr;
+ const OpRegistrationData* op_reg_data = nullptr;
+ TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
+ *op_def = &op_reg_data->op_def;
+ return Status::OK();
+}
+
OpRegistry::OpRegistry() : initialized_(false) {}
-void OpRegistry::Register(const OpDef& op_def) {
+OpRegistry::~OpRegistry() {
+ for (const auto& e : registry_) delete e.second;
+}
+
+void OpRegistry::Register(std::unique_ptr<OpRegistrationData> op_reg_data) {
+ OpRegistrationData* raw_ptr = op_reg_data.get();
+
mutex_lock lock(mu_);
if (initialized_) {
- TF_QCHECK_OK(RegisterAlreadyLocked(op_def)) << "Attempting to register: "
- << SummarizeOpDef(op_def);
+ TF_QCHECK_OK(RegisterAlreadyLocked(std::move(op_reg_data)));
} else {
- deferred_.push_back(op_def);
+ deferred_.push_back(std::move(op_reg_data));
}
if (watcher_) {
- watcher_(op_def);
+ watcher_(raw_ptr->op_def);
}
}
-const OpDef* OpRegistry::LookUp(const string& op_type_name,
- Status* status) const {
- const OpDef* op_def = nullptr;
+Status OpRegistry::LookUp(const string& op_type_name,
+ const OpRegistrationData** op_reg_data) const {
+ *op_reg_data = nullptr;
+ const OpRegistrationData* res = nullptr;
+
bool first_call = false;
{ // Scope for lock.
mutex_lock lock(mu_);
first_call = CallDeferred();
- op_def = gtl::FindWithDefault(registry_, op_type_name, nullptr);
+ res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
// Note: Can't hold mu_ while calling Export() below.
}
if (first_call) {
TF_QCHECK_OK(ValidateKernelRegistrations(*this));
}
- if (op_def == nullptr) {
- status->Update(
- errors::NotFound("Op type not registered '", op_type_name, "'"));
- VLOG(1) << status->ToString();
+ if (res == nullptr) {
static bool first_unregistered = true;
if (first_unregistered) {
OpList op_list;
@@ -74,15 +87,20 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name,
}
first_unregistered = false;
}
+ Status status =
+ errors::NotFound("Op type not registered '", op_type_name, "'");
+ VLOG(1) << status.ToString();
+ return status;
}
- return op_def;
+ *op_reg_data = res;
+ return Status::OK();
}
void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
mutex_lock lock(mu_);
CallDeferred();
- for (auto p : registry_) {
- op_defs->push_back(*p.second);
+ for (const auto& p : registry_) {
+ op_defs->push_back(p.second->op_def);
}
}
@@ -100,8 +118,8 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const {
mutex_lock lock(mu_);
CallDeferred();
- std::vector<std::pair<string, const OpDef*>> sorted(registry_.begin(),
- registry_.end());
+ std::vector<std::pair<string, const OpRegistrationData*>> sorted(
+ registry_.begin(), registry_.end());
std::sort(sorted.begin(), sorted.end());
auto out = ops->mutable_op();
@@ -110,7 +128,7 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const {
for (const auto& item : sorted) {
if (include_internal || !StringPiece(item.first).starts_with("_")) {
- *out->Add() = *item.second;
+ *out->Add() = item.second->op_def;
}
}
}
@@ -128,23 +146,23 @@ string OpRegistry::DebugString(bool include_internal) const {
bool OpRegistry::CallDeferred() const {
if (initialized_) return false;
initialized_ = true;
- for (const auto& op_def : deferred_) {
- TF_QCHECK_OK(RegisterAlreadyLocked(op_def)) << "Attempting to register: "
- << SummarizeOpDef(op_def);
+ for (int i = 0; i < deferred_.size(); ++i) {
+ TF_QCHECK_OK(RegisterAlreadyLocked(std::move(deferred_[i])));
}
deferred_.clear();
return true;
}
-Status OpRegistry::RegisterAlreadyLocked(const OpDef& def) const {
- TF_RETURN_IF_ERROR(ValidateOpDef(def));
+Status OpRegistry::RegisterAlreadyLocked(
+ std::unique_ptr<OpRegistrationData> op_reg_data) const {
+ TF_RETURN_IF_ERROR(ValidateOpDef(op_reg_data->op_def));
- std::unique_ptr<OpDef> copy(new OpDef(def));
- if (gtl::InsertIfNotPresent(&registry_, def.name(), copy.get())) {
- copy.release(); // Ownership transferred to op_registry
+ if (gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
+ op_reg_data.get())) {
+ op_reg_data.release(); // Ownership transferred to op_registry
return Status::OK();
} else {
- return errors::AlreadyExists("Op with name ", def.name());
+ return errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
}
}
@@ -158,19 +176,25 @@ OpRegistry* OpRegistry::Global() {
OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
for (const OpDef& op_def : op_list->op()) {
- index_[op_def.name()] = &op_def;
+ auto* op_reg_data = new OpRegistrationData();
+ op_reg_data->op_def = op_def;
+ index_[op_def.name()] = op_reg_data;
}
}
-const OpDef* OpListOpRegistry::LookUp(const string& op_type_name,
- Status* status) const {
+OpListOpRegistry::~OpListOpRegistry() {
+ for (const auto& e : index_) delete e.second;
+}
+
+Status OpListOpRegistry::LookUp(const string& op_type_name,
+ const OpRegistrationData** op_reg_data) const {
auto iter = index_.find(op_type_name);
if (iter == index_.end()) {
- status->Update(
- errors::NotFound("Op type not registered '", op_type_name, "'"));
- return nullptr;
+ *op_reg_data = nullptr;
+ return errors::NotFound("Op type not registered '", op_type_name, "'");
}
- return iter->second;
+ *op_reg_data = iter->second;
+ return Status::OK();
}
// Other registration ---------------------------------------------------------
@@ -178,9 +202,9 @@ const OpDef* OpListOpRegistry::LookUp(const string& op_type_name,
namespace register_op {
OpDefBuilderReceiver::OpDefBuilderReceiver(
const OpDefBuilderWrapper<true>& wrapper) {
- OpDef op_def;
- wrapper.builder().Finalize(&op_def);
- OpRegistry::Global()->Register(op_def);
+ std::unique_ptr<OpRegistrationData> data(new OpRegistrationData);
+ wrapper.builder().Finalize(data.get());
+ OpRegistry::Global()->Register(std::move(data));
}
} // namespace register_op
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 2a99597721..cd484dacd4 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/selective_registration.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -42,11 +43,14 @@ class OpRegistryInterface {
public:
virtual ~OpRegistryInterface();
- // Returns nullptr and sets *status if no OpDef is registered under that
- // name, otherwise returns the registered OpDef.
+ // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
+ // registered under that name, otherwise returns the registered OpDef.
// Caller must not delete the returned pointer.
- virtual const OpDef* LookUp(const string& op_type_name,
- Status* status) const = 0;
+ virtual Status LookUp(const string& op_type_name,
+ const OpRegistrationData** op_reg_data) const = 0;
+
+ // Shorthand for calling LookUp to get the OpDef.
+ Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
};
// The standard implementation of OpRegistryInterface, along with a
@@ -62,17 +66,13 @@ class OpRegistryInterface {
class OpRegistry : public OpRegistryInterface {
public:
OpRegistry();
- ~OpRegistry() override {}
+ ~OpRegistry() override;
- // Calls func() and registers the returned OpDef. Since Register()
- // is normally called during program initialization (before main()),
- // we defer calling func() until the first call to LookUp() or
- // Export() (if one of those has already been called, func() is
- // called immediately).
- void Register(const OpDef& op_def);
+ // Calls watcher and registers the passed OpDef.
+ void Register(std::unique_ptr<OpRegistrationData> op_data);
- const OpDef* LookUp(const string& op_type_name,
- Status* status) const override;
+ Status LookUp(const string& op_type_name,
+ const OpRegistrationData** op_reg_data) const override;
// Fills *ops with all registered OpDefs (except those with names
// starting with '_' if include_internal == false).
@@ -111,15 +111,19 @@ class OpRegistry : public OpRegistryInterface {
// time it is called.
bool CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
- // Add 'def' to the registry. On failure, or if there is already an
- // OpDef with that name registered, returns a non-okay status.
- Status RegisterAlreadyLocked(const OpDef& def) const
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ // Add 'def' to the registry with additional data 'data'. On failure, or if
+ // there is already an OpDef with that name registered, returns a non-okay
+ // status.
+ Status RegisterAlreadyLocked(std::unique_ptr<OpRegistrationData> op_data)
+ const EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutable mutex mu_;
// Functions in deferred_ may only be called with mu_ held.
- mutable std::vector<OpDef> deferred_ GUARDED_BY(mu_);
- mutable std::unordered_map<string, OpDef*> registry_ GUARDED_BY(mu_);
+ mutable std::vector<std::unique_ptr<OpRegistrationData>> deferred_
+ GUARDED_BY(mu_);
+ // Values are owned.
+ mutable std::unordered_map<string, const OpRegistrationData*> registry_
+ GUARDED_BY(mu_);
mutable bool initialized_ GUARDED_BY(mu_);
// Registry watcher.
@@ -127,16 +131,21 @@ class OpRegistry : public OpRegistryInterface {
};
// An adapter to allow an OpList to be used as an OpRegistryInterface.
+//
+// Note that shape inference functions are not passed in to OpListOpRegistry, so
+// it will return an unusable shape inference function for every op it supports;
+// therefore, it should only be used in contexts where this is okay.
class OpListOpRegistry : public OpRegistryInterface {
public:
// Does not take ownership of op_list, *op_list must outlive *this.
OpListOpRegistry(const OpList* op_list);
- ~OpListOpRegistry() override {}
- const OpDef* LookUp(const string& op_type_name,
- Status* status) const override;
+ ~OpListOpRegistry() override;
+ Status LookUp(const string& op_type_name,
+ const OpRegistrationData** op_reg_data) const override;
private:
- std::unordered_map<string, const OpDef*> index_;
+ // Values are owned.
+ std::unordered_map<string, const OpRegistrationData*> index_;
};
// Treats 'registry_ptr' as a pointer to OpRegistry, and calls
@@ -217,6 +226,10 @@ class OpDefBuilderWrapper<true> {
builder_.Doc(text);
return *this;
}
+ OpDefBuilderWrapper<true>& SetShapeFn(const OpShapeInferenceFn& fn) {
+ builder_.SetShapeFn(fn);
+ return *this;
+ }
const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
private:
@@ -237,6 +250,9 @@ class OpDefBuilderWrapper<false> {
OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
+ OpDefBuilderWrapper<false>& SetShapeFn(const OpShapeInferenceFn& fn) {
+ return *this;
+ }
};
struct OpDefBuilderReceiver {
diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc
index 9cb1ead451..fd6a988563 100644
--- a/tensorflow/core/framework/op_compatibility_test.cc
+++ b/tensorflow/core/framework/op_compatibility_test.cc
@@ -40,11 +40,9 @@ class TestKernel : public OpKernel {
class OpCompatibilityTest : public OpsTestBase {
protected:
const OpDef* RegisteredOpDef() {
- Status status;
- const OpDef* new_op_def =
- OpRegistry::Global()->LookUp(node_def()->op(), &status);
- TF_CHECK_OK(status);
- return new_op_def;
+ const OpDef* op_def;
+ TF_CHECK_OK(OpRegistry::Global()->LookUpOpDef(node_def()->op(), &op_def));
+ return op_def;
}
void ExpectSuccess(const OpDef& old_op_def) {
@@ -170,11 +168,11 @@ REGISTER_OP("AddAttr").Output("ndef: string").Attr("a: int = 42");
REGISTER_KERNEL_BUILDER(Name("AddAttr").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, AddAttr) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(
- OpDefBuilder("AddAttr").Output("ndef: string").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("add_attr", &old_op_def).Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ OpDefBuilder("AddAttr").Output("ndef: string").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("add_attr", &old_op.op_def).Finalize(node_def()));
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("add_attr = AddAttr[a=42]()", Result());
}
@@ -183,15 +181,15 @@ REGISTER_OP("LessStrict").Output("ndef: string").Attr("a: {'A', 'B', 'C'}");
REGISTER_KERNEL_BUILDER(Name("LessStrict").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, LessStrict) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("LessStrict")
.Output("ndef: string")
.Attr("a: {'A', 'B'}")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("less_strict", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("less_strict", &old_op.op_def)
.Attr("a", "B")
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("less_strict = LessStrict[a=\"B\"]()", Result());
}
@@ -201,15 +199,15 @@ REGISTER_KERNEL_BUILDER(Name("RemoveRestriction").Device(DEVICE_CPU),
TestKernel);
TEST_F(OpCompatibilityTest, RemoveRestriction) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("RemoveRestriction")
.Output("ndef: string")
.Attr("a: {int32, bool}")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op.op_def)
.Attr("a", DT_INT32)
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("remove_restriction = RemoveRestriction[a=DT_INT32]()", Result());
}
@@ -218,17 +216,17 @@ REGISTER_OP("AttrOrder").Output("ndef: string").Attr("a: int").Attr("b: bool");
REGISTER_KERNEL_BUILDER(Name("AttrOrder").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, AttrOrder) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("AttrOrder")
.Output("ndef: string")
.Attr("b: bool")
.Attr("a: int")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("attr_order", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("attr_order", &old_op.op_def)
.Attr("b", true)
.Attr("a", 7)
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("attr_order = AttrOrder[a=7, b=true]()", Result());
}
@@ -237,15 +235,15 @@ REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234");
REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, AddDefault) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("AddDefault")
.Output("ndef: string")
.Attr("a: int")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def)
.Attr("a", 765)
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("add_default = AddDefault[a=765]()", Result());
}
@@ -255,14 +253,14 @@ REGISTER_OP("RemoveDefault").Output("ndef: string").Attr("a: int");
REGISTER_KERNEL_BUILDER(Name("RemoveDefault").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, RemoveDefault) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("RemoveDefault")
.Output("ndef: string")
.Attr("a: int = 91")
- .Finalize(&old_op_def));
+ .Finalize(&old_op));
TF_ASSERT_OK(
- NodeDefBuilder("remove_default", &old_op_def).Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def()));
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("remove_default = RemoveDefault[a=91]()", Result());
}
@@ -275,15 +273,15 @@ REGISTER_OP("TypePolymorphic")
REGISTER_KERNEL_BUILDER(Name("TypePolymorphic").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, TypePolymorphic) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("TypePolymorphic")
.Input("a: int32")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("type_polymorphic", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("type_polymorphic", &old_op.op_def)
.Input(FakeInput())
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("type_polymorphic = TypePolymorphic[T=DT_INT32](a)", Result());
}
@@ -296,15 +294,15 @@ REGISTER_OP("MakeList")
REGISTER_KERNEL_BUILDER(Name("MakeList").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, MakeList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("MakeList")
.Input("a: int32")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def)
.Input(FakeInput())
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("make_list = MakeList[N=1](a)", Result());
}
@@ -319,15 +317,15 @@ REGISTER_OP("MakePolyList")
REGISTER_KERNEL_BUILDER(Name("MakePolyList").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, MakePolyList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("MakePolyList")
.Input("a: int32")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("make_poly_list", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("make_poly_list", &old_op.op_def)
.Input(FakeInput())
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("make_poly_list = MakePolyList[N=1, T=DT_INT32](a)", Result());
}
@@ -340,15 +338,15 @@ REGISTER_OP("MakeAnyList")
REGISTER_KERNEL_BUILDER(Name("MakeAnyList").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, MakeAnyList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("MakeAnyList")
.Input("a: int32")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("make_any_list", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("make_any_list", &old_op.op_def)
.Input(FakeInput())
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("make_any_list = MakeAnyList[T=[DT_INT32]](a)", Result());
}
@@ -362,16 +360,16 @@ REGISTER_OP("PolyIntoList")
REGISTER_KERNEL_BUILDER(Name("PolyIntoList").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, PolyIntoList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("PolyIntoList")
.Input("a: T")
.Output("ndef: string")
.Attr("T: type")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("poly_into_list", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("poly_into_list", &old_op.op_def)
.Input(FakeInput(DT_INT32))
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("poly_into_list = PolyIntoList[N=1, T=DT_INT32](a)", Result());
}
@@ -387,17 +385,17 @@ REGISTER_KERNEL_BUILDER(Name("MakeMultipleSameList").Device(DEVICE_CPU),
TestKernel);
TEST_F(OpCompatibilityTest, MakeMultipleSameList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("MakeMultipleSameList")
.Input("a: int32")
.Input("b: int32")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def)
.Input(FakeInput())
.Input(FakeInput())
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("make_list = MakeMultipleSameList[N=2](a, b)", Result());
}
@@ -411,17 +409,17 @@ REGISTER_KERNEL_BUILDER(Name("MakeMultipleAnyList").Device(DEVICE_CPU),
TestKernel);
TEST_F(OpCompatibilityTest, MakeMultipleAnyList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("MakeMultipleAnyList")
.Input("a: int32")
.Input("b: float")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def)
.Input(FakeInput())
.Input(FakeInput())
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("make_list = MakeMultipleAnyList[T=[DT_INT32, DT_FLOAT]](a, b)",
Result());
}
@@ -431,15 +429,15 @@ REGISTER_OP("ChangeName").Input("y: int32").Output("ndef: string");
REGISTER_KERNEL_BUILDER(Name("ChangeName").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, ChangeName) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("ChangeName")
.Input("x: int32")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("change_name", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("change_name", &old_op.op_def)
.Input(FakeInput())
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("change_name = ChangeName[](a)", Result());
}
@@ -452,11 +450,12 @@ REGISTER_OP("AddNInts")
REGISTER_KERNEL_BUILDER(Name("AddNInts").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, AddNInts) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(
- OpDefBuilder("AddNInts").Output("ndef: string").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("add_n_ints", &old_op_def).Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ OpDefBuilder("AddNInts").Output("ndef: string").Finalize(&old_op));
+ TF_ASSERT_OK(
+ NodeDefBuilder("add_n_ints", &old_op.op_def).Finalize(node_def()));
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("add_n_ints = AddNInts[N=0]()", Result());
}
@@ -470,11 +469,12 @@ REGISTER_OP("AddNSame")
REGISTER_KERNEL_BUILDER(Name("AddNSame").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, AddNSame) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(
+ OpDefBuilder("AddNSame").Output("ndef: string").Finalize(&old_op));
TF_ASSERT_OK(
- OpDefBuilder("AddNSame").Output("ndef: string").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("add_n_same", &old_op_def).Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ NodeDefBuilder("add_n_same", &old_op.op_def).Finalize(node_def()));
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("add_n_same = AddNSame[N=0, T=DT_BOOL]()", Result());
}
@@ -490,16 +490,16 @@ REGISTER_KERNEL_BUILDER(Name("AddNSameAsExisting").Device(DEVICE_CPU),
TestKernel);
TEST_F(OpCompatibilityTest, AddNSameAsExisting) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("AddNSameAsExisting")
.Input("a: T")
.Output("ndef: string")
.Attr("T: type")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("add_n_same_as_existing", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("add_n_same_as_existing", &old_op.op_def)
.Input(FakeInput(DT_STRING))
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("add_n_same_as_existing = AddNSameAsExisting[N=0, T=DT_STRING](a)",
Result());
}
@@ -513,12 +513,12 @@ REGISTER_OP("AddAnyList")
REGISTER_KERNEL_BUILDER(Name("AddAnyList").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, AddAnyList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(
- OpDefBuilder("AddAnyList").Output("ndef: string").Finalize(&old_op_def));
+ OpDefBuilder("AddAnyList").Output("ndef: string").Finalize(&old_op));
TF_ASSERT_OK(
- NodeDefBuilder("add_any_list", &old_op_def).Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ NodeDefBuilder("add_any_list", &old_op.op_def).Finalize(node_def()));
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("add_any_list = AddAnyList[T=[]]()", Result());
}
@@ -530,16 +530,16 @@ REGISTER_OP("ShorterAnyList")
REGISTER_KERNEL_BUILDER(Name("ShorterAnyList").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, ShorterAnyList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("ShorterAnyList")
.Input("a: T")
.Output("ndef: string")
.Attr("T: list(type) >= 2")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("shorter_any_list", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("shorter_any_list", &old_op.op_def)
.Input(FakeInput(2, DT_BOOL))
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("shorter_any_list = ShorterAnyList[T=[DT_BOOL, DT_BOOL]](a, a:1)",
Result());
}
@@ -551,16 +551,16 @@ REGISTER_OP("ShorterSameList")
REGISTER_KERNEL_BUILDER(Name("ShorterSameList").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, ShorterSameList) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("ShorterSameList")
.Input("a: N * int32")
.Output("ndef: string")
.Attr("N: int >= 2")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("shorter_same_list", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("shorter_same_list", &old_op.op_def)
.Input(FakeInput(2))
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("shorter_same_list = ShorterSameList[N=2](a, a:1)", Result());
}
@@ -571,15 +571,15 @@ REGISTER_KERNEL_BUILDER(Name("AttrRemoveRestriction").Device(DEVICE_CPU),
TestKernel);
TEST_F(OpCompatibilityTest, AttrRemoveRestriction) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("AttrRemoveRestriction")
.Attr("t: {int32,int64}")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op.op_def)
.Attr("t", DT_INT32)
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("remove_restriction = AttrRemoveRestriction[t=DT_INT32]()",
Result());
}
@@ -593,15 +593,15 @@ REGISTER_KERNEL_BUILDER(Name("AttrLessRestrictive").Device(DEVICE_CPU),
TestKernel);
TEST_F(OpCompatibilityTest, AttrLessRestrictive) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("AttrLessRestrictive")
.Attr("t: {int32, int64}")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("less_restrictive", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("less_restrictive", &old_op.op_def)
.Attr("t", DT_INT32)
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("less_restrictive = AttrLessRestrictive[t=DT_INT32]()", Result());
}
@@ -611,15 +611,15 @@ REGISTER_OP("AttrRemoveMin").Attr("n: int").Output("ndef: string");
REGISTER_KERNEL_BUILDER(Name("AttrRemoveMin").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, AttrRemoveMin) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("AttrRemoveMin")
.Attr("n: int >= 3")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("remove_min", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("remove_min", &old_op.op_def)
.Attr("n", 4)
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("remove_min = AttrRemoveMin[n=4]()", Result());
}
@@ -629,15 +629,15 @@ REGISTER_OP("AttrLowerMin").Attr("n: int >= 1").Output("ndef: string");
REGISTER_KERNEL_BUILDER(Name("AttrLowerMin").Device(DEVICE_CPU), TestKernel);
TEST_F(OpCompatibilityTest, AttrLowerMin) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("AttrLowerMin")
.Attr("n: int >= 3")
.Output("ndef: string")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("lower_min", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("lower_min", &old_op.op_def)
.Attr("n", 4)
.Finalize(node_def()));
- ExpectSuccess(old_op_def);
+ ExpectSuccess(old_op.op_def);
EXPECT_EQ("lower_min = AttrLowerMin[n=4]()", Result());
}
@@ -647,11 +647,12 @@ TEST_F(OpCompatibilityTest, AttrLowerMin) {
REGISTER_OP("RemoveAttr");
TEST_F(OpCompatibilityTest, RemoveAttrFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(OpDefBuilder("RemoveAttr").Attr("a: int").Finalize(&old_op_def));
- TF_ASSERT_OK(
- NodeDefBuilder("fails", &old_op_def).Attr("a", 3).Finalize(node_def()));
- ExpectInvalid(old_op_def, "NodeDef mentions attr 'a' not in",
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("RemoveAttr").Attr("a: int").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
+ .Attr("a", 3)
+ .Finalize(node_def()));
+ ExpectInvalid(old_op.op_def, "NodeDef mentions attr 'a' not in",
"Attr 'a' removed");
}
@@ -659,10 +660,10 @@ TEST_F(OpCompatibilityTest, RemoveAttrFails) {
REGISTER_OP("AddAttrNoDefault").Attr("a: int");
TEST_F(OpCompatibilityTest, AddAttrNoDefaultFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(OpDefBuilder("AddAttrNoDefault").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
- ExpectInvalid(old_op_def, "NodeDef missing attr 'a'",
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("AddAttrNoDefault").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def()));
+ ExpectInvalid(old_op.op_def, "NodeDef missing attr 'a'",
"Attr 'a' added without default");
}
@@ -670,10 +671,10 @@ TEST_F(OpCompatibilityTest, AddAttrNoDefaultFails) {
REGISTER_OP("AddSingleInput").Input("a: int32");
TEST_F(OpCompatibilityTest, AddSingleInputFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(OpDefBuilder("AddSingleInput").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("AddSingleInput").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def()));
+ ExpectInvalid(old_op.op_def,
"expected inputs 'int32' do not match 0 inputs specified",
"Input signature mismatch '' vs. 'int32'");
}
@@ -690,28 +691,28 @@ REGISTER_OP("AddListBigDefault")
.Attr("T: list(type) = [DT_INT32]");
TEST_F(OpCompatibilityTest, AddNIntsBigDefaultFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(OpDefBuilder("AddNIntsBigDefault").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("AddNIntsBigDefault").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def()));
+ ExpectInvalid(old_op.op_def,
"expected inputs 'int32' do not match 0 inputs specified",
"Input signature mismatch '' vs. 'int32'");
}
TEST_F(OpCompatibilityTest, AddNSameBigDefaultFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(OpDefBuilder("AddNSameBigDefault").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("AddNSameBigDefault").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def()));
+ ExpectInvalid(old_op.op_def,
"expected inputs 'int32' do not match 0 inputs specified",
"Input signature mismatch '' vs. 'int32'");
}
TEST_F(OpCompatibilityTest, AddListBigDefaultFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(OpDefBuilder("AddListBigDefault").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("AddListBigDefault").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def()));
+ ExpectInvalid(old_op.op_def,
"expected inputs 'int32' do not match 0 inputs specified",
"Input signature mismatch '' vs. 'int32'");
}
@@ -721,13 +722,12 @@ TEST_F(OpCompatibilityTest, AddListBigDefaultFails) {
REGISTER_OP("ChangeType").Input("a: float");
TEST_F(OpCompatibilityTest, ChangeTypeFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(
- OpDefBuilder("ChangeType").Input("a: int32").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("ChangeType").Input("a: int32").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
.Input(FakeInput())
.Finalize(node_def()));
- ExpectTypeMismatch(old_op_def,
+ ExpectTypeMismatch(old_op.op_def,
"Input signature mismatch 'int32' vs. 'float'");
}
@@ -736,17 +736,18 @@ TEST_F(OpCompatibilityTest, ChangeTypeFails) {
REGISTER_OP("ChangeOrder").Input("a: float").Input("b: int32");
TEST_F(OpCompatibilityTest, ChangeOrderFails) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("ChangeOrder")
.Input("b: int32")
.Input("a: float")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
.Input(FakeInput())
.Input(FakeInput())
.Finalize(node_def()));
ExpectTypeMismatch(
- old_op_def, "Input signature mismatch 'int32, float' vs. 'float, int32'");
+ old_op.op_def,
+ "Input signature mismatch 'int32, float' vs. 'float, int32'");
}
// Can't remove inputs/outputs.
@@ -754,13 +755,12 @@ TEST_F(OpCompatibilityTest, ChangeOrderFails) {
REGISTER_OP("RemoveInput");
TEST_F(OpCompatibilityTest, RemoveInputFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(
- OpDefBuilder("RemoveInput").Input("a: float").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("RemoveInput").Input("a: float").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
.Input(FakeInput())
.Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ ExpectInvalid(old_op.op_def,
"expected inputs '' do not match 1 inputs specified",
"Input signature mismatch 'float' vs. ''");
}
@@ -770,13 +770,13 @@ TEST_F(OpCompatibilityTest, RemoveInputFails) {
REGISTER_OP("ChangeAttrType").Attr("a: int");
TEST_F(OpCompatibilityTest, ChangeAttrTypeFails) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(
- OpDefBuilder("ChangeAttrType").Attr("a: bool").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ OpDefBuilder("ChangeAttrType").Attr("a: bool").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
.Attr("a", true)
.Finalize(node_def()));
- ExpectInvalid(old_op_def, "value with type 'bool' when 'int' expected",
+ ExpectInvalid(old_op.op_def, "value with type 'bool' when 'int' expected",
"Attr 'a' changed type 'bool' -> 'int'");
}
@@ -785,12 +785,14 @@ TEST_F(OpCompatibilityTest, ChangeAttrTypeFails) {
REGISTER_OP("AttrFromList").Attr("a: int");
TEST_F(OpCompatibilityTest, AttrFromListFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(
- OpDefBuilder("AttrFromList").Attr("a: list(int)").Finalize(&old_op_def));
+ OpRegistrationData old_op;
TF_ASSERT_OK(
- NodeDefBuilder("fails", &old_op_def).Attr("a", {5}).Finalize(node_def()));
- ExpectInvalid(old_op_def, "value with type 'list(int)' when 'int' expected",
+ OpDefBuilder("AttrFromList").Attr("a: list(int)").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
+ .Attr("a", {5})
+ .Finalize(node_def()));
+ ExpectInvalid(old_op.op_def,
+ "value with type 'list(int)' when 'int' expected",
"Attr 'a' changed type 'list(int)' -> 'int'");
}
@@ -799,11 +801,13 @@ TEST_F(OpCompatibilityTest, AttrFromListFails) {
REGISTER_OP("AttrToList").Attr("a: list(int)");
TEST_F(OpCompatibilityTest, AttrToListFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(OpDefBuilder("AttrToList").Attr("a: int").Finalize(&old_op_def));
- TF_ASSERT_OK(
- NodeDefBuilder("fails", &old_op_def).Attr("a", 5).Finalize(node_def()));
- ExpectInvalid(old_op_def, "value with type 'int' when 'list(int)' expected",
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("AttrToList").Attr("a: int").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
+ .Attr("a", 5)
+ .Finalize(node_def()));
+ ExpectInvalid(old_op.op_def,
+ "value with type 'int' when 'list(int)' expected",
"Attr 'a' changed type 'int' -> 'list(int)'");
}
@@ -812,15 +816,16 @@ TEST_F(OpCompatibilityTest, AttrToListFails) {
REGISTER_OP("PolymorphicToAnyList").Input("a: T").Attr("T: list(type)");
TEST_F(OpCompatibilityTest, PolymorphicToAnyListFails) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("PolymorphicToAnyList")
.Input("a: T")
.Attr("T: type")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
.Input(FakeInput(DT_INT32))
.Finalize(node_def()));
- ExpectInvalid(old_op_def, "value with type 'type' when 'list(type)' expected",
+ ExpectInvalid(old_op.op_def,
+ "value with type 'type' when 'list(type)' expected",
"Attr 'T' changed type 'type' -> 'list(type)'");
}
@@ -832,16 +837,17 @@ REGISTER_OP("SameToAnyList")
.Attr("N: int = 1");
TEST_F(OpCompatibilityTest, SameToAnyListFails) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("SameToAnyList")
.Input("a: N * T")
.Attr("T: type")
.Attr("N: int")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def)
.Input(FakeInput(1, DT_INT32))
.Finalize(node_def()));
- ExpectInvalid(old_op_def, "value with type 'type' when 'list(type)' expected",
+ ExpectInvalid(old_op.op_def,
+ "value with type 'type' when 'list(type)' expected",
"Attr 'T' changed type 'type' -> 'list(type)'");
}
@@ -850,13 +856,13 @@ TEST_F(OpCompatibilityTest, SameToAnyListFails) {
REGISTER_OP("AttrAddRestriction").Attr("t: {int32, int64}");
TEST_F(OpCompatibilityTest, AttrAddRestrictionFails) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(
- OpDefBuilder("AttrAddRestriction").Attr("t: type").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("add_restriction", &old_op_def)
+ OpDefBuilder("AttrAddRestriction").Attr("t: type").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("add_restriction", &old_op.op_def)
.Attr("t", DT_BOOL)
.Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ ExpectInvalid(old_op.op_def,
"Value for attr 't' of bool is not in the list of allowed "
"values: int32, int64",
"Attr 't' has a stricter set of allowed values; from "
@@ -868,14 +874,14 @@ TEST_F(OpCompatibilityTest, AttrAddRestrictionFails) {
REGISTER_OP("AttrMoreRestrictive").Attr("t: {int32, int64}");
TEST_F(OpCompatibilityTest, AttrMoreRestrictiveFails) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(OpDefBuilder("AttrMoreRestrictive")
.Attr("t: {int32, int64, bool}")
- .Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("more_restrictive", &old_op_def)
+ .Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("more_restrictive", &old_op.op_def)
.Attr("t", DT_BOOL)
.Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ ExpectInvalid(old_op.op_def,
"Value for attr 't' of bool is not in the list of allowed "
"values: int32, int64",
"Attr 't' has a stricter set of allowed values; from "
@@ -887,11 +893,12 @@ TEST_F(OpCompatibilityTest, AttrMoreRestrictiveFails) {
REGISTER_OP("AttrAddMin").Attr("n: int >= 3");
TEST_F(OpCompatibilityTest, AttrAddMinFails) {
- OpDef old_op_def;
- TF_ASSERT_OK(OpDefBuilder("AttrAddMin").Attr("n: int").Finalize(&old_op_def));
- TF_ASSERT_OK(
- NodeDefBuilder("add_min", &old_op_def).Attr("n", 2).Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ OpRegistrationData old_op;
+ TF_ASSERT_OK(OpDefBuilder("AttrAddMin").Attr("n: int").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("add_min", &old_op.op_def)
+ .Attr("n", 2)
+ .Finalize(node_def()));
+ ExpectInvalid(old_op.op_def,
"Value for attr 'n' of 2 must be at least minimum 3",
"Attr 'n' has a higher minimum; from no minimum to 3");
}
@@ -901,13 +908,13 @@ TEST_F(OpCompatibilityTest, AttrAddMinFails) {
REGISTER_OP("AttrRaiseMin").Attr("n: int >= 3");
TEST_F(OpCompatibilityTest, AttrRaiseMinFails) {
- OpDef old_op_def;
+ OpRegistrationData old_op;
TF_ASSERT_OK(
- OpDefBuilder("AttrRaiseMin").Attr("n: int >= 1").Finalize(&old_op_def));
- TF_ASSERT_OK(NodeDefBuilder("raise_min", &old_op_def)
+ OpDefBuilder("AttrRaiseMin").Attr("n: int >= 1").Finalize(&old_op));
+ TF_ASSERT_OK(NodeDefBuilder("raise_min", &old_op.op_def)
.Attr("n", 2)
.Finalize(node_def()));
- ExpectInvalid(old_op_def,
+ ExpectInvalid(old_op.op_def,
"Value for attr 'n' of 2 must be at least minimum 3",
"Attr 'n' has a higher minimum; from 1 to 3");
}
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 2ec8b22458..7e021f7b64 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -491,7 +491,7 @@ void FinalizeDoc(const string& text, OpDef* op_def,
} // namespace
OpDefBuilder::OpDefBuilder(StringPiece op_name) {
- op_def_.set_name(op_name.ToString()); // NOLINT
+ op_def()->set_name(op_name.ToString()); // NOLINT
}
OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
@@ -513,7 +513,7 @@ OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
if (!doc_.empty()) {
errors_.push_back(
- strings::StrCat("Extra call to Doc() for Op ", op_def_.name()));
+ strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
} else {
doc_.assign(text.data(), text.size());
}
@@ -522,41 +522,52 @@ OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
#endif
OpDefBuilder& OpDefBuilder::SetIsCommutative() {
- op_def_.set_is_commutative(true);
+ op_def()->set_is_commutative(true);
return *this;
}
OpDefBuilder& OpDefBuilder::SetIsAggregate() {
- op_def_.set_is_aggregate(true);
+ op_def()->set_is_aggregate(true);
return *this;
}
OpDefBuilder& OpDefBuilder::SetIsStateful() {
- op_def_.set_is_stateful(true);
+ op_def()->set_is_stateful(true);
return *this;
}
OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
- op_def_.set_allows_uninitialized_input(true);
+ op_def()->set_allows_uninitialized_input(true);
return *this;
}
OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
- if (op_def_.has_deprecation()) {
+ if (op_def()->has_deprecation()) {
errors_.push_back(
- strings::StrCat("Deprecated called twice for Op ", op_def_.name()));
+ strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
} else {
- OpDeprecation* deprecation = op_def_.mutable_deprecation();
+ OpDeprecation* deprecation = op_def()->mutable_deprecation();
deprecation->set_version(version);
deprecation->set_explanation(explanation.ToString());
}
return *this;
}
-Status OpDefBuilder::Finalize(OpDef* op_def) const {
+OpDefBuilder& OpDefBuilder::SetShapeFn(const OpShapeInferenceFn& fn) {
+ if (op_reg_data_.shape_inference_fn != nullptr) {
+ errors_.push_back(
+ strings::StrCat("SetShapeFn called twice for Op ", op_def()->name()));
+ } else {
+ op_reg_data_.shape_inference_fn = fn;
+ }
+ return *this;
+}
+
+Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
std::vector<string> errors = errors_;
- *op_def = op_def_;
+ *op_reg_data = op_reg_data_;
+ OpDef* op_def = &op_reg_data->op_def;
for (StringPiece attr : attrs_) {
FinalizeAttr(attr, op_def, &errors);
}
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index 4ca5927d8a..14950eff28 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Class and associated machinery for specifying an Op's OpDef for Op
-// registration.
+// Class and associated machinery for specifying an Op's OpDef and shape
+// inference function for Op registration.
#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
@@ -24,9 +24,25 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+namespace shape_inference {
+class InferenceContext;
+}
+typedef std::function<Status(shape_inference::InferenceContext* c)>
+ OpShapeInferenceFn;
+
+struct OpRegistrationData {
+ public:
+ OpRegistrationData() {}
+ OpRegistrationData(const OpDef& def) : op_def(def) {}
+
+ OpDef op_def;
+ OpShapeInferenceFn shape_inference_fn;
+};
+
// Builder class passed to the REGISTER_OP() macro.
class OpDefBuilder {
public:
@@ -111,14 +127,21 @@ class OpDefBuilder {
OpDefBuilder& Doc(StringPiece text) { return *this; }
#endif
- // Sets *op_def to the requested OpDef, or returns an error.
+ OpDefBuilder& SetShapeFn(const OpShapeInferenceFn& fn);
+
+ // Sets op_reg_data->op_def to the requested OpDef and
+ // op_reg_data->shape_inference_fn to the requested shape inference function,
+ // or returns an error.
// Must be called after all of the above methods.
+ //
// Note that OpDefBuilder only reports parsing errors. You should also
// call ValidateOpDef() to detect other problems.
- Status Finalize(OpDef* op_def) const;
+ Status Finalize(OpRegistrationData* op_reg_data) const;
private:
- OpDef op_def_;
+ OpDef* op_def() { return &op_reg_data_.op_def; }
+
+ OpRegistrationData op_reg_data_;
std::vector<string> attrs_;
std::vector<string> inputs_;
std::vector<string> outputs_;
diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc
index ff986746ec..15f9c9671f 100644
--- a/tensorflow/core/framework/op_def_builder_test.cc
+++ b/tensorflow/core/framework/op_def_builder_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -38,10 +39,12 @@ class OpDefBuilderTest : public ::testing::Test {
protected:
OpDefBuilder b() { return OpDefBuilder("Test"); }
- void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto) {
- OpDef op_def;
- Status status = builder.Finalize(&op_def);
+ void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto,
+ OpShapeInferenceFn* shape_fn_out = nullptr) {
+ OpRegistrationData op_reg_data;
+ Status status = builder.Finalize(&op_reg_data);
TF_EXPECT_OK(status);
+ OpDef& op_def = op_reg_data.op_def;
if (status.ok()) {
OpDef expected;
protobuf::TextFormat::ParseFromString(
@@ -50,13 +53,18 @@ class OpDefBuilderTest : public ::testing::Test {
CanonicalizeAttrTypeListOrder(&op_def);
CanonicalizeAttrTypeListOrder(&expected);
EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
+
+ if (shape_fn_out) {
+ *shape_fn_out = op_reg_data.shape_inference_fn;
+ }
}
}
void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) {
- OpDef op_def;
- Status status = builder.Finalize(&op_def);
+ OpRegistrationData op_reg_data;
+ Status status = builder.Finalize(&op_reg_data);
TF_EXPECT_OK(status);
+ OpDef& op_def = op_reg_data.op_def;
if (status.ok()) {
OpDef expected;
protobuf::TextFormat::ParseFromString(
@@ -66,8 +74,8 @@ class OpDefBuilderTest : public ::testing::Test {
}
void ExpectFailure(const OpDefBuilder& builder, string error) {
- OpDef op_def;
- Status status = builder.Finalize(&op_def);
+ OpRegistrationData op_reg_data;
+ Status status = builder.Finalize(&op_reg_data);
EXPECT_FALSE(status.ok());
if (!status.ok()) {
EXPECT_EQ(status.error_message(), error);
@@ -573,5 +581,26 @@ attr {
)proto");
}
+TEST_F(OpDefBuilderTest, SetShapeFn) {
+ auto fn = OpShapeInferenceFn([](shape_inference::InferenceContext* c) {
+ return errors::Unknown("ShapeFn was called");
+ });
+ OpShapeInferenceFn fn_out;
+ ExpectSuccess(
+ b().SetShapeFn(fn).Attr("dtype: type"),
+ "attr { name: \"dtype\" type: \"type\" allowed_values { list { } } }",
+ &fn_out);
+ ASSERT_TRUE(fn_out != nullptr);
+ EXPECT_EQ("ShapeFn was called", fn_out(nullptr).error_message());
+}
+
+TEST_F(OpDefBuilderTest, SetShapeFnCalledTwiceFailure) {
+ auto fn = OpShapeInferenceFn([](shape_inference::InferenceContext* c) {
+ return errors::Unknown("ShapeFn was called");
+ });
+ ExpectFailure(b().SetShapeFn(fn).SetShapeFn(fn),
+ "SetShapeFn called twice for Op Test");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc
index 71924b4c98..e24b645683 100644
--- a/tensorflow/core/framework/op_def_util_test.cc
+++ b/tensorflow/core/framework/op_def_util_test.cc
@@ -37,13 +37,13 @@ class ValidateOpDefTest : public ::testing::Test {
Status TestProto(const string& text) { return ValidateOpDef(FromText(text)); }
Status TestBuilder(const OpDefBuilder& builder) {
- OpDef op_def;
- Status status = builder.Finalize(&op_def);
+ OpRegistrationData op_reg_data;
+ Status status = builder.Finalize(&op_reg_data);
TF_EXPECT_OK(status);
if (!status.ok()) {
return status;
} else {
- return ValidateOpDef(op_def);
+ return ValidateOpDef(op_reg_data.op_def);
}
}
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 30de279e7f..4df0ffeda7 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -757,9 +757,9 @@ Status SupportedDeviceTypesForNode(
// DynamicPlacer) to consider the possibility that 'def' is call to
// a user-defined function and only calls this
// SupportedDeviceTypesForNode for primitive ops.
- Status s;
- const OpDef* op_def = OpRegistry::Global()->LookUp(def.op(), &s);
- if (op_def) {
+ const OpRegistrationData* op_reg_data;
+ const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
+ if (s.ok()) {
for (const DeviceType& device_type : prioritized_types) {
const KernelRegistration* reg = nullptr;
TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, def, &reg));
@@ -790,9 +790,9 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
// Look up the Op registered for this op name.
- Status s;
- const OpDef* op_def = OpRegistry::Global()->LookUp(node_def.op(), &s);
- if (op_def == nullptr) return s;
+ const OpDef* op_def = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def);
+ if (!s.ok()) return s;
// Validate node_def against OpDef.
s = ValidateNodeDef(node_def, *op_def);
@@ -858,22 +858,23 @@ bool FindArgInOp(StringPiece arg_name,
} // namespace
Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
- Status unused_status;
for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
const KernelDef& kernel_def(key_registration.second.def);
- const OpDef* op_def = op_registry.LookUp(kernel_def.op(), &unused_status);
- if (op_def == nullptr) {
+ const OpRegistrationData* op_reg_data;
+ const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
+ if (!status.ok()) {
// TODO(josh11b): Make this a hard error.
LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def)
<< "') for unknown op: " << kernel_def.op();
continue;
}
+ const OpDef& op_def = op_reg_data->op_def;
for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
- if (!FindArgInOp(host_memory_arg, op_def->input_arg()) &&
- !FindArgInOp(host_memory_arg, op_def->output_arg())) {
+ if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
+ !FindArgInOp(host_memory_arg, op_def.output_arg())) {
return errors::InvalidArgument("HostMemory arg '", host_memory_arg,
"' not found in OpDef: ",
- SummarizeOpDef(*op_def));
+ SummarizeOpDef(op_def));
}
}
}
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 17b9d33386..5269314583 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -125,8 +125,8 @@ Status InferenceContext::WithValue(const Dimension* dim, int64 value,
return Status::OK();
}
*out = nullptr;
- return errors::InvalidArgument("Dimension must be size ", value,
- " but is size ", existing);
+ return errors::InvalidArgument("Dimension must be ", value, " but is ",
+ existing);
}
Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1,
@@ -142,7 +142,7 @@ Status InferenceContext::Merge(const Dimension* d0, const Dimension* d1,
return Status::OK();
} else {
*out = nullptr;
- return errors::InvalidArgument("Dimensions must be equal size, but are ",
+ return errors::InvalidArgument("Dimensions must be equal, but are ",
Value(d0), " and ", Value(d1));
}
}
@@ -181,7 +181,8 @@ Status InferenceContext::Merge(const Shape* s0, const Shape* s1,
return_s1 = false;
} else if (v0 != v1) {
*out = nullptr;
- return errors::InvalidArgument("Dimensions must be equal size, but are ",
+ return errors::InvalidArgument("Dimension ", i,
+ " in both shapes must be equal, but are ",
Value(d0), " and ", Value(d1));
}
}
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index c0681a91fc..55e8b90326 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -91,8 +91,9 @@ class InferenceContext {
// not available at the time of shape inference.
const Tensor* input_tensor(int idx) const { return input_tensors_[idx]; }
- void set_output(int idx, const Shape* shape);
+ void set_output(int idx, const Shape* shape) { outputs_[idx] = shape; }
int num_outputs() const { return outputs_.size(); }
+ const Shape* output(int idx) { return outputs_[idx]; }
// idx can be negative for an offset from end of dimensions.
const Dimension* Dim(const Shape* s, int32 idx) { return s->dims_[idx]; }
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index b8d21e5e38..27788ac03d 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -113,11 +113,11 @@ TEST(ShapeInferenceTest, WithValue) {
// WithRank on dimension with known size.
out1 = d0;
- EXPECT_EQ("Invalid argument: Dimension must be size 0 but is size 1",
+ EXPECT_EQ("Invalid argument: Dimension must be 0 but is 1",
c.WithValue(d0, 0, &out1).ToString());
EXPECT_TRUE(out1 == nullptr);
out1 = d0;
- EXPECT_EQ("Invalid argument: Dimension must be size 2 but is size 1",
+ EXPECT_EQ("Invalid argument: Dimension must be 2 but is 1",
c.WithValue(d0, 2, &out1).ToString());
EXPECT_TRUE(out1 == nullptr);
EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok());
@@ -159,10 +159,10 @@ TEST(ShapeInferenceTest, MergeDim) {
EXPECT_TRUE(d2_b == out);
// Merging inequal values is an error.
- EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 2 and 1",
+ EXPECT_EQ("Invalid argument: Dimensions must be equal, but are 2 and 1",
c.Merge(d2, d1, &out).ToString());
EXPECT_TRUE(out == nullptr);
- EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 1 and 2",
+ EXPECT_EQ("Invalid argument: Dimensions must be equal, but are 1 and 2",
c.Merge(d1, d2, &out).ToString());
EXPECT_TRUE(out == nullptr);
}
@@ -210,11 +210,13 @@ TEST(ShapeInferenceTest, MergeShape) {
// Incompatible merges give errors and set out to nullptr.
out = s_unknown;
- EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 2 and 3",
+ EXPECT_EQ(("Invalid argument: Dimension 1 in both shapes must be equal, but "
+ "are 2 and 3"),
c.Merge(s_u_2, s_1_3, &out).ToString());
EXPECT_TRUE(out == nullptr);
out = s_unknown;
- EXPECT_EQ("Invalid argument: Dimensions must be equal size, but are 3 and 2",
+ EXPECT_EQ(("Invalid argument: Dimension 1 in both shapes must be equal, but "
+ "are 3 and 2"),
c.Merge(s_1_3, s_u_2, &out).ToString());
EXPECT_TRUE(out == nullptr);
out = s_unknown;
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
new file mode 100644
index 0000000000..9b56014edb
--- /dev/null
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -0,0 +1,153 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+using shape_inference::Dimension;
+using shape_inference::Shape;
+using errors::Unknown;
+
+Status InferShapes(const string& op_name, const string& ins,
+ const string& expected_outs) {
+ const OpRegistrationData* op_reg_data;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data));
+ const int num_outputs = op_reg_data->op_def.output_arg_size();
+
+ std::vector<string> ins_v = str_util::Split(ins, ';');
+ shape_inference::InferenceContext c(ins_v, num_outputs);
+ TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
+
+ std::unordered_map<const Dimension*, std::pair<int, int>>
+ dim_to_input_and_dim_idx;
+ std::unordered_map<const Shape*, int> shape_to_input_idx;
+ for (int i = 0; i < c.num_inputs(); ++i) {
+ auto in = c.input(i);
+ shape_to_input_idx[in] = i;
+ for (int j = 0; j < c.Rank(in); ++j) {
+ dim_to_input_and_dim_idx[c.Dim(in, j)] = std::make_pair(i, j);
+ }
+ }
+ if (expected_outs == "e") {
+ return Unknown("Shape inference should have returned error");
+ }
+
+ // Verify the output shape.
+ std::vector<string> expected_outs_v = str_util::Split(expected_outs, ';');
+ if (num_outputs != expected_outs_v.size()) {
+ return Unknown("Wrong number of expected outputs (", expected_outs_v.size(),
+ " vs ", num_outputs, ")");
+ }
+ for (int i = 0; i < num_outputs; ++i) {
+ string err_prefix = strings::StrCat("Output ", i);
+ StringPiece expected(expected_outs_v[i]);
+ const shape_inference::Shape* out = c.output(i);
+ const int in_index = gtl::FindWithDefault(shape_to_input_idx, out, -1);
+ if (expected.starts_with("in")) {
+ if (in_index == -1) {
+ return Unknown(err_prefix, " did not match any input shape");
+ }
+ auto v = str_util::Split(expected, '|');
+ if (std::find(v.begin(), v.end(), strings::StrCat("in", in_index)) ==
+ v.end()) {
+ return Unknown(err_prefix, " matched input ", in_index,
+ " and should have matched one of (", expected, ")");
+ }
+ continue;
+ }
+ if (in_index != -1) {
+ return Unknown(err_prefix, " matched input ", in_index,
+ " and should have not matched an input shape");
+ }
+ if (expected == "?") {
+ if (c.RankKnown(out)) {
+ return Unknown(err_prefix, " expected to be unknown but was ",
+ c.DebugString(out));
+ }
+ continue;
+ }
+
+ // Verify the dimensions.
+ CHECK(expected.starts_with("[") && expected.ends_with("]"));
+ expected.remove_prefix(1);
+ expected.remove_suffix(1);
+
+ // Split expected as a dimension.
+ auto expected_dims = str_util::Split(expected, ',');
+ if (!c.RankKnown(out)) {
+ return Unknown(err_prefix, " expected rank ", expected_dims.size(),
+ " but was ?");
+ }
+ if (c.Rank(out) != expected_dims.size()) {
+ return Unknown(err_prefix, " expected rank ", expected_dims.size(),
+ " but was ", c.Rank(out));
+ }
+ for (int j = 0; j < expected_dims.size(); ++j) {
+ err_prefix = strings::StrCat("Output dim ", i, ",", j);
+ StringPiece expected_dim(expected_dims[j]);
+ const Dimension* out_dim = c.Dim(out, j);
+ std::pair<int, int> in_dim_idx = gtl::FindWithDefault(
+ dim_to_input_and_dim_idx, out_dim, std::make_pair(-1, -1));
+ if (expected_dim == "?") {
+ if (in_dim_idx.first != -1) {
+ return Unknown(err_prefix,
+ " expected to be unknown but matched input d",
+ in_dim_idx.first, "_", in_dim_idx.second);
+ } else if (c.ValueKnown(out_dim)) {
+ return Unknown(err_prefix, " expected to be unknown but was ",
+ c.Value(out_dim));
+ }
+ } else if (expected_dim.starts_with("d")) {
+ // Compare the dimension values.
+ auto v = str_util::Split(expected_dim, '|');
+ if (in_dim_idx.first == -1) {
+ return Unknown(err_prefix, " did not match any input dim");
+ }
+ if (std::find(v.begin(), v.end(),
+ strings::StrCat("d", in_dim_idx.first, "_",
+ in_dim_idx.second)) == v.end()) {
+ return Unknown(err_prefix, " matched input d", in_dim_idx.first, "_",
+ in_dim_idx.second, " and should have matched one of ",
+ expected_dim);
+ }
+ } else {
+ // Parse it as a value.
+ int64 value = -1;
+ if (!strings::safe_strto64(expected_dim, &value)) {
+ return Unknown(err_prefix, " expected dim failed to parse as int64");
+ }
+ if (in_dim_idx.first != -1) {
+ return Unknown(err_prefix, " expected to be ", value,
+ " but matched input d", in_dim_idx.first, "_",
+ in_dim_idx.second);
+ } else if (value != c.Value(out_dim)) {
+ return Unknown(err_prefix, " expected to be ", value, " but was ",
+ c.DebugString(out_dim));
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
new file mode 100644
index 0000000000..f2581247d9
--- /dev/null
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -0,0 +1,56 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
+
+#include <vector>
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+// Contains utilities for writing tests for shape inference functions.
+
+namespace tensorflow {
+
+// Run shape inference for <op_name>, given inputs specified by <ins>
+// and returns an error if the inferred shape does not match expected_outs.
+//
+// <ins> is a semicolon separated list of shapes. Each shape is formatted
+// according to the formatting per
+// shape_inference::InferenceContext::InferenceContext.
+//
+// <expected_outs> is a semicolon separated list of shapes. Each shape is
+// formatted as one of:
+// * ? - an unknown shape, but not matching an input shape
+// * in0|in2|... - output shape must be the same as one of these input shapes.
+// * [1,?,d0_0|d0_1] - output shape is of known rank, with comma-separated
+// dimension values.
+// Each dimension value is one of:
+// * a constant, which means that constant not equal to a specific input
+// * ?, which means an unknown dim size not equal to a specific input
+// * d0_0|d1_2, indicating that the dim size must be equal to one of
+// the given input dimensions; the first number is the input # and
+// the second is which dimension in that input it corresponds to.
+// <expected_outs> can be "e"; this is used to indicate that shape inference
+// should have failed.
+Status InferShapes(const string& op_name, const string& ins,
+ const string& expected_outs);
+
+#define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message())
+#define INFER_ERROR(s, op, i) \
+ EXPECT_EQ(s, InferShapes(op, i, "x").error_message())
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
diff --git a/tensorflow/core/framework/shape_inference_testutil_test.cc b/tensorflow/core/framework/shape_inference_testutil_test.cc
new file mode 100644
index 0000000000..d16ec04264
--- /dev/null
+++ b/tensorflow/core/framework/shape_inference_testutil_test.cc
@@ -0,0 +1,154 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+namespace {
+
+static OpShapeInferenceFn* global_fn_ptr = nullptr;
+REGISTER_OP("OpOneOut")
+ .Input("inputs: N * T")
+ .Output("o1: T")
+ .Attr("N: int >= 1")
+ .Attr("T: numbertype")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return (*global_fn_ptr)(c);
+ }));
+REGISTER_OP("OpTwoOut")
+ .Input("inputs: N * T")
+ .Output("o1: T")
+ .Output("o2: T")
+ .Attr("N: int >= 1")
+ .Attr("T: numbertype")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return (*global_fn_ptr)(c);
+ }));
+
+string RunInferShapes(const string& op_name, const string& ins,
+ const string& expected_outs, OpShapeInferenceFn fn) {
+ global_fn_ptr = &fn;
+ return InferShapes(op_name, ins, expected_outs).error_message();
+}
+
+} // namespace
+
+TEST(ShapeInferenceTestutilTest, Failures) {
+ auto fn_copy_input_0 = [](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ };
+ auto fn_copy_input_2 = [](InferenceContext* c) {
+ c->set_output(0, c->input(2));
+ return Status::OK();
+ };
+ auto fn_output_unknown_shapes = [](InferenceContext* c) {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->CreateUnknownShape());
+ }
+ return Status::OK();
+ };
+ auto fn_output_1_2 = [](InferenceContext* c) {
+ c->set_output(0, c->CreateShape({c->CreateDim(1), c->CreateDim(2)}));
+ return Status::OK();
+ };
+ auto fn_output_u_2 = [](InferenceContext* c) {
+ c->set_output(0, c->CreateShape({c->CreateUnknownDim(), c->CreateDim(2)}));
+ return Status::OK();
+ };
+ const string& op = "OpOneOut";
+
+ EXPECT_EQ("Shape inference should have returned error",
+ RunInferShapes(op, "[1];[2];[1]", "e", fn_copy_input_0));
+ EXPECT_EQ("Wrong number of expected outputs (2 vs 1)",
+ RunInferShapes(op, "[1];[2];[1]", "[1];[2]", fn_copy_input_0));
+ EXPECT_EQ("Op type not registered 'NoSuchOp'",
+ RunInferShapes("NoSuchOp", "", "", fn_copy_input_0));
+
+ // Wrong shape error messages.
+ EXPECT_EQ(
+ "Output 0 matched input 0 and should have not matched an input shape",
+ RunInferShapes(op, "[1];[2];[1]", "?", fn_copy_input_0));
+ EXPECT_EQ("Output 0 matched input 0 and should have matched one of (in2)",
+ RunInferShapes(op, "[1];[2];[1]", "in2", fn_copy_input_0));
+ EXPECT_EQ("Output 0 matched input 0 and should have matched one of (in1|in2)",
+ RunInferShapes(op, "[1];[2];[1]", "in1|in2", fn_copy_input_0));
+ EXPECT_EQ(
+ "Output 0 matched input 2 and should have not matched an input shape",
+ RunInferShapes(op, "[1];[2];[1]", "[1]", fn_copy_input_2));
+ EXPECT_EQ("Output 0 did not match any input shape",
+ RunInferShapes(op, "[1];[2];[1]", "in0|in1", fn_output_1_2));
+ EXPECT_EQ("Output 0 expected to be unknown but was [1,2]",
+ RunInferShapes(op, "[1];[2];[1]", "?", fn_output_1_2));
+ EXPECT_EQ("Output 0 expected rank 3 but was 2",
+ RunInferShapes(op, "[1];[2];[1]", "[1,2,3]", fn_output_1_2));
+ EXPECT_EQ(
+ "Output 0 expected rank 2 but was ?",
+ RunInferShapes(op, "[1];[2];[1]", "[1,2]", fn_output_unknown_shapes));
+
+ // Wrong shape error messages on the second output.
+ EXPECT_EQ("Output 1 expected rank 3 but was ?",
+ RunInferShapes("OpTwoOut", "[1];[2];[1]", "?;[1,2,3]",
+ fn_output_unknown_shapes));
+
+ // Wrong dimension error messages.
+ EXPECT_EQ("Output dim 0,1 expected to be 3 but was 2",
+ RunInferShapes(op, "[1];[2];[1]", "[1,3]", fn_output_1_2));
+ EXPECT_EQ("Output dim 0,0 expected to be 2 but was 1",
+ RunInferShapes(op, "[1];[2];[1]", "[2,2]", fn_output_1_2));
+ EXPECT_EQ("Output dim 0,0 expected to be unknown but was 1",
+ RunInferShapes(op, "[1];[2];[1]", "[?,2]", fn_output_1_2));
+ EXPECT_EQ("Output dim 0,1 expected to be 1 but was 2",
+ RunInferShapes(op, "[1];[2];[1]", "[?,1]", fn_output_u_2));
+ EXPECT_EQ("Output dim 0,0 expected to be 1 but was ?",
+ RunInferShapes(op, "[0,1,?];[2];[1]", "[1,2]", fn_output_u_2));
+ auto fn = [](InferenceContext* c) {
+ c->set_output(
+ 0, c->CreateShape({c->Dim(c->input(0), 1), c->CreateDim(2),
+ c->CreateUnknownDim(), c->Dim(c->input(2), 0)}));
+ return Status::OK();
+ };
+ const string ins = "[0,1,?];[2];[1]";
+ EXPECT_EQ("Output dim 0,0 expected to be unknown but matched input d0_1",
+ RunInferShapes(op, ins, "[?,2,?,d2_0]", fn));
+ EXPECT_EQ("Output dim 0,0 expected to be 0 but matched input d0_1",
+ RunInferShapes(op, ins, "[0,2,?,d2_0]", fn));
+ EXPECT_EQ(
+ "Output dim 0,0 matched input d0_1 and should have matched one of d0_0",
+ RunInferShapes(op, ins, "[d0_0,2,?,d2_0]", fn));
+ EXPECT_EQ("Output dim 0,0 expected dim failed to parse as int64",
+ RunInferShapes(op, ins, "[x,2,?,d2_0]", fn));
+ EXPECT_EQ(
+ "Output dim 0,0 matched input d0_1 and should have matched one of "
+ "d0_0|d0_2",
+ RunInferShapes(op, ins, "[d0_0|d0_2,2,?,d2_0]", fn));
+ EXPECT_EQ("Output dim 0,1 expected to be unknown but was 2",
+ RunInferShapes(op, ins, "[d0_1,?,?,d0_0|d2_0]", fn));
+ EXPECT_EQ("Output dim 0,2 expected to be 8 but was ?",
+ RunInferShapes(op, ins, "[d0_1,2,8,d0_0|d2_0]", fn));
+ EXPECT_EQ("Output dim 0,2 did not match any input dim",
+ RunInferShapes(op, ins, "[d0_1,2,d0_1|d2_0,d0_0|d2_0]", fn));
+ EXPECT_EQ("", // OK, no error.
+ RunInferShapes(op, ins, "[d0_1,2,?,d0_0|d2_0]", fn));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 363fe5623c..3607db7e81 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -205,8 +205,9 @@ Graph::~Graph() {
}
Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
- const OpDef* op_def = ops_->LookUp(node_def.op(), status);
- if (op_def == nullptr) return nullptr;
+ const OpDef* op_def;
+ status->Update(ops_->LookUpOpDef(node_def.op(), &op_def));
+ if (!status->ok()) return nullptr;
DataTypeVector inputs;
DataTypeVector outputs;
diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc
index 21d9e25284..b5fbb5fa27 100644
--- a/tensorflow/core/graph/validate.cc
+++ b/tensorflow/core/graph/validate.cc
@@ -30,8 +30,8 @@ Status ValidateGraphDef(const GraphDef& graph_def,
const int version = graph_def.versions().producer();
for (const NodeDef& node_def : graph_def.node()) {
// Look up the OpDef for the node_def's op name.
- const OpDef* op_def = op_registry.LookUp(node_def.op(), &s);
- TF_RETURN_IF_ERROR(s);
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def.op(), &op_def));
TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, version));
}
diff --git a/tensorflow/core/graph/validate_test.cc b/tensorflow/core/graph/validate_test.cc
index a7357df15d..cb6d107cad 100644
--- a/tensorflow/core/graph/validate_test.cc
+++ b/tensorflow/core/graph/validate_test.cc
@@ -95,8 +95,10 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) {
}
TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) {
+ OpRegistrationData op_reg_data;
+ TF_ASSERT_OK(OpDefBuilder("UniqueSnowflake").Finalize(&op_reg_data));
OpList op_list;
- TF_ASSERT_OK(OpDefBuilder("UniqueSnowflake").Finalize(op_list.add_op()));
+ *op_list.add_op() = op_reg_data.op_def;
const string graph_def_str = "node { name: 'A' op: 'UniqueSnowflake' }";
GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser();
@@ -105,8 +107,10 @@ TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) {
}
TEST(ValidateGraphDefAgainstOpListTest, GraphWithGlobalOpNotInOpList) {
+ OpRegistrationData op_reg_data;
+ TF_ASSERT_OK(OpDefBuilder("NotAnywhere").Finalize(&op_reg_data));
OpList op_list;
- TF_ASSERT_OK(OpDefBuilder("NotAnywhere").Finalize(op_list.add_op()));
+ *op_list.add_op() = op_reg_data.op_def;
const string graph_def_str = "node { name: 'A' op: 'FloatInput' }";
GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser();
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 46398ca5a9..7743173996 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -15,9 +15,14 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+typedef shape_inference::Dimension Dimension;
+typedef shape_inference::InferenceContext InferenceContext;
+typedef shape_inference::Shape Shape;
+
REGISTER_OP("AddN")
.Input("inputs: N * T")
.Output("sum: T")
@@ -25,6 +30,16 @@ REGISTER_OP("AddN")
.Attr("T: numbertype")
.SetIsCommutative()
.SetIsAggregate()
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ const Shape* cur = c->input(c->num_inputs() - 1);
+ for (int i = c->num_inputs() - 2; i >= 0; --i) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
+ "From merging shape ", i,
+ " with other shapes.");
+ }
+ c->set_output(0, cur);
+ return Status::OK();
+ }))
.Doc(R"doc(
Add all input tensors element wise.
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
new file mode 100644
index 0000000000..89d6ed14d1
--- /dev/null
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -0,0 +1,44 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+yo
+
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(MathOpsTest, AddN_ShapeFn) {
+ INFER_OK("AddN", "?;?", "in0|in1");
+ INFER_OK("AddN", "[1,?]", "in0");
+ INFER_OK("AddN", "[1,2];[?,2]", "in0");
+ INFER_OK("AddN", "[1,2];[1,2]", "in0|in1");
+ INFER_OK("AddN", "[?,2];[1,2]", "in1");
+ INFER_OK("AddN", "[1,?];[?,2];[1,2]", "in2");
+ INFER_OK("AddN", "[1,2];[?,2];[1,?]", "in0");
+ INFER_OK("AddN", "?;?;[1,2]", "in2");
+ INFER_OK("AddN", "[1,?];[?,2]", "[d0_0,d1_1]");
+ INFER_OK("AddN", "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]");
+ INFER_OK("AddN", "[?,2];[1,?]", "[d1_0,d0_1]");
+
+ INFER_ERROR(("Dimension 1 in both shapes must be equal, but are 2 and "
+ "4\n\tFrom merging shape 0 with other shapes."),
+ "AddN", "[1,2];?;[1,4]");
+ INFER_ERROR(("Shapes must be equal rank, but are 2 and 3\n\tFrom merging "
+ "shape 1 with other shapes."),
+ "AddN", "?;[1,2];?;[1,2,3]");
+}
+
+} // end namespace tensorflow