diff options
author | 2016-10-31 13:11:56 -0800 | |
---|---|---|
committer | 2016-10-31 14:19:03 -0700 | |
commit | 41734d78d3facf652c25b2a2761aadd978b3f2ef (patch) | |
tree | 931685da30dfcb84077664dede0b33d93c609397 | |
parent | 1dd44f3ecc38cdb3df95cb3599f7b843f6d5062a (diff) |
Automated rollback of change 137740850
Change: 137747341
28 files changed, 714 insertions, 162 deletions
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8cdecf706a..907158b646 100644 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -94,6 +94,38 @@ function(RELATIVE_PROTOBUF_GENERATE_PYTHON ROOT_DIR SRCS) set(${SRCS} ${${SRCS}} PARENT_SCOPE) endfunction() +function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) + if(NOT ARGN) + message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_CPP() called without any proto files") + return() + endif() + + set(${SRCS}) + set(${HDRS}) + foreach(FIL ${ARGN}) + set(ABS_FIL ${ROOT_DIR}/${FIL}) + get_filename_component(FIL_WE ${FIL} NAME_WE) + get_filename_component(FIL_DIR ${ABS_FIL} PATH) + file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) + + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h") + + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS} + DEPENDS ${ABS_FIL} protobuf + COMMENT "Running C++ protocol buffer compiler on ${FIL}" + VERBATIM ) + endforeach() + + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) +endfunction() + file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" "${tensorflow_source_dir}/tensorflow/python/*.proto" @@ -102,6 +134,12 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON( ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs} ) +RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS + ${tensorflow_source_dir} ${tf_protos_python_srcs} +) + +add_library(tf_python_protos_cc ${PROTO_SRCS} ${PROTO_HDRS}) + # tf_python_touchup_modules adds empty __init__.py files to all # directories containing Python code, so that Python will recognize # them as modules. @@ -201,6 +239,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ) target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE tf_protos_cc + tf_python_protos_cc ${tensorflow_EXTERNAL_LIBRARIES} ) @@ -312,6 +351,7 @@ target_link_libraries(pywrap_tensorflow ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} tf_protos_cc + tf_python_protos_cc ${PYTHON_LIBRARIES} ) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1c37921afc..79546ccd20 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -379,6 +379,7 @@ tf_gen_op_libs( "no_op", "parsing_ops", "random_ops", + "resource_variable_ops", "sdca_ops", "script_ops", "sendrecv_ops", @@ -542,6 +543,7 @@ cc_library( "//tensorflow/core/kernels:parsing", "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:required", + "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sdca_ops", "//tensorflow/core/kernels:sparse", "//tensorflow/core/kernels:state", diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 1ddd483076..e1500ed1ad 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -42,6 +42,8 @@ Status ShapeRefiner::AddNode(const Node* node) { // indexed by 'node's input. std::vector<Node*> input_nodes(node->num_inputs()); std::vector<ShapeHandle> input_shapes(node->num_inputs()); + std::vector<DataType> input_handle_dtypes(node->num_inputs()); + std::vector<ShapeHandle> input_handle_shapes(node->num_inputs()); for (const Edge* e : node->in_edges()) { if (e->IsControlEdge()) continue; @@ -57,6 +59,15 @@ Status ShapeRefiner::AddNode(const Node* node) { DCHECK_GE(e->dst_input(), 0); input_nodes[e->dst_input()] = input; input_shapes[e->dst_input()] = c->output(e->src_output()); + + // Only propagate handle xshape and dtype of edges which are carrying + // resource handles. + if (e->src()->output_type(e->src_output()) == DT_RESOURCE) { + input_handle_dtypes[e->dst_input()] = + c->output_handle_dtype(e->src_output()); + input_handle_shapes[e->dst_input()] = + c->output_handle_shape(e->src_output()); + } } // Get the shape function for this node @@ -76,9 +87,9 @@ Status ShapeRefiner::AddNode(const Node* node) { std::vector<ShapeHandle> input_tensors_as_shapes; // Create the inference context for this node with the existing input shapes. - std::unique_ptr<InferenceContext> c( - new InferenceContext(&node->def(), node->op_def(), input_shapes, - input_tensors, input_tensors_as_shapes)); + std::unique_ptr<InferenceContext> c(new InferenceContext( + &node->def(), node->op_def(), input_shapes, input_tensors, + input_tensors_as_shapes, input_handle_shapes, input_handle_dtypes)); if (!c->construction_status().ok()) { return c->construction_status(); } diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 7196bc8304..ca1326844c 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -56,7 +56,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) { .Input({{"data", 0, DT_FLOAT}}) .Finalize(&def)); - InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {}); + InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {}, {}, {}); TF_EXPECT_OK(NoOutputs(&c)); EXPECT_EQ(0, c.num_outputs()); } @@ -74,14 +74,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) { NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def)); { - InferenceContext c(&def, op_def, {S({})}, {}, {}); + InferenceContext c(&def, op_def, {S({})}, {}, {}, {}, {}); TF_EXPECT_OK(ScalarShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(0, c.Rank(output)); } { - InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {}); + InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {}, {}, {}); TF_EXPECT_OK(ScalarShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(0, c.Rank(output)); @@ -108,7 +108,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Finalize(&def)); { - InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Unknown inner dimension for one - InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -126,7 +126,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Invalid rank. - InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {}); + InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); EXPECT_TRUE( @@ -136,7 +136,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Unknown outer dimension - InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -145,7 +145,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Inner shapes not compatible - InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); EXPECT_TRUE( @@ -156,7 +156,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Inner shapes not compatible - InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}, + {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); EXPECT_TRUE( @@ -174,7 +175,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Attr("type", DT_FLOAT) .Finalize(&def)); - InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {}); + InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {}, {}, {}); auto s = MatMulShape(&c); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -191,7 +192,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Attr("type", DT_FLOAT) .Finalize(&def)); - InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {}, {}, {}); auto s = MatMulShape(&c); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -215,7 +216,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Finalize(&def)); { - InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -224,7 +225,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Unknown ranks. - InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {}); + InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_FALSE(c.RankKnown(output)); @@ -232,7 +233,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Rank > 2 - InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}); + InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {}, + {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output)); @@ -245,7 +247,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ("[2,3,4,5]", c.DebugString(output)); @@ -258,8 +260,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, - {}); + InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, + {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output)); @@ -272,7 +274,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {}); + InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {}, {}, + {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ("[10,11,12]", c.DebugString(output)); @@ -280,7 +283,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Input rank not high enough - InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {}); + InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {}, {}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); } @@ -292,7 +295,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Attr("data_format", "NCHW") .Finalize(&def)); // NCHW format - InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {}, {}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); } } @@ -311,7 +314,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Finalize(&def)); { - InferenceContext c(&def, op_def, {S({2, 10})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 10})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(10, c.Value(c.Dim(output, 0))); @@ -319,7 +322,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { { // Rank > 2 - InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {}); + InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(10, c.Value(c.Dim(output, 0))); @@ -331,7 +334,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(3, c.Value(c.Dim(output, 0))); @@ -343,7 +346,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}); + InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {}, + {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(3, c.Value(c.Dim(output, 0))); @@ -355,7 +359,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {}); + InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(10, c.Value(c.Dim(output, 0))); @@ -363,7 +367,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { { // Input rank not high enough - InferenceContext c(&def, op_def, {S({3})}, {}, {}); + InferenceContext c(&def, op_def, {S({3})}, {}, {}, {}, {}); EXPECT_FALSE(BiasAddGradShape(&c).ok()); } @@ -374,7 +378,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Attr("data_format", "NCHW") .Finalize(&def)); // NCHW format - InferenceContext c(&def, op_def, {S({2, 3})}, {}, {}); + InferenceContext c(&def, op_def, {S({2, 3})}, {}, {}, {}, {}); EXPECT_FALSE(BiasAddGradShape(&c).ok()); } } diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 2d16760625..85b83c4d74 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -66,6 +66,14 @@ void AddNodeAttr(StringPiece name, std::initializer_list<T> value, AttrValueMap::value_type(name.ToString(), attr_value)); } +// Adds an attr to an attr value map. +template <class T> +void AddAttr(StringPiece name, T&& value, AttrValueMap* map) { + AttrValue attr_value; + SetAttrValue(value, &attr_value); + map->insert(AttrValueMap::value_type(name.ToString(), attr_value)); +} + class AttrSlice { public: AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit) diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index f823462079..30bb2cb708 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -21,6 +21,7 @@ limitations under the License. #include <typeinfo> #include <unordered_map> +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index da88b6a7ca..4aa32f6a84 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -31,7 +31,9 @@ InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector<TensorShapeProto>& input_shapes, const std::vector<const Tensor*>& input_tensors, - const std::vector<ShapeHandle>& input_tensors_as_shapes) + const std::vector<ShapeHandle>& input_tensors_as_shapes, + const std::vector<TensorShapeProto>& input_handle_shapes, + const std::vector<DataType>& input_handle_dtypes) : node_def_(*CHECK_NOTNULL(node_def)) { PreInputInit(op_def, input_tensors, input_tensors_as_shapes); if (!construction_status_.ok()) return; @@ -43,19 +45,30 @@ InferenceContext::InferenceContext( } inputs_.push_back(shape); } - PostInputInit(); + std::vector<ShapeHandle> handle_shapes; + for (const auto& p : input_handle_shapes) { + ShapeHandle shape; + construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); + if (!construction_status_.ok()) { + return; + } + handle_shapes.push_back(shape); + } + PostInputInit(handle_shapes, input_handle_dtypes); } InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, - const std::vector<ShapeHandle>& input_tensors_as_shapes) + const std::vector<ShapeHandle>& input_tensors_as_shapes, + const std::vector<ShapeHandle>& input_handle_shapes, + const std::vector<DataType>& input_handle_dtypes) : node_def_(*CHECK_NOTNULL(node_def)) { PreInputInit(op_def, input_tensors, input_tensors_as_shapes); if (!construction_status_.ok()) return; inputs_ = input_shapes; - PostInputInit(); + PostInputInit(input_handle_shapes, input_handle_dtypes); } InferenceContext::~InferenceContext() { @@ -124,15 +137,44 @@ void InferenceContext::PreInputInit( for (int i = 0; i < num_outputs; ++i) { outputs_.push_back(nullptr); } + output_handle_shape_.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + output_handle_shape_.push_back(UnknownShape()); + } + output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID); } -void InferenceContext::PostInputInit() { +void InferenceContext::PostInputInit( + const std::vector<ShapeHandle>& input_handle_shapes, + const std::vector<DataType>& input_handle_dtypes) { int num_inputs_from_node_def = 0; for (const auto& e : input_name_map_) { num_inputs_from_node_def = std::max(num_inputs_from_node_def, e.second.second); } + // Allow passing empty shapes/dtypes to avoid changing every single test. + if (input_handle_shapes.empty()) { + input_handle_shape_.resize(inputs_.size()); + } else { + input_handle_shape_ = input_handle_shapes; + if (input_handle_shape_.size() != inputs_.size()) { + construction_status_ = errors::InvalidArgument( + "Wrong number of handle shapes passed; expected ", inputs_.size(), + " got ", input_handle_shape_.size()); + } + } + if (input_handle_dtypes.empty()) { + input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID); + } else { + input_handle_dtype_ = input_handle_dtypes; + if (input_handle_dtype_.size() != inputs_.size()) { + construction_status_ = errors::InvalidArgument( + "Wrong number of handle dtypes passed; expected ", inputs_.size(), + " got ", input_handle_dtype_.size()); + } + } + if (inputs_.size() != num_inputs_from_node_def) { construction_status_ = errors::InvalidArgument( "Wrong number of inputs passed: ", inputs_.size(), " while ", @@ -737,6 +779,13 @@ Status InferenceContext::AttachContext(const Status& status) { strings::StrCat(status.error_message(), error_context)); } +ShapeHandle InferenceContext::input_handle_shape(int idx) { + if (!input_handle_shape_[idx].IsSet()) { + input_handle_shape_[idx] = UnknownShape(); + } + return input_handle_shape_[idx]; +} + // ----------------------------------------------------------------------------- // ShapeManager // ----------------------------------------------------------------------------- diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index f5befc15a1..e02490efd9 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -147,7 +147,9 @@ class InferenceContext { InferenceContext(const NodeDef* node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, - const std::vector<ShapeHandle>& input_tensors_as_shapes); + const std::vector<ShapeHandle>& input_tensors_as_shapes, + const std::vector<ShapeHandle>& input_handle_shapes, + const std::vector<DataType>& input_handle_dtypes); // <input_tensors> is NULL-padded to be the same size as <input_shapes>. // @@ -162,7 +164,9 @@ class InferenceContext { InferenceContext(const NodeDef* node_def, const OpDef& op_def, const std::vector<TensorShapeProto>& input_shapes, const std::vector<const Tensor*>& input_tensors, - const std::vector<ShapeHandle>& input_tensors_as_shapes); + const std::vector<ShapeHandle>& input_tensors_as_shapes, + const std::vector<TensorShapeProto>& input_handle_shapes, + const std::vector<DataType>& input_handle_dtypes); ~InferenceContext(); @@ -231,12 +235,12 @@ class InferenceContext { } return s->dims_[idx]; } - int32 Rank(ShapeHandle s) { return s->rank_; } - bool RankKnown(ShapeHandle s) { return Rank(s) != kUnknownRank; } - inline int64 Value(DimensionOrConstant d) { + int32 Rank(ShapeHandle s) const { return s->rank_; } + bool RankKnown(ShapeHandle s) const { return Rank(s) != kUnknownRank; } + inline int64 Value(DimensionOrConstant d) const { return d.dim.IsSet() ? d.dim->value_ : d.val; } - inline bool ValueKnown(DimensionOrConstant d) { + inline bool ValueKnown(DimensionOrConstant d) const { return Value(d) != kUnknownDim; } @@ -391,6 +395,30 @@ class InferenceContext { Status construction_status() const { return construction_status_; } + // Methods to propagate shape and dtype on edges of handles. Handles are the + // dtype DT_RESOURCE which can be used to access state stored in a + // ResourceManager. When ops (such as variables) consume these handles to + // produce tensors they might need to know side-information about the shapes + // and dtypes of tensors which can be accessed via the handle. These methods + // propagate that information. Output handle dtypes and shapes are ignored if + // the output tensor is not of type DT_RESOURCE. + ShapeHandle input_handle_shape(int idx); + DataType input_handle_dtype(int idx) const { + return input_handle_dtype_[idx]; + } + void set_output_handle_shape(int idx, ShapeHandle shape) { + output_handle_shape_[idx] = shape; + } + void set_output_handle_dtype(int idx, DataType dtype) { + output_handle_dtype_[idx] = dtype; + } + ShapeHandle output_handle_shape(int idx) const { + return output_handle_shape_[idx]; + } + DataType output_handle_dtype(int idx) const { + return output_handle_dtype_[idx]; + } + // Validates the 3 component tensors of a sparse tensor have the proper // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. Status ValidateSparseTensor(ShapeHandle indices_shape, @@ -481,7 +509,8 @@ class InferenceContext { void PreInputInit(const OpDef& op_def, const std::vector<const Tensor*>& input_tensors, const std::vector<ShapeHandle>& input_tensors_as_shapes); - void PostInputInit(); + void PostInputInit(const std::vector<ShapeHandle>& input_handle_shapes, + const std::vector<DataType>& input_handle_dtypes); DimensionHandle GetDimension(const DimensionOrConstant& d); @@ -510,6 +539,11 @@ class InferenceContext { std::vector<ShapeHandle> input_tensors_as_shapes_; std::vector<bool> requested_input_tensor_as_partial_shape_; + std::vector<ShapeHandle> input_handle_shape_; + std::vector<DataType> input_handle_dtype_; + std::vector<ShapeHandle> output_handle_shape_; + std::vector<DataType> output_handle_dtype_; + const NodeDef& node_def_; NameRangeMap input_name_map_; NameRangeMap output_name_map_; diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 06096bfdcc..8d6b4ac021 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -71,7 +71,8 @@ TEST_F(ShapeInferenceTest, InputOutputByName) { .Attr("N", 3) .Input(FakeInput(DT_FLOAT)) .Finalize(&def); - InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {}); + InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {}, + {}, {}); EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0)))); EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1)))); @@ -107,7 +108,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) { TEST_F(ShapeInferenceTest, DimensionOrConstant) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}, {}); EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(InferenceContext::kUnknownDim)); EXPECT_EQ(1, c.Value(1)); @@ -122,7 +123,7 @@ TEST_F(ShapeInferenceTest, Run) { NodeDef def; def.set_name("foo"); def.set_op("foo_op"); - InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {}); + InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {}, {}, {}); { auto fn = [](InferenceContext* c) { @@ -154,7 +155,7 @@ TEST_F(ShapeInferenceTest, Run) { TEST_F(ShapeInferenceTest, RankAndDimInspection) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})}, - {}, {}); + {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(2, c.num_outputs()); @@ -195,7 +196,8 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) { TEST_F(ShapeInferenceTest, NumElements) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 2), - {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}); + {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {}, + {}); EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0)))); EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1)))); @@ -208,7 +210,8 @@ TEST_F(ShapeInferenceTest, NumElements) { TEST_F(ShapeInferenceTest, WithRank) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}); + InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}, + {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -246,7 +249,8 @@ TEST_F(ShapeInferenceTest, WithRank) { TEST_F(ShapeInferenceTest, WithRankAtMost) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}); + InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}, + {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -284,7 +288,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) { TEST_F(ShapeInferenceTest, WithRankAtLeast) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}); + InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}, + {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -322,7 +327,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) { TEST_F(ShapeInferenceTest, WithValue) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}, {}); auto d0 = c.Dim(c.input(0), 0); auto d1 = c.Dim(c.input(0), 1); @@ -363,7 +368,8 @@ TEST_F(ShapeInferenceTest, WithValue) { TEST_F(ShapeInferenceTest, MergeDim) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {}, {}, + {}); auto d2 = c.Dim(c.input(0), 0); auto d_unknown = c.Dim(c.input(0), 1); @@ -412,7 +418,7 @@ TEST_F(ShapeInferenceTest, MergeShape) { InferenceContext c(&def, MakeOpDef(7, 2), {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}), Unknown(), S({1})}, - {}, {}); + {}, {}, {}, {}); auto s_unknown = c.input(0); auto s_1_2 = c.input(1); @@ -483,7 +489,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) { { Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}), }, - {}, {}); + {}, {}, {}, {}); auto s_unknown = c.input(0); auto s_u_2 = c.input(1); @@ -536,7 +542,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) { TEST_F(ShapeInferenceTest, Subshape) { NodeDef def; InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()}, - {}, {}); + {}, {}, {}, {}); ShapeHandle unknown = c.input(1); ShapeHandle out; @@ -611,7 +617,7 @@ TEST_F(ShapeInferenceTest, Subshape) { TEST_F(ShapeInferenceTest, Concatenate) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 2), - {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}); + {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -637,7 +643,8 @@ TEST_F(ShapeInferenceTest, Concatenate) { TEST_F(ShapeInferenceTest, ReplaceDim) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {}); + InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {}, + {}, {}); auto in = c.input(0); auto unknown = c.input(1); @@ -668,7 +675,8 @@ TEST_F(ShapeInferenceTest, ReplaceDim) { TEST_F(ShapeInferenceTest, MakeShape) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}, {}, + {}); std::vector<DimensionHandle> dims; auto in0 = c.input(0); @@ -693,7 +701,7 @@ TEST_F(ShapeInferenceTest, MakeShape) { TEST_F(ShapeInferenceTest, UnknownShape) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto u0 = c.UnknownShape(); auto u1 = c.UnknownShape(); @@ -705,7 +713,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) { TEST_F(ShapeInferenceTest, Scalar) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto s0 = c.Scalar(); EXPECT_EQ("[]", c.DebugString(s0)); @@ -716,7 +724,7 @@ TEST_F(ShapeInferenceTest, Scalar) { TEST_F(ShapeInferenceTest, Vector) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto s0 = c.Vector(1); EXPECT_EQ("[1]", c.DebugString(s0)); @@ -732,7 +740,7 @@ TEST_F(ShapeInferenceTest, Vector) { TEST_F(ShapeInferenceTest, Matrix) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto s0 = c.Matrix(1, 2); EXPECT_EQ("[1,2]", c.DebugString(s0)); @@ -754,7 +762,7 @@ TEST_F(ShapeInferenceTest, Matrix) { TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { auto create = [&](Tensor* t) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {}); + InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}, {}); ShapeHandle out; Status s = c.MakeShapeFromShapeTensor(0, &out); if (s.ok()) { @@ -806,7 +814,8 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { // Test when the input shape is wrong. { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}); + InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}, {}, + {}); ShapeHandle out; EXPECT_EQ("Shape must be rank 1 but is rank 2", c.MakeShapeFromShapeTensor(0, &out).error_message()); @@ -816,7 +825,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); TensorShapeProto proto; // With a set unknown rank. @@ -852,7 +861,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { TEST_F(ShapeInferenceTest, MakeDim) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto d0 = c.MakeDim(1); auto d1 = c.MakeDim(1); @@ -866,7 +875,7 @@ TEST_F(ShapeInferenceTest, MakeDim) { TEST_F(ShapeInferenceTest, UnknownDim) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto d0 = c.UnknownDim(); auto d1 = c.UnknownDim(); @@ -878,7 +887,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) { TEST_F(ShapeInferenceTest, UnknownShapeOfRank) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3); EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3)); @@ -892,7 +901,7 @@ TEST_F(ShapeInferenceTest, InputTensors) { const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30}); NodeDef def; InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})}, - {&t1, &t2}, {}); + {&t1, &t2}, {}, {}, {}); EXPECT_TRUE(c.input_tensor(0) == &t1); EXPECT_TRUE(c.input_tensor(1) == &t2); @@ -903,7 +912,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) { Tensor t1 = tensorflow::test::AsScalar<int32>(20); Tensor t2 = tensorflow::test::AsScalar<int32>(-1); NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {}); + InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {}, {}, + {}); DimensionHandle d; EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); @@ -934,7 +944,7 @@ TEST_F(ShapeInferenceTest, GetAttr) { .ok()); std::vector<ShapeHandle> empty; - InferenceContext c(&def, op_reg_data.op_def, empty, {}, {}); + InferenceContext c(&def, op_reg_data.op_def, empty, {}, {}, {}, {}); string value; EXPECT_TRUE(c.GetAttr("foo", &value).ok()); EXPECT_EQ("bar", value); @@ -942,7 +952,8 @@ TEST_F(ShapeInferenceTest, GetAttr) { TEST_F(ShapeInferenceTest, Divide) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}, {}, + {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1004,7 +1015,7 @@ TEST_F(ShapeInferenceTest, Divide) { TEST_F(ShapeInferenceTest, Add) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, {}, {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1055,7 +1066,7 @@ TEST_F(ShapeInferenceTest, Add) { TEST_F(ShapeInferenceTest, Subtract) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}, {}, {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1104,7 +1115,7 @@ TEST_F(ShapeInferenceTest, Subtract) { TEST_F(ShapeInferenceTest, Multiply) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}, {}, {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1157,7 +1168,7 @@ TEST_F(ShapeInferenceTest, Multiply) { TEST_F(ShapeInferenceTest, FullyDefined) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}); + InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); // No rank or missing dimension information should return false. EXPECT_FALSE(c.FullyDefined(c.UnknownShape())); @@ -1170,7 +1181,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) { TEST_F(ShapeInferenceTest, Min) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}, {}, {}); auto s = c.input(0); auto d_1 = c.Dim(s, 0); @@ -1218,7 +1229,7 @@ TEST_F(ShapeInferenceTest, Min) { TEST_F(ShapeInferenceTest, Max) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, {}, {}); auto s = c.input(0); auto d_1 = c.Dim(s, 0); @@ -1256,7 +1267,7 @@ TEST_F(ShapeInferenceTest, Max) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()}, - {}, {}); + {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1269,7 +1280,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {}, - {}); + {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1281,8 +1292,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, - {}); + InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {}, + {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1295,8 +1306,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, - {}); + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {}, + {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1309,8 +1320,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, - {}); + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {}, + {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1324,7 +1335,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {}, - {}); + {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1337,7 +1348,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {}, - {}); + {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1350,7 +1361,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {}, - {}); + {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1363,7 +1374,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) { TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {}, - {}); + {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -1375,8 +1386,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) { TEST_F(ShapeInferenceTest, ValidateSparseTensor) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, - {}); + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {}, + {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index ed1d3ec520..a225824f82 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -44,8 +44,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, } shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def, - in_shapes, op.input_tensors, - {} /* input_tensors_as_shapes */); + in_shapes, op.input_tensors, {}, {}, {}); TF_RETURN_IF_ERROR(c.construction_status()); if (op_reg_data->shape_inference_fn == nullptr) { return errors::InvalidArgument( diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 92915685f5..47d74d4def 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -435,6 +435,7 @@ class Tensor { friend class VariableOp; // For access to set_shape friend class AutoReloadVariableOp; // For access to set_shape friend class TensorTestHelper; // For access to set_shape + friend class CreateVariableOp; // Creates a tensor with the input datatype, shape and buf. // diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 34954f0066..1954ebdc10 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1031,6 +1031,18 @@ tf_kernel_library( ) tf_kernel_library( + name = "resource_variable_ops", + srcs = ["resource_variable_ops.cc"], + deps = [ + ":variable_ops", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:resource_variable_ops_op_lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( name = "fact_op", prefix = "fact_op", deps = [ diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc new file mode 100644 index 0000000000..fbe66e8386 --- /dev/null +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -0,0 +1,49 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +REGISTER_RESOURCE_HANDLE_KERNEL(Var); + +class CreateVariableOp : public OpKernel { + public: + CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); + } + + void Compute(OpKernelContext* c) override { + Var* var = new Var(dtype_); + var->Ref(); + core::ScopedUnref ur(var); + OP_REQUIRES_OK(c, CreateResource<Var>(c, HandleFromInput(c, 0), var)); + // TODO(apassos): this currently does not initialize the tensor, so it's + // pointless, other than checking construction in tests. Fix this. + } + + private: + DataType dtype_; +}; +REGISTER_KERNEL_BUILDER(Name("CreateVariableOp").Device(DEVICE_CPU), + CreateVariableOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h index f754629a72..f44f94c51b 100644 --- a/tensorflow/core/kernels/variable_ops.h +++ b/tensorflow/core/kernels/variable_ops.h @@ -26,6 +26,26 @@ limitations under the License. namespace tensorflow { +// Resource stored by variables in the resource manager. +class Var : public ResourceBase { + public: + explicit Var(DataType dtype) : tensor_(dtype) {} + mutex* mu() { return &mu_; } + Tensor* tensor() { return &tensor_; } + + string DebugString() override { + return strings::StrCat(DataTypeString(tensor_.dtype()), "/", + tensor_.shape().DebugString()); + } + + private: + mutex mu_; + Tensor tensor_; + + ~Var() override {} + TF_DISALLOW_COPY_AND_ASSIGN(Var); +}; + class VariableOp : public OpKernel { public: explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) { @@ -59,25 +79,6 @@ class VariableOp : public OpKernel { } private: - class Var : public ResourceBase { - public: - explicit Var(DataType dtype) : tensor_(dtype) {} - mutex* mu() { return &mu_; } - Tensor* tensor() { return &tensor_; } - - string DebugString() override { - return strings::StrCat(DataTypeString(tensor_.dtype()), "/", - tensor_.shape().DebugString()); - } - - private: - mutex mu_; - Tensor tensor_; - - ~Var() override {} - TF_DISALLOW_COPY_AND_ASSIGN(Var); - }; - DataType dtype_; TensorShape shape_; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 6e076a092e..58b95fcff1 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1252,7 +1252,12 @@ REGISTER_OP("Identity") .Input("input: T") .Output("output: T") .Attr("T: type") - .SetShapeFn(shape_inference::UnchangedShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output_handle_dtype(0, c->input_handle_dtype(0)); + c->set_output_handle_shape(0, c->input_handle_shape(0)); + return Status::OK(); + }) .Doc(R"Doc( Return a tensor with the same shape and contents as the input tensor or value. )Doc"); diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 8679739b70..90f3c60b64 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -153,6 +155,21 @@ TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) { INFER_OK(op, "[1,2,?,4,5];?;?", "in0"); } +TEST(ArrayOpsTest, Identity_ShapeFnHandles) { + const char* op_name = "Identity"; + ShapeInferenceTestOp op(op_name); + // Check that handle dtypes are preserved. + const OpRegistrationData* op_reg_data; + TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); + shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def, + {TensorShapeProto()}, {}, {}, {}, + {DT_BOOL}); + TF_ASSERT_OK(c.construction_status()); + ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr); + TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn)); + EXPECT_TRUE(c.output_handle_dtype(0) == DT_BOOL); +} + TEST(ArrayOpsTest, Diag_ShapeFn) { ShapeInferenceTestOp op("Diag"); INFER_OK(op, "?", "?"); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 8d3d9310a4..ff00214da3 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -911,6 +911,19 @@ REGISTER_OP("Select") .Output("output: T") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { + // Merge handle shape and dtype if applicable. + if (c->input_handle_dtype(1) != c->input_handle_dtype(2)) { + // TODO(apassos) resolve this in the manner of b/32476923 + return errors::InvalidArgument( + "Trying to merge handles pointing to different dtypes."); + } + c->set_output_handle_dtype(0, c->input_handle_dtype(1)); + ShapeHandle output_handle_shape; + TF_RETURN_IF_ERROR(c->Merge(c->input_handle_shape(1), + c->input_handle_shape(2), + &output_handle_shape)); + c->set_output_handle_shape(0, output_handle_shape); + // The inputs 'then' and 'else' must have the same shape. ShapeHandle data = c->input(1); ShapeHandle other = c->input(2); @@ -961,8 +974,9 @@ REGISTER_OP("Select") } c->set_output(0, data); + return Status::OK(); - }) + }) .Doc(R"doc( Selects elements from `t` or `e`, depending on `condition`. diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 79ae187342..6a1cc8e7eb 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -216,6 +216,37 @@ TEST(MathOpsTest, Select_ShapeFn) { INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]"); INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op, "[2,?,5];[?,?,3];[?,2,?]"); + + // Test that handle shapes were merged. + const OpRegistrationData* op_reg_data; + TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); + TensorShapeProto i0; + i0.add_dim()->set_size(1); + i0.add_dim()->set_size(-1); + TensorShapeProto i1; + i1.add_dim()->set_size(-1); + i1.add_dim()->set_size(2); + + ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr); + shape_inference::InferenceContext c( + &op.node_def, op_reg_data->op_def, + {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {}, + {TensorShapeProto(), i0, i1}, {}); + TF_ASSERT_OK(c.construction_status()); + TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn)); + EXPECT_TRUE(c.FullyDefined(c.output_handle_shape(0))); + EXPECT_EQ("[1,2]", c.DebugString(c.output_handle_shape(0))); + + // Expect an error when the shapes can't be merged. + TensorShapeProto i2; + i1.add_dim()->set_size(2); + i1.add_dim()->set_size(2); + shape_inference::InferenceContext c2( + &op.node_def, op_reg_data->op_def, + {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {}, + {TensorShapeProto(), i0, i2}, {}); + TF_ASSERT_OK(c.construction_status()); + EXPECT_FALSE(c2.Run(op_reg_data->shape_inference_fn).ok()); } TEST(MathOpsTest, Range_ShapeFn) { diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc new file mode 100644 index 0000000000..6211b07ac5 --- /dev/null +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -0,0 +1,80 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("VarHandleOp") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("dtype: type") + .Attr("shape: shape") + .Output("resource: resource") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + DataType t; + c->GetAttr("dtype", &t); + c->set_output_handle_dtype(0, t); + TensorShapeProto p; + c->GetAttr("shape", &p); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(p, &s)); + c->set_output_handle_shape(0, s); + return Status::OK(); + }) + .Doc(R"( +Creates a handle to a Variable resource. + +container: the container this variable is placed in. +shared_name: the name by which this variable is referred to. +dtype: the type of this variable. Must agree with the dtypes + of all ops using this variable. +shape: The (possibly partially specified) shape of this variable. +)"); + +REGISTER_OP("CreateVariableOp") + .Input("resource: resource") + .Input("value: dtype") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + DataType handle_dtype = c->input_handle_dtype(0); + DataType value_dtype; + c->GetAttr("dtype", &value_dtype); + if (handle_dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to initialize handle for variable with wrong dtype. " + "Expected ", + handle_dtype, " got ", value_dtype); + } + shape_inference::ShapeHandle s = c->input_handle_shape(0); + shape_inference::ShapeHandle value_shape = c->input(1); + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused)); + return Status::OK(); + }) + .Doc(R"( +Creates a variable resource. + +resource: handle to the resource in which to store the variable. +value: the value to set the new tensor to use. +dtype: the dtype of the value. +)"); + +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index dcb6f52605..da875c081a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -218,23 +218,6 @@ cc_library( ) cc_library( - name = "cpp_shape_inference", - srcs = ["framework/cpp_shape_inference.cc"], - hdrs = ["framework/cpp_shape_inference.h"], - copts = ["-Wno-sign-compare"], - visibility = ["//visibility:public"], - deps = [ - ":numpy_lib", - ":py_func_lib", - "//tensorflow/c:tf_status_helper", - "//tensorflow/core:framework", - "//tensorflow/core:protos_cc", - "//third_party/py/numpy:headers", - "//util/python:python_headers", - ], -) - -cc_library( name = "python_op_gen_main", srcs = [ "framework/python_op_gen_main.cc", @@ -284,6 +267,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":cpp_shape_inference_proto_py", ":framework_for_generated_wrappers", ":pywrap_tensorflow", ], @@ -661,6 +645,11 @@ tf_gen_op_wrapper_private_py( ) tf_gen_op_wrapper_private_py( + name = "resource_variable_ops_gen", + require_shape_functions = True, +) + +tf_gen_op_wrapper_private_py( name = "script_ops_gen", require_shape_functions = True, ) @@ -990,6 +979,16 @@ py_library( ) py_library( + name = "resource_variable_ops", + srcs = ["ops/resource_variable_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework", + ":resource_variable_ops_gen", + ], +) + +py_library( name = "nn", srcs = ["ops/nn.py"], srcs_version = "PY2AND3", @@ -1409,6 +1408,7 @@ py_library( ":partitioned_variables", ":random_ops", ":random_ops_gen", + ":resource_variable_ops", ":resources", ":rnn", ":rnn_cell", @@ -1690,6 +1690,7 @@ tf_proto_library( ["**/*.proto"], exclude = [ "util/protobuf/compare_test.proto", + "framework/cpp_shape_inference.proto", ], ), go_api_version = 2, @@ -1701,6 +1702,13 @@ tf_proto_library_py( srcs = ["util/protobuf/compare_test.proto"], ) +tf_proto_library( + name = "cpp_shape_inference_proto", + srcs = ["framework/cpp_shape_inference.proto"], + cc_api_version = 2, + cc_libs = ["//tensorflow/core:protos_all_cc"], +) + py_test( name = "protobuf_compare_test", size = "small", @@ -1767,6 +1775,24 @@ py_library( ], ) +cc_library( + name = "cpp_shape_inference", + srcs = ["framework/cpp_shape_inference.cc"], + hdrs = ["framework/cpp_shape_inference.h"], + copts = ["-Wno-sign-compare"], + visibility = ["//visibility:public"], + deps = [ + ":cpp_shape_inference_proto_cc", + ":numpy_lib", + ":py_func_lib", + "//tensorflow/c:tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:protos_cc", + "//third_party/py/numpy:headers", + "//util/python:python_headers", + ], +) + cuda_py_tests( name = "device_lib_test", size = "small", diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index a0c97f5f8f..09afe56b19 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np import six.moves -from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -567,8 +567,12 @@ def call_cpp_shape_fn(op, input_tensors_needed=None, the C++ shape function. Returns: - A TensorShape list of the output shapes of the op, as computed using the - C++ shape inference function registered for the op. + A dictionary with the following keys: + shapes: A TensorShape list of the output shapes of the op, as computed + using the C++ shape inference function registered for the op. + handle_shapes: A TensorShape list of the shapes for handle outputs, if + any. + handle_dtypes: A list of DataType enums for the handle outputs, if any. Raises: ValueError: If the C++ shape function returned an error (e.g. because the @@ -576,8 +580,16 @@ def call_cpp_shape_fn(op, input_tensors_needed=None, according to the shape function). """ node_def_str = op.node_def.SerializeToString() - input_shapes = [i.get_shape().as_proto().SerializeToString() for i in - op.inputs] + + def tensor_to_inference_result(t): + r = cpp_shape_inference_pb2.CppShapeInferenceResult() + r.shape.CopyFrom(t.get_shape().as_proto()) + # pylint: disable=protected-access + r.handle_shape.CopyFrom(t._handle_shape) + r.handle_dtype = t._handle_dtype + # pylint: enable=protected-access + return r.SerializeToString() + input_shapes = [tensor_to_inference_result(i) for i in op.inputs] input_tensors = [None for i in input_shapes] if input_tensors_needed: @@ -596,10 +608,13 @@ def call_cpp_shape_fn(op, input_tensors_needed=None, raise ValueError(err.message) # Convert TensorShapeProto values in output_shapes. - result = [ - tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s)) + result_protos = [ + cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) for s in output_shapes ] + result = [r.shape for r in result_protos] + result_handle_shapes = [r.handle_shape for r in result_protos] + result_handle_dtypes = [r.handle_dtype for r in result_protos] if debug_python_shape_fn: try: @@ -616,4 +631,6 @@ def call_cpp_shape_fn(op, input_tensors_needed=None, str(result), str(python_result), str(op.node_def), ",".join([str(i.get_shape()) for i in op.inputs]))) - return result + return {"shapes": result, + "handle_shapes": result_handle_shapes, + "handle_dtypes": result_handle_dtypes} diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc index bb5a57e617..57b85e8118 100644 --- a/tensorflow/python/framework/cpp_shape_inference.cc +++ b/tensorflow/python/framework/cpp_shape_inference.cc @@ -20,12 +20,32 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/python/framework/cpp_shape_inference.pb.h" #include "tensorflow/python/lib/core/py_func.h" namespace tensorflow { namespace swig { namespace { +void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s, + tensorflow::shape_inference::InferenceContext* c, + TensorShapeProto* out) { + if (c->RankKnown(s)) { + const int32 rank = c->Rank(s); + for (int i = 0; i < rank; ++i) { + shape_inference::DimensionHandle d = c->Dim(s, i); + auto* out_dim = out->add_dim(); + if (c->ValueKnown(d)) { + out_dim->set_size(c->Value(d)); + } else { + out_dim->set_size(-1); + } + } + } else { + out->set_unknown_rank(true); + } +} + Status RunCppShapeInferenceImpl( const string& serialized_node_def, const std::vector<string>& input_serialized_shapes, @@ -49,12 +69,21 @@ Status RunCppShapeInferenceImpl( // Convert input shapes. std::vector<TensorShapeProto> input_shapes; + std::vector<TensorShapeProto> input_handle_shapes; + std::vector<DataType> input_handle_dtypes; input_shapes.resize(input_serialized_shapes.size()); + input_handle_shapes.resize(input_serialized_shapes.size()); + input_handle_dtypes.resize(input_serialized_shapes.size()); + CppShapeInferenceResult tmp; for (int i = 0; i < input_serialized_shapes.size(); ++i) { - if (!input_shapes[i].ParseFromString(input_serialized_shapes[i])) { + tmp.Clear(); + if (!tmp.ParseFromString(input_serialized_shapes[i])) { return errors::InvalidArgument( "Error parsing shape proto during cpp shape inference"); } + input_shapes[i].Swap(tmp.mutable_shape()); + input_handle_dtypes[i] = tmp.handle_dtype(); + input_handle_shapes[i].Swap(tmp.mutable_handle_shape()); } // Convert input tensor values; @@ -73,34 +102,23 @@ Status RunCppShapeInferenceImpl( } // Run shape inference. - // TODO(cwhipkey): pass a value for input_tensors_as_shapes. tensorflow::shape_inference::InferenceContext c( &node, op_reg_data->op_def, input_shapes, input_tensors, - {} /* input_tensors_as_shapes */); + {} /* input_tensors_as_shapes */, input_handle_shapes, + input_handle_dtypes); TF_RETURN_IF_ERROR(c.construction_status()); TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); // Convert output shapes. output_tensor_shape_protos->resize(c.num_outputs()); - TensorShapeProto out; + CppShapeInferenceResult out; for (int i = 0; i < c.num_outputs(); ++i) { - shape_inference::ShapeHandle s = c.output(i); out.Clear(); - if (c.RankKnown(s)) { - const int32 rank = c.Rank(s); - for (int i = 0; i < rank; ++i) { - shape_inference::DimensionHandle d = c.Dim(s, i); - auto* out_dim = out.add_dim(); - if (c.ValueKnown(d)) { - out_dim->set_size(c.Value(d)); - } else { - out_dim->set_size(-1); - } - } - } else { - out.set_unknown_rank(true); - } + ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape()); + ProtoFromShapeHandle(c.output_handle_shape(i), &c, + out.mutable_handle_shape()); + out.set_handle_dtype(c.output_handle_dtype(i)); CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i])); } diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h index a2f52227d6..f91af8e1a8 100644 --- a/tensorflow/python/framework/cpp_shape_inference.h +++ b/tensorflow/python/framework/cpp_shape_inference.h @@ -36,7 +36,7 @@ namespace swig { // inference was successful. // // On success, <*output_shapes> is populated with the inferred output shapes (as -// serialized TensorShapeProtos). +// serialized CppShapeInferenceResult protos). // <*output_shapes> must be empty when this function is called. // // This is temporary code to be used during the migration diff --git a/tensorflow/python/framework/cpp_shape_inference.proto b/tensorflow/python/framework/cpp_shape_inference.proto new file mode 100644 index 0000000000..c7d59cc55e --- /dev/null +++ b/tensorflow/python/framework/cpp_shape_inference.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; + +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; + +message CppShapeInferenceResult { + TensorShapeProto shape = 1; + TensorShapeProto handle_shape = 2; + DataType handle_dtype = 3; +}
\ No newline at end of file diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 6d80da1300..be9de08ba3 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -32,6 +32,8 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import versions_pb2 from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes @@ -297,6 +299,10 @@ class Tensor(object): # to easily navigate a computation graph. self._consumers = [] + # Attributes used for C++ shape inference. Not inspected, only forwarded. + self._handle_shape = tensor_shape_pb2.TensorShapeProto() + self._handle_dtype = types_pb2.DT_INVALID + @property def op(self): """The `Operation` that produces this tensor as an output.""" @@ -1791,10 +1797,22 @@ def set_shapes_for_outputs(op): if shapes is None: raise RuntimeError( "Shape function for op %s did not return any shapes" % op) + elif isinstance(shapes, dict): + # Returned by call_cpp_shape_fn + shapes_dict = shapes + shapes = shapes_dict["shapes"] + handle_shapes = shapes_dict["handle_shapes"] + handle_dtypes = shapes_dict["handle_dtypes"] + for output, handle_shape, handle_dtype in zip(op.outputs, handle_shapes, handle_dtypes): + # pylint: disable=protected-access + output._handle_shape = handle_shape + output._handle_dtype = handle_dtype + # pylint: enable=protected-access + if len(op.outputs) != len(shapes): raise RuntimeError( - "Shape function for op %s returned %d shapes but expected %d" % - (op, len(shapes), len(op.outputs))) + "Shape function for op %s returned %d shapes but expected %d %s %s" % + (op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes))) for output, s in zip(op.outputs, shapes): output.set_shape(s) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index baa48ec8a4..fe74b3426c 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -256,6 +256,16 @@ tf_py_test( ) tf_py_test( + name = "resource_variable_ops_test", + size = "small", + srcs = ["resource_variable_ops_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:resource_variable_ops", + ], +) + +tf_py_test( name = "save_restore_ops_test", size = "small", srcs = ["save_restore_ops_test.py"], diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py new file mode 100644 index 0000000000..cb4375ce91 --- /dev/null +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -0,0 +1,51 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.resource_variable_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + +class ResourceVariableOpsTest(test_util.TensorFlowTestCase): + + def testHandleDtypeShapeMatch(self): + with self.test_session(): + handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) + with self.assertRaises(ValueError): + resource_variable_ops.create_variable_op( + handle, constant_op.constant(0.0, dtype=dtypes.float32)).run() + with self.assertRaises(ValueError): + resource_variable_ops.create_variable_op( + handle, constant_op.constant([0], dtype=dtypes.int32)).run() + resource_variable_ops.create_variable_op( + handle, constant_op.constant(0, dtype=dtypes.int32)).run() + + def testDtypeSurvivesIdentity(self): + with self.test_session(): + handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) + id_handle = array_ops.identity(handle) + resource_variable_ops.create_variable_op( + id_handle, constant_op.constant(0, dtype=dtypes.int32)).run() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py new file mode 100644 index 0000000000..7db9731e19 --- /dev/null +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -0,0 +1,30 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops to use variables as resources.""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import ops +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_resource_variable_ops import * +# pylint: enable=wildcard-import + +ops.RegisterShape("VarHandleOp")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("CreateVariableOp")(common_shapes.call_cpp_shape_fn) |