aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/function.cc6
-rw-r--r--tensorflow/core/framework/function.h3
-rw-r--r--tensorflow/core/framework/function_test.cc6
-rw-r--r--tensorflow/core/framework/op_def_builder.h2
4 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 1774f74ca8..0fdc2c820c 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb_text.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -867,6 +868,11 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
return Status::OK();
}
+FunctionLibraryDefinition::FunctionDefAndOpRegistration::
+ FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
+ : fdef(fdef_in),
+ op_registration_data(fdef.signature(), shape_inference::UnknownShape) {}
+
FunctionLibraryDefinition::FunctionLibraryDefinition(
const FunctionLibraryDefinition& other)
: default_registry_(other.default_registry_), func_grad_(other.func_grad_) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index d840d2f001..045976dd06 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -348,8 +348,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
private:
// TODO(cwhipkey): support shape functions in FunctionDefLibrary.
struct FunctionDefAndOpRegistration {
- FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
- : fdef(fdef_in), op_registration_data(fdef.signature()) {}
+ FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
FunctionDef fdef;
OpRegistrationData op_registration_data;
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
index 140dbd8932..8e15bf04ab 100644
--- a/tensorflow/core/framework/function_test.cc
+++ b/tensorflow/core/framework/function_test.cc
@@ -938,6 +938,12 @@ TEST(FunctionLibraryDefinitionTest, LookUp) {
ASSERT_NE(op_def, nullptr);
EXPECT_EQ(op_def->DebugString(),
test::function::XTimesTwo().signature().DebugString());
+
+ const OpRegistrationData* op_reg_data;
+ TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data));
+ ASSERT_NE(op_reg_data, nullptr);
+ // Shape inference function is initialized to UnknownShape.
+ ASSERT_NE(op_reg_data->shape_inference_fn, nullptr);
}
TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index e64255e96f..0d492208d4 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -38,6 +38,8 @@ struct OpRegistrationData {
public:
OpRegistrationData() {}
OpRegistrationData(const OpDef& def) : op_def(def) {}
+ OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn)
+ : op_def(def), shape_inference_fn(fn) {}
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;