diff options
-rw-r--r-- | tensorflow/core/framework/function.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/function_test.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/framework/op_def_builder.h | 2 |
4 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 1774f74ca8..0fdc2c820c 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.pb_text.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -867,6 +868,11 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { return Status::OK(); } +FunctionLibraryDefinition::FunctionDefAndOpRegistration:: + FunctionDefAndOpRegistration(const FunctionDef& fdef_in) + : fdef(fdef_in), + op_registration_data(fdef.signature(), shape_inference::UnknownShape) {} + FunctionLibraryDefinition::FunctionLibraryDefinition( const FunctionLibraryDefinition& other) : default_registry_(other.default_registry_), func_grad_(other.func_grad_) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index d840d2f001..045976dd06 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -348,8 +348,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { private: // TODO(cwhipkey): support shape functions in FunctionDefLibrary. struct FunctionDefAndOpRegistration { - FunctionDefAndOpRegistration(const FunctionDef& fdef_in) - : fdef(fdef_in), op_registration_data(fdef.signature()) {} + FunctionDefAndOpRegistration(const FunctionDef& fdef_in); FunctionDef fdef; OpRegistrationData op_registration_data; diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 140dbd8932..8e15bf04ab 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -938,6 +938,12 @@ TEST(FunctionLibraryDefinitionTest, LookUp) { ASSERT_NE(op_def, nullptr); EXPECT_EQ(op_def->DebugString(), test::function::XTimesTwo().signature().DebugString()); + + const OpRegistrationData* op_reg_data; + TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data)); + ASSERT_NE(op_reg_data, nullptr); + // Shape inference function is initialized to UnknownShape. + ASSERT_NE(op_reg_data->shape_inference_fn, nullptr); } TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h index e64255e96f..0d492208d4 100644 --- a/tensorflow/core/framework/op_def_builder.h +++ b/tensorflow/core/framework/op_def_builder.h @@ -38,6 +38,8 @@ struct OpRegistrationData { public: OpRegistrationData() {} OpRegistrationData(const OpDef& def) : op_def(def) {} + OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn) + : op_def(def), shape_inference_fn(fn) {} OpDef op_def; OpShapeInferenceFn shape_inference_fn; |