aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-31 13:11:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 14:19:03 -0700
commit41734d78d3facf652c25b2a2761aadd978b3f2ef (patch)
tree931685da30dfcb84077664dede0b33d93c609397
parent1dd44f3ecc38cdb3df95cb3599f7b843f6d5062a (diff)
Automated rollback of change 137740850
Change: 137747341
-rw-r--r--tensorflow/contrib/cmake/tf_python.cmake40
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc17
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc58
-rw-r--r--tensorflow/core/framework/node_def_util.h8
-rw-r--r--tensorflow/core/framework/resource_mgr.h1
-rw-r--r--tensorflow/core/framework/shape_inference.cc59
-rw-r--r--tensorflow/core/framework/shape_inference.h48
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc111
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc3
-rw-r--r--tensorflow/core/framework/tensor.h1
-rw-r--r--tensorflow/core/kernels/BUILD12
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc49
-rw-r--r--tensorflow/core/kernels/variable_ops.h39
-rw-r--r--tensorflow/core/ops/array_ops.cc7
-rw-r--r--tensorflow/core/ops/array_ops_test.cc17
-rw-r--r--tensorflow/core/ops/math_ops.cc16
-rw-r--r--tensorflow/core/ops/math_ops_test.cc31
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc80
-rw-r--r--tensorflow/python/BUILD60
-rw-r--r--tensorflow/python/framework/common_shapes.py33
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.cc56
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.h2
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.proto13
-rw-r--r--tensorflow/python/framework/ops.py22
-rw-r--r--tensorflow/python/kernel_tests/BUILD10
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py51
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py30
28 files changed, 714 insertions, 162 deletions
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 8cdecf706a..907158b646 100644
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -94,6 +94,38 @@ function(RELATIVE_PROTOBUF_GENERATE_PYTHON ROOT_DIR SRCS)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
+function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR)
+ if(NOT ARGN)
+ message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_CPP() called without any proto files")
+ return()
+ endif()
+
+ set(${SRCS})
+ set(${HDRS})
+ foreach(FIL ${ARGN})
+ set(ABS_FIL ${ROOT_DIR}/${FIL})
+ get_filename_component(FIL_WE ${FIL} NAME_WE)
+ get_filename_component(FIL_DIR ${ABS_FIL} PATH)
+ file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})
+
+ list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc")
+ list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h")
+
+ add_custom_command(
+ OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc"
+ "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h"
+ COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
+ ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS}
+ DEPENDS ${ABS_FIL} protobuf
+ COMMENT "Running C++ protocol buffer compiler on ${FIL}"
+ VERBATIM )
+ endforeach()
+
+ set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
+ set(${SRCS} ${${SRCS}} PARENT_SCOPE)
+ set(${HDRS} ${${HDRS}} PARENT_SCOPE)
+endfunction()
+
file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/*.proto"
"${tensorflow_source_dir}/tensorflow/python/*.proto"
@@ -102,6 +134,12 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON(
${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs}
)
+RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
+ ${tensorflow_source_dir} ${tf_protos_python_srcs}
+)
+
+add_library(tf_python_protos_cc ${PROTO_SRCS} ${PROTO_HDRS})
+
# tf_python_touchup_modules adds empty __init__.py files to all
# directories containing Python code, so that Python will recognize
# them as modules.
@@ -201,6 +239,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name)
)
target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE
tf_protos_cc
+ tf_python_protos_cc
${tensorflow_EXTERNAL_LIBRARIES}
)
@@ -312,6 +351,7 @@ target_link_libraries(pywrap_tensorflow
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
tf_protos_cc
+ tf_python_protos_cc
${PYTHON_LIBRARIES}
)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1c37921afc..79546ccd20 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -379,6 +379,7 @@ tf_gen_op_libs(
"no_op",
"parsing_ops",
"random_ops",
+ "resource_variable_ops",
"sdca_ops",
"script_ops",
"sendrecv_ops",
@@ -542,6 +543,7 @@ cc_library(
"//tensorflow/core/kernels:parsing",
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:required",
+ "//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index 1ddd483076..e1500ed1ad 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -42,6 +42,8 @@ Status ShapeRefiner::AddNode(const Node* node) {
// indexed by 'node's input.
std::vector<Node*> input_nodes(node->num_inputs());
std::vector<ShapeHandle> input_shapes(node->num_inputs());
+ std::vector<DataType> input_handle_dtypes(node->num_inputs());
+ std::vector<ShapeHandle> input_handle_shapes(node->num_inputs());
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
@@ -57,6 +59,15 @@ Status ShapeRefiner::AddNode(const Node* node) {
DCHECK_GE(e->dst_input(), 0);
input_nodes[e->dst_input()] = input;
input_shapes[e->dst_input()] = c->output(e->src_output());
+
+ // Only propagate handle xshape and dtype of edges which are carrying
+ // resource handles.
+ if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
+ input_handle_dtypes[e->dst_input()] =
+ c->output_handle_dtype(e->src_output());
+ input_handle_shapes[e->dst_input()] =
+ c->output_handle_shape(e->src_output());
+ }
}
// Get the shape function for this node
@@ -76,9 +87,9 @@ Status ShapeRefiner::AddNode(const Node* node) {
std::vector<ShapeHandle> input_tensors_as_shapes;
// Create the inference context for this node with the existing input shapes.
- std::unique_ptr<InferenceContext> c(
- new InferenceContext(&node->def(), node->op_def(), input_shapes,
- input_tensors, input_tensors_as_shapes));
+ std::unique_ptr<InferenceContext> c(new InferenceContext(
+ &node->def(), node->op_def(), input_shapes, input_tensors,
+ input_tensors_as_shapes, input_handle_shapes, input_handle_dtypes));
if (!c->construction_status().ok()) {
return c->construction_status();
}
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index 7196bc8304..ca1326844c 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -56,7 +56,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
.Input({{"data", 0, DT_FLOAT}})
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {});
+ InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {}, {}, {});
TF_EXPECT_OK(NoOutputs(&c));
EXPECT_EQ(0, c.num_outputs());
}
@@ -74,14 +74,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({})}, {}, {});
+ InferenceContext c(&def, op_def, {S({})}, {}, {}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
{
- InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {});
+ InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@@ -108,7 +108,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown inner dimension for one
- InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -126,7 +126,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Invalid rank.
- InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -136,7 +136,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown outer dimension
- InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -145,7 +145,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
- InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -156,7 +156,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
- InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {},
+ {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -174,7 +175,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {});
+ InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -191,7 +192,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -215,7 +216,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -224,7 +225,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Unknown ranks.
- InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {});
+ InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_FALSE(c.RankKnown(output));
@@ -232,7 +233,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Rank > 2
- InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {});
+ InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {},
+ {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
@@ -245,7 +247,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
@@ -258,8 +260,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {},
- {});
+ InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {},
+ {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
@@ -272,7 +274,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {});
+ InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {}, {},
+ {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[10,11,12]", c.DebugString(output));
@@ -280,7 +283,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Input rank not high enough
- InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {});
+ InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@@ -292,7 +295,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@@ -311,7 +314,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 10})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 10})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -319,7 +322,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Rank > 2
- InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {});
+ InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -331,7 +334,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -343,7 +346,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {});
+ InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {},
+ {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -355,7 +359,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {});
+ InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -363,7 +367,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Input rank not high enough
- InferenceContext c(&def, op_def, {S({3})}, {}, {});
+ InferenceContext c(&def, op_def, {S({3})}, {}, {}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@@ -374,7 +378,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {S({2, 3})}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3})}, {}, {}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 2d16760625..85b83c4d74 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -66,6 +66,14 @@ void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
AttrValueMap::value_type(name.ToString(), attr_value));
}
+// Adds an attr to an attr value map.
+template <class T>
+void AddAttr(StringPiece name, T&& value, AttrValueMap* map) {
+ AttrValue attr_value;
+ SetAttrValue(value, &attr_value);
+ map->insert(AttrValueMap::value_type(name.ToString(), attr_value));
+}
+
class AttrSlice {
public:
AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index f823462079..30bb2cb708 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <typeinfo>
#include <unordered_map>
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index da88b6a7ca..4aa32f6a84 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -31,7 +31,9 @@ InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes)
+ const std::vector<ShapeHandle>& input_tensors_as_shapes,
+ const std::vector<TensorShapeProto>& input_handle_shapes,
+ const std::vector<DataType>& input_handle_dtypes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
@@ -43,19 +45,30 @@ InferenceContext::InferenceContext(
}
inputs_.push_back(shape);
}
- PostInputInit();
+ std::vector<ShapeHandle> handle_shapes;
+ for (const auto& p : input_handle_shapes) {
+ ShapeHandle shape;
+ construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
+ if (!construction_status_.ok()) {
+ return;
+ }
+ handle_shapes.push_back(shape);
+ }
+ PostInputInit(handle_shapes, input_handle_dtypes);
}
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes)
+ const std::vector<ShapeHandle>& input_tensors_as_shapes,
+ const std::vector<ShapeHandle>& input_handle_shapes,
+ const std::vector<DataType>& input_handle_dtypes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
- PostInputInit();
+ PostInputInit(input_handle_shapes, input_handle_dtypes);
}
InferenceContext::~InferenceContext() {
@@ -124,15 +137,44 @@ void InferenceContext::PreInputInit(
for (int i = 0; i < num_outputs; ++i) {
outputs_.push_back(nullptr);
}
+ output_handle_shape_.reserve(num_outputs);
+ for (int i = 0; i < num_outputs; ++i) {
+ output_handle_shape_.push_back(UnknownShape());
+ }
+ output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID);
}
-void InferenceContext::PostInputInit() {
+void InferenceContext::PostInputInit(
+ const std::vector<ShapeHandle>& input_handle_shapes,
+ const std::vector<DataType>& input_handle_dtypes) {
int num_inputs_from_node_def = 0;
for (const auto& e : input_name_map_) {
num_inputs_from_node_def =
std::max(num_inputs_from_node_def, e.second.second);
}
+ // Allow passing empty shapes/dtypes to avoid changing every single test.
+ if (input_handle_shapes.empty()) {
+ input_handle_shape_.resize(inputs_.size());
+ } else {
+ input_handle_shape_ = input_handle_shapes;
+ if (input_handle_shape_.size() != inputs_.size()) {
+ construction_status_ = errors::InvalidArgument(
+ "Wrong number of handle shapes passed; expected ", inputs_.size(),
+ " got ", input_handle_shape_.size());
+ }
+ }
+ if (input_handle_dtypes.empty()) {
+ input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID);
+ } else {
+ input_handle_dtype_ = input_handle_dtypes;
+ if (input_handle_dtype_.size() != inputs_.size()) {
+ construction_status_ = errors::InvalidArgument(
+ "Wrong number of handle dtypes passed; expected ", inputs_.size(),
+ " got ", input_handle_dtype_.size());
+ }
+ }
+
if (inputs_.size() != num_inputs_from_node_def) {
construction_status_ = errors::InvalidArgument(
"Wrong number of inputs passed: ", inputs_.size(), " while ",
@@ -737,6 +779,13 @@ Status InferenceContext::AttachContext(const Status& status) {
strings::StrCat(status.error_message(), error_context));
}
+ShapeHandle InferenceContext::input_handle_shape(int idx) {
+ if (!input_handle_shape_[idx].IsSet()) {
+ input_handle_shape_[idx] = UnknownShape();
+ }
+ return input_handle_shape_[idx];
+}
+
// -----------------------------------------------------------------------------
// ShapeManager
// -----------------------------------------------------------------------------
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index f5befc15a1..e02490efd9 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -147,7 +147,9 @@ class InferenceContext {
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes);
+ const std::vector<ShapeHandle>& input_tensors_as_shapes,
+ const std::vector<ShapeHandle>& input_handle_shapes,
+ const std::vector<DataType>& input_handle_dtypes);
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
//
@@ -162,7 +164,9 @@ class InferenceContext {
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes);
+ const std::vector<ShapeHandle>& input_tensors_as_shapes,
+ const std::vector<TensorShapeProto>& input_handle_shapes,
+ const std::vector<DataType>& input_handle_dtypes);
~InferenceContext();
@@ -231,12 +235,12 @@ class InferenceContext {
}
return s->dims_[idx];
}
- int32 Rank(ShapeHandle s) { return s->rank_; }
- bool RankKnown(ShapeHandle s) { return Rank(s) != kUnknownRank; }
- inline int64 Value(DimensionOrConstant d) {
+ int32 Rank(ShapeHandle s) const { return s->rank_; }
+ bool RankKnown(ShapeHandle s) const { return Rank(s) != kUnknownRank; }
+ inline int64 Value(DimensionOrConstant d) const {
return d.dim.IsSet() ? d.dim->value_ : d.val;
}
- inline bool ValueKnown(DimensionOrConstant d) {
+ inline bool ValueKnown(DimensionOrConstant d) const {
return Value(d) != kUnknownDim;
}
@@ -391,6 +395,30 @@ class InferenceContext {
Status construction_status() const { return construction_status_; }
+ // Methods to propagate shape and dtype on edges of handles. Handles are the
+ // dtype DT_RESOURCE which can be used to access state stored in a
+ // ResourceManager. When ops (such as variables) consume these handles to
+ // produce tensors they might need to know side-information about the shapes
+ // and dtypes of tensors which can be accessed via the handle. These methods
+ // propagate that information. Output handle dtypes and shapes are ignored if
+ // the output tensor is not of type DT_RESOURCE.
+ ShapeHandle input_handle_shape(int idx);
+ DataType input_handle_dtype(int idx) const {
+ return input_handle_dtype_[idx];
+ }
+ void set_output_handle_shape(int idx, ShapeHandle shape) {
+ output_handle_shape_[idx] = shape;
+ }
+ void set_output_handle_dtype(int idx, DataType dtype) {
+ output_handle_dtype_[idx] = dtype;
+ }
+ ShapeHandle output_handle_shape(int idx) const {
+ return output_handle_shape_[idx];
+ }
+ DataType output_handle_dtype(int idx) const {
+ return output_handle_dtype_[idx];
+ }
+
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(ShapeHandle indices_shape,
@@ -481,7 +509,8 @@ class InferenceContext {
void PreInputInit(const OpDef& op_def,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
- void PostInputInit();
+ void PostInputInit(const std::vector<ShapeHandle>& input_handle_shapes,
+ const std::vector<DataType>& input_handle_dtypes);
DimensionHandle GetDimension(const DimensionOrConstant& d);
@@ -510,6 +539,11 @@ class InferenceContext {
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
+ std::vector<ShapeHandle> input_handle_shape_;
+ std::vector<DataType> input_handle_dtype_;
+ std::vector<ShapeHandle> output_handle_shape_;
+ std::vector<DataType> output_handle_dtype_;
+
const NodeDef& node_def_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 06096bfdcc..8d6b4ac021 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -71,7 +71,8 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
.Attr("N", 3)
.Input(FakeInput(DT_FLOAT))
.Finalize(&def);
- InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {});
+ InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {},
+ {}, {});
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
@@ -107,7 +108,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}, {});
EXPECT_EQ(InferenceContext::kUnknownDim,
c.Value(InferenceContext::kUnknownDim));
EXPECT_EQ(1, c.Value(1));
@@ -122,7 +123,7 @@ TEST_F(ShapeInferenceTest, Run) {
NodeDef def;
def.set_name("foo");
def.set_op("foo_op");
- InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {}, {}, {});
{
auto fn = [](InferenceContext* c) {
@@ -154,7 +155,7 @@ TEST_F(ShapeInferenceTest, Run) {
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})},
- {}, {});
+ {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@@ -195,7 +196,8 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
TEST_F(ShapeInferenceTest, NumElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2),
- {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {});
+ {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {},
+ {});
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
@@ -208,7 +210,8 @@ TEST_F(ShapeInferenceTest, NumElements) {
TEST_F(ShapeInferenceTest, WithRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
+ {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -246,7 +249,8 @@ TEST_F(ShapeInferenceTest, WithRank) {
TEST_F(ShapeInferenceTest, WithRankAtMost) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
+ {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -284,7 +288,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
+ {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -322,7 +327,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
TEST_F(ShapeInferenceTest, WithValue) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}, {});
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@@ -363,7 +368,8 @@ TEST_F(ShapeInferenceTest, WithValue) {
TEST_F(ShapeInferenceTest, MergeDim) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {}, {},
+ {});
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@@ -412,7 +418,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
InferenceContext c(&def, MakeOpDef(7, 2),
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
Unknown(), S({1})},
- {}, {});
+ {}, {}, {}, {});
auto s_unknown = c.input(0);
auto s_1_2 = c.input(1);
@@ -483,7 +489,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
{
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
},
- {}, {});
+ {}, {}, {}, {});
auto s_unknown = c.input(0);
auto s_u_2 = c.input(1);
@@ -536,7 +542,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
TEST_F(ShapeInferenceTest, Subshape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()},
- {}, {});
+ {}, {}, {}, {});
ShapeHandle unknown = c.input(1);
ShapeHandle out;
@@ -611,7 +617,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
TEST_F(ShapeInferenceTest, Concatenate) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2),
- {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {});
+ {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -637,7 +643,8 @@ TEST_F(ShapeInferenceTest, Concatenate) {
TEST_F(ShapeInferenceTest, ReplaceDim) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {});
+ InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {},
+ {}, {});
auto in = c.input(0);
auto unknown = c.input(1);
@@ -668,7 +675,8 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
TEST_F(ShapeInferenceTest, MakeShape) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}, {},
+ {});
std::vector<DimensionHandle> dims;
auto in0 = c.input(0);
@@ -693,7 +701,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
TEST_F(ShapeInferenceTest, UnknownShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto u0 = c.UnknownShape();
auto u1 = c.UnknownShape();
@@ -705,7 +713,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
TEST_F(ShapeInferenceTest, Scalar) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Scalar();
EXPECT_EQ("[]", c.DebugString(s0));
@@ -716,7 +724,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
TEST_F(ShapeInferenceTest, Vector) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Vector(1);
EXPECT_EQ("[1]", c.DebugString(s0));
@@ -732,7 +740,7 @@ TEST_F(ShapeInferenceTest, Vector) {
TEST_F(ShapeInferenceTest, Matrix) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Matrix(1, 2);
EXPECT_EQ("[1,2]", c.DebugString(s0));
@@ -754,7 +762,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
auto create = [&](Tensor* t) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {});
+ InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}, {});
ShapeHandle out;
Status s = c.MakeShapeFromShapeTensor(0, &out);
if (s.ok()) {
@@ -806,7 +814,8 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
// Test when the input shape is wrong.
{
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {});
+ InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}, {},
+ {});
ShapeHandle out;
EXPECT_EQ("Shape must be rank 1 but is rank 2",
c.MakeShapeFromShapeTensor(0, &out).error_message());
@@ -816,7 +825,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
TensorShapeProto proto;
// With a set unknown rank.
@@ -852,7 +861,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
TEST_F(ShapeInferenceTest, MakeDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto d0 = c.MakeDim(1);
auto d1 = c.MakeDim(1);
@@ -866,7 +875,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
TEST_F(ShapeInferenceTest, UnknownDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto d0 = c.UnknownDim();
auto d1 = c.UnknownDim();
@@ -878,7 +887,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
@@ -892,7 +901,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
- {&t1, &t2}, {});
+ {&t1, &t2}, {}, {}, {});
EXPECT_TRUE(c.input_tensor(0) == &t1);
EXPECT_TRUE(c.input_tensor(1) == &t2);
@@ -903,7 +912,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {});
+ InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {}, {},
+ {});
DimensionHandle d;
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
@@ -934,7 +944,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
.ok());
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, op_reg_data.op_def, empty, {}, {});
+ InferenceContext c(&def, op_reg_data.op_def, empty, {}, {}, {}, {});
string value;
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
EXPECT_EQ("bar", value);
@@ -942,7 +952,8 @@ TEST_F(ShapeInferenceTest, GetAttr) {
TEST_F(ShapeInferenceTest, Divide) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}, {},
+ {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1004,7 +1015,7 @@ TEST_F(ShapeInferenceTest, Divide) {
TEST_F(ShapeInferenceTest, Add) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1055,7 +1066,7 @@ TEST_F(ShapeInferenceTest, Add) {
TEST_F(ShapeInferenceTest, Subtract) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1104,7 +1115,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
TEST_F(ShapeInferenceTest, Multiply) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1157,7 +1168,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
TEST_F(ShapeInferenceTest, FullyDefined) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
// No rank or missing dimension information should return false.
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
@@ -1170,7 +1181,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
TEST_F(ShapeInferenceTest, Min) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@@ -1218,7 +1229,7 @@ TEST_F(ShapeInferenceTest, Min) {
TEST_F(ShapeInferenceTest, Max) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@@ -1256,7 +1267,7 @@ TEST_F(ShapeInferenceTest, Max) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
- {}, {});
+ {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1269,7 +1280,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
- {});
+ {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1281,8 +1292,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
+ {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1295,8 +1306,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
+ {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1309,8 +1320,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
+ {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1324,7 +1335,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
- {});
+ {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1337,7 +1348,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
- {});
+ {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1350,7 +1361,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
- {});
+ {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1363,7 +1374,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
- {});
+ {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1375,8 +1386,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
+ {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index ed1d3ec520..a225824f82 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -44,8 +44,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
}
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
- in_shapes, op.input_tensors,
- {} /* input_tensors_as_shapes */);
+ in_shapes, op.input_tensors, {}, {}, {});
TF_RETURN_IF_ERROR(c.construction_status());
if (op_reg_data->shape_inference_fn == nullptr) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 92915685f5..47d74d4def 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -435,6 +435,7 @@ class Tensor {
friend class VariableOp; // For access to set_shape
friend class AutoReloadVariableOp; // For access to set_shape
friend class TensorTestHelper; // For access to set_shape
+ friend class CreateVariableOp;
// Creates a tensor with the input datatype, shape and buf.
//
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 34954f0066..1954ebdc10 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1031,6 +1031,18 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "resource_variable_ops",
+ srcs = ["resource_variable_ops.cc"],
+ deps = [
+ ":variable_ops",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:resource_variable_ops_op_lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
name = "fact_op",
prefix = "fact_op",
deps = [
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
new file mode 100644
index 0000000000..fbe66e8386
--- /dev/null
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -0,0 +1,49 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+REGISTER_RESOURCE_HANDLE_KERNEL(Var);
+
+class CreateVariableOp : public OpKernel {
+ public:
+ CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ Var* var = new Var(dtype_);
+ var->Ref();
+ core::ScopedUnref ur(var);
+ OP_REQUIRES_OK(c, CreateResource<Var>(c, HandleFromInput(c, 0), var));
+ // TODO(apassos): this currently does not initialize the tensor, so it's
+ // pointless, other than checking construction in tests. Fix this.
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_KERNEL_BUILDER(Name("CreateVariableOp").Device(DEVICE_CPU),
+ CreateVariableOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index f754629a72..f44f94c51b 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -26,6 +26,26 @@ limitations under the License.
namespace tensorflow {
+// Resource stored by variables in the resource manager.
+class Var : public ResourceBase {
+ public:
+ explicit Var(DataType dtype) : tensor_(dtype) {}
+ mutex* mu() { return &mu_; }
+ Tensor* tensor() { return &tensor_; }
+
+ string DebugString() override {
+ return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
+ tensor_.shape().DebugString());
+ }
+
+ private:
+ mutex mu_;
+ Tensor tensor_;
+
+ ~Var() override {}
+ TF_DISALLOW_COPY_AND_ASSIGN(Var);
+};
+
class VariableOp : public OpKernel {
public:
explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -59,25 +79,6 @@ class VariableOp : public OpKernel {
}
private:
- class Var : public ResourceBase {
- public:
- explicit Var(DataType dtype) : tensor_(dtype) {}
- mutex* mu() { return &mu_; }
- Tensor* tensor() { return &tensor_; }
-
- string DebugString() override {
- return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
- tensor_.shape().DebugString());
- }
-
- private:
- mutex mu_;
- Tensor tensor_;
-
- ~Var() override {}
- TF_DISALLOW_COPY_AND_ASSIGN(Var);
- };
-
DataType dtype_;
TensorShape shape_;
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 6e076a092e..58b95fcff1 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1252,7 +1252,12 @@ REGISTER_OP("Identity")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
- .SetShapeFn(shape_inference::UnchangedShape)
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ c->set_output_handle_dtype(0, c->input_handle_dtype(0));
+ c->set_output_handle_shape(0, c->input_handle_shape(0));
+ return Status::OK();
+ })
.Doc(R"Doc(
Return a tensor with the same shape and contents as the input tensor or value.
)Doc");
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 8679739b70..90f3c60b64 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -153,6 +155,21 @@ TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) {
INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
}
+TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
+ const char* op_name = "Identity";
+ ShapeInferenceTestOp op(op_name);
+ // Check that handle dtypes are preserved.
+ const OpRegistrationData* op_reg_data;
+ TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
+ shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
+ {TensorShapeProto()}, {}, {}, {},
+ {DT_BOOL});
+ TF_ASSERT_OK(c.construction_status());
+ ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
+ TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
+ EXPECT_TRUE(c.output_handle_dtype(0) == DT_BOOL);
+}
+
TEST(ArrayOpsTest, Diag_ShapeFn) {
ShapeInferenceTestOp op("Diag");
INFER_OK(op, "?", "?");
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 8d3d9310a4..ff00214da3 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -911,6 +911,19 @@ REGISTER_OP("Select")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
+ // Merge handle shape and dtype if applicable.
+ if (c->input_handle_dtype(1) != c->input_handle_dtype(2)) {
+ // TODO(apassos) resolve this in the manner of b/32476923
+ return errors::InvalidArgument(
+ "Trying to merge handles pointing to different dtypes.");
+ }
+ c->set_output_handle_dtype(0, c->input_handle_dtype(1));
+ ShapeHandle output_handle_shape;
+ TF_RETURN_IF_ERROR(c->Merge(c->input_handle_shape(1),
+ c->input_handle_shape(2),
+ &output_handle_shape));
+ c->set_output_handle_shape(0, output_handle_shape);
+
// The inputs 'then' and 'else' must have the same shape.
ShapeHandle data = c->input(1);
ShapeHandle other = c->input(2);
@@ -961,8 +974,9 @@ REGISTER_OP("Select")
}
c->set_output(0, data);
+
return Status::OK();
- })
+ })
.Doc(R"doc(
Selects elements from `t` or `e`, depending on `condition`.
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index 79ae187342..6a1cc8e7eb 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -216,6 +216,37 @@ TEST(MathOpsTest, Select_ShapeFn) {
INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
"[2,?,5];[?,?,3];[?,2,?]");
+
+ // Test that handle shapes were merged.
+ const OpRegistrationData* op_reg_data;
+ TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
+ TensorShapeProto i0;
+ i0.add_dim()->set_size(1);
+ i0.add_dim()->set_size(-1);
+ TensorShapeProto i1;
+ i1.add_dim()->set_size(-1);
+ i1.add_dim()->set_size(2);
+
+ ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
+ shape_inference::InferenceContext c(
+ &op.node_def, op_reg_data->op_def,
+ {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
+ {TensorShapeProto(), i0, i1}, {});
+ TF_ASSERT_OK(c.construction_status());
+ TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
+ EXPECT_TRUE(c.FullyDefined(c.output_handle_shape(0)));
+ EXPECT_EQ("[1,2]", c.DebugString(c.output_handle_shape(0)));
+
+ // Expect an error when the shapes can't be merged.
+ TensorShapeProto i2;
+ i1.add_dim()->set_size(2);
+ i1.add_dim()->set_size(2);
+ shape_inference::InferenceContext c2(
+ &op.node_def, op_reg_data->op_def,
+ {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
+ {TensorShapeProto(), i0, i2}, {});
+ TF_ASSERT_OK(c.construction_status());
+ EXPECT_FALSE(c2.Run(op_reg_data->shape_inference_fn).ok());
}
TEST(MathOpsTest, Range_ShapeFn) {
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
new file mode 100644
index 0000000000..6211b07ac5
--- /dev/null
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -0,0 +1,80 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ============================================================================
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("VarHandleOp")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("dtype: type")
+ .Attr("shape: shape")
+ .Output("resource: resource")
+ .SetIsStateful()
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ DataType t;
+ c->GetAttr("dtype", &t);
+ c->set_output_handle_dtype(0, t);
+ TensorShapeProto p;
+ c->GetAttr("shape", &p);
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(p, &s));
+ c->set_output_handle_shape(0, s);
+ return Status::OK();
+ })
+ .Doc(R"(
+Creates a handle to a Variable resource.
+
+container: the container this variable is placed in.
+shared_name: the name by which this variable is referred to.
+dtype: the type of this variable. Must agree with the dtypes
+ of all ops using this variable.
+shape: The (possibly partially specified) shape of this variable.
+)");
+
+REGISTER_OP("CreateVariableOp")
+ .Input("resource: resource")
+ .Input("value: dtype")
+ .Attr("dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType handle_dtype = c->input_handle_dtype(0);
+ DataType value_dtype;
+ c->GetAttr("dtype", &value_dtype);
+ if (handle_dtype != value_dtype) {
+ return errors::InvalidArgument(
+ "Trying to initialize handle for variable with wrong dtype. "
+ "Expected ",
+ handle_dtype, " got ", value_dtype);
+ }
+ shape_inference::ShapeHandle s = c->input_handle_shape(0);
+ shape_inference::ShapeHandle value_shape = c->input(1);
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
+ return Status::OK();
+ })
+ .Doc(R"(
+Creates a variable resource.
+
+resource: handle to the resource in which to store the variable.
+value: the value to set the new tensor to use.
+dtype: the dtype of the value.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index dcb6f52605..da875c081a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -218,23 +218,6 @@ cc_library(
)
cc_library(
- name = "cpp_shape_inference",
- srcs = ["framework/cpp_shape_inference.cc"],
- hdrs = ["framework/cpp_shape_inference.h"],
- copts = ["-Wno-sign-compare"],
- visibility = ["//visibility:public"],
- deps = [
- ":numpy_lib",
- ":py_func_lib",
- "//tensorflow/c:tf_status_helper",
- "//tensorflow/core:framework",
- "//tensorflow/core:protos_cc",
- "//third_party/py/numpy:headers",
- "//util/python:python_headers",
- ],
-)
-
-cc_library(
name = "python_op_gen_main",
srcs = [
"framework/python_op_gen_main.cc",
@@ -284,6 +267,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":cpp_shape_inference_proto_py",
":framework_for_generated_wrappers",
":pywrap_tensorflow",
],
@@ -661,6 +645,11 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "resource_variable_ops_gen",
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_private_py(
name = "script_ops_gen",
require_shape_functions = True,
)
@@ -990,6 +979,16 @@ py_library(
)
py_library(
+ name = "resource_variable_ops",
+ srcs = ["ops/resource_variable_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework",
+ ":resource_variable_ops_gen",
+ ],
+)
+
+py_library(
name = "nn",
srcs = ["ops/nn.py"],
srcs_version = "PY2AND3",
@@ -1409,6 +1408,7 @@ py_library(
":partitioned_variables",
":random_ops",
":random_ops_gen",
+ ":resource_variable_ops",
":resources",
":rnn",
":rnn_cell",
@@ -1690,6 +1690,7 @@ tf_proto_library(
["**/*.proto"],
exclude = [
"util/protobuf/compare_test.proto",
+ "framework/cpp_shape_inference.proto",
],
),
go_api_version = 2,
@@ -1701,6 +1702,13 @@ tf_proto_library_py(
srcs = ["util/protobuf/compare_test.proto"],
)
+tf_proto_library(
+ name = "cpp_shape_inference_proto",
+ srcs = ["framework/cpp_shape_inference.proto"],
+ cc_api_version = 2,
+ cc_libs = ["//tensorflow/core:protos_all_cc"],
+)
+
py_test(
name = "protobuf_compare_test",
size = "small",
@@ -1767,6 +1775,24 @@ py_library(
],
)
+cc_library(
+ name = "cpp_shape_inference",
+ srcs = ["framework/cpp_shape_inference.cc"],
+ hdrs = ["framework/cpp_shape_inference.h"],
+ copts = ["-Wno-sign-compare"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cpp_shape_inference_proto_cc",
+ ":numpy_lib",
+ ":py_func_lib",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_cc",
+ "//third_party/py/numpy:headers",
+ "//util/python:python_headers",
+ ],
+)
+
cuda_py_tests(
name = "device_lib_test",
size = "small",
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index a0c97f5f8f..09afe56b19 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
import six.moves
-from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -567,8 +567,12 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
the C++ shape function.
Returns:
- A TensorShape list of the output shapes of the op, as computed using the
- C++ shape inference function registered for the op.
+ A dictionary with the following keys:
+ shapes: A TensorShape list of the output shapes of the op, as computed
+ using the C++ shape inference function registered for the op.
+ handle_shapes: A TensorShape list of the shapes for handle outputs, if
+ any.
+ handle_dtypes: A list of DataType enums for the handle outputs, if any.
Raises:
ValueError: If the C++ shape function returned an error (e.g. because the
@@ -576,8 +580,16 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
according to the shape function).
"""
node_def_str = op.node_def.SerializeToString()
- input_shapes = [i.get_shape().as_proto().SerializeToString() for i in
- op.inputs]
+
+ def tensor_to_inference_result(t):
+ r = cpp_shape_inference_pb2.CppShapeInferenceResult()
+ r.shape.CopyFrom(t.get_shape().as_proto())
+ # pylint: disable=protected-access
+ r.handle_shape.CopyFrom(t._handle_shape)
+ r.handle_dtype = t._handle_dtype
+ # pylint: enable=protected-access
+ return r.SerializeToString()
+ input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
input_tensors = [None for i in input_shapes]
if input_tensors_needed:
@@ -596,10 +608,13 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
raise ValueError(err.message)
# Convert TensorShapeProto values in output_shapes.
- result = [
- tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s))
+ result_protos = [
+ cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
for s in output_shapes
]
+ result = [r.shape for r in result_protos]
+ result_handle_shapes = [r.handle_shape for r in result_protos]
+ result_handle_dtypes = [r.handle_dtype for r in result_protos]
if debug_python_shape_fn:
try:
@@ -616,4 +631,6 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
str(result), str(python_result), str(op.node_def),
",".join([str(i.get_shape()) for i in op.inputs])))
- return result
+ return {"shapes": result,
+ "handle_shapes": result_handle_shapes,
+ "handle_dtypes": result_handle_dtypes}
diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc
index bb5a57e617..57b85e8118 100644
--- a/tensorflow/python/framework/cpp_shape_inference.cc
+++ b/tensorflow/python/framework/cpp_shape_inference.cc
@@ -20,12 +20,32 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
#include "tensorflow/python/lib/core/py_func.h"
namespace tensorflow {
namespace swig {
namespace {
+void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s,
+ tensorflow::shape_inference::InferenceContext* c,
+ TensorShapeProto* out) {
+ if (c->RankKnown(s)) {
+ const int32 rank = c->Rank(s);
+ for (int i = 0; i < rank; ++i) {
+ shape_inference::DimensionHandle d = c->Dim(s, i);
+ auto* out_dim = out->add_dim();
+ if (c->ValueKnown(d)) {
+ out_dim->set_size(c->Value(d));
+ } else {
+ out_dim->set_size(-1);
+ }
+ }
+ } else {
+ out->set_unknown_rank(true);
+ }
+}
+
Status RunCppShapeInferenceImpl(
const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
@@ -49,12 +69,21 @@ Status RunCppShapeInferenceImpl(
// Convert input shapes.
std::vector<TensorShapeProto> input_shapes;
+ std::vector<TensorShapeProto> input_handle_shapes;
+ std::vector<DataType> input_handle_dtypes;
input_shapes.resize(input_serialized_shapes.size());
+ input_handle_shapes.resize(input_serialized_shapes.size());
+ input_handle_dtypes.resize(input_serialized_shapes.size());
+ CppShapeInferenceResult tmp;
for (int i = 0; i < input_serialized_shapes.size(); ++i) {
- if (!input_shapes[i].ParseFromString(input_serialized_shapes[i])) {
+ tmp.Clear();
+ if (!tmp.ParseFromString(input_serialized_shapes[i])) {
return errors::InvalidArgument(
"Error parsing shape proto during cpp shape inference");
}
+ input_shapes[i].Swap(tmp.mutable_shape());
+ input_handle_dtypes[i] = tmp.handle_dtype();
+ input_handle_shapes[i].Swap(tmp.mutable_handle_shape());
}
// Convert input tensor values;
@@ -73,34 +102,23 @@ Status RunCppShapeInferenceImpl(
}
// Run shape inference.
- // TODO(cwhipkey): pass a value for input_tensors_as_shapes.
tensorflow::shape_inference::InferenceContext c(
&node, op_reg_data->op_def, input_shapes, input_tensors,
- {} /* input_tensors_as_shapes */);
+ {} /* input_tensors_as_shapes */, input_handle_shapes,
+ input_handle_dtypes);
TF_RETURN_IF_ERROR(c.construction_status());
TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
// Convert output shapes.
output_tensor_shape_protos->resize(c.num_outputs());
- TensorShapeProto out;
+ CppShapeInferenceResult out;
for (int i = 0; i < c.num_outputs(); ++i) {
- shape_inference::ShapeHandle s = c.output(i);
out.Clear();
- if (c.RankKnown(s)) {
- const int32 rank = c.Rank(s);
- for (int i = 0; i < rank; ++i) {
- shape_inference::DimensionHandle d = c.Dim(s, i);
- auto* out_dim = out.add_dim();
- if (c.ValueKnown(d)) {
- out_dim->set_size(c.Value(d));
- } else {
- out_dim->set_size(-1);
- }
- }
- } else {
- out.set_unknown_rank(true);
- }
+ ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape());
+ ProtoFromShapeHandle(c.output_handle_shape(i), &c,
+ out.mutable_handle_shape());
+ out.set_handle_dtype(c.output_handle_dtype(i));
CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
}
diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h
index a2f52227d6..f91af8e1a8 100644
--- a/tensorflow/python/framework/cpp_shape_inference.h
+++ b/tensorflow/python/framework/cpp_shape_inference.h
@@ -36,7 +36,7 @@ namespace swig {
// inference was successful.
//
// On success, <*output_shapes> is populated with the inferred output shapes (as
-// serialized TensorShapeProtos).
+// serialized CppShapeInferenceResult protos).
// <*output_shapes> must be empty when this function is called.
//
// This is temporary code to be used during the migration
diff --git a/tensorflow/python/framework/cpp_shape_inference.proto b/tensorflow/python/framework/cpp_shape_inference.proto
new file mode 100644
index 0000000000..c7d59cc55e
--- /dev/null
+++ b/tensorflow/python/framework/cpp_shape_inference.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/types.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+
+message CppShapeInferenceResult {
+ TensorShapeProto shape = 1;
+ TensorShapeProto handle_shape = 2;
+ DataType handle_dtype = 3;
+} \ No newline at end of file
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 6d80da1300..be9de08ba3 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -32,6 +32,8 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
@@ -297,6 +299,10 @@ class Tensor(object):
# to easily navigate a computation graph.
self._consumers = []
+ # Attributes used for C++ shape inference. Not inspected, only forwarded.
+ self._handle_shape = tensor_shape_pb2.TensorShapeProto()
+ self._handle_dtype = types_pb2.DT_INVALID
+
@property
def op(self):
"""The `Operation` that produces this tensor as an output."""
@@ -1791,10 +1797,22 @@ def set_shapes_for_outputs(op):
if shapes is None:
raise RuntimeError(
"Shape function for op %s did not return any shapes" % op)
+ elif isinstance(shapes, dict):
+ # Returned by call_cpp_shape_fn
+ shapes_dict = shapes
+ shapes = shapes_dict["shapes"]
+ handle_shapes = shapes_dict["handle_shapes"]
+ handle_dtypes = shapes_dict["handle_dtypes"]
+ for output, handle_shape, handle_dtype in zip(op.outputs, handle_shapes, handle_dtypes):
+ # pylint: disable=protected-access
+ output._handle_shape = handle_shape
+ output._handle_dtype = handle_dtype
+ # pylint: enable=protected-access
+
if len(op.outputs) != len(shapes):
raise RuntimeError(
- "Shape function for op %s returned %d shapes but expected %d" %
- (op, len(shapes), len(op.outputs)))
+ "Shape function for op %s returned %d shapes but expected %d %s %s" %
+ (op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
for output, s in zip(op.outputs, shapes):
output.set_shape(s)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index baa48ec8a4..fe74b3426c 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -256,6 +256,16 @@ tf_py_test(
)
tf_py_test(
+ name = "resource_variable_ops_test",
+ size = "small",
+ srcs = ["resource_variable_ops_test.py"],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:resource_variable_ops",
+ ],
+)
+
+tf_py_test(
name = "save_restore_ops_test",
size = "small",
srcs = ["save_restore_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
new file mode 100644
index 0000000000..cb4375ce91
--- /dev/null
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -0,0 +1,51 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.resource_variable_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+
+
+class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
+
+ def testHandleDtypeShapeMatch(self):
+ with self.test_session():
+ handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
+ with self.assertRaises(ValueError):
+ resource_variable_ops.create_variable_op(
+ handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
+ with self.assertRaises(ValueError):
+ resource_variable_ops.create_variable_op(
+ handle, constant_op.constant([0], dtype=dtypes.int32)).run()
+ resource_variable_ops.create_variable_op(
+ handle, constant_op.constant(0, dtype=dtypes.int32)).run()
+
+ def testDtypeSurvivesIdentity(self):
+ with self.test_session():
+ handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
+ id_handle = array_ops.identity(handle)
+ resource_variable_ops.create_variable_op(
+ id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
new file mode 100644
index 0000000000..7db9731e19
--- /dev/null
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -0,0 +1,30 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops to use variables as resources."""
+
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import common_shapes
+from tensorflow.python.framework import ops
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_resource_variable_ops import *
+# pylint: enable=wildcard-import
+
+ops.RegisterShape("VarHandleOp")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("CreateVariableOp")(common_shapes.call_cpp_shape_fn)