aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-31 12:24:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 13:38:39 -0700
commit2940b6c9ac3518b27633d90880b9092157496ee8 (patch)
treecba06a8dabb4c1cf79faba8267b8eea4e16b2a6d /tensorflow
parent32906b8c26608185eb7062b9bb32108b1b416d8a (diff)
Automated rollback of change 137731142
Change: 137740850
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/cmake/tf_python.cmake40
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc17
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc58
-rw-r--r--tensorflow/core/framework/node_def_util.h8
-rw-r--r--tensorflow/core/framework/resource_mgr.h1
-rw-r--r--tensorflow/core/framework/shape_inference.cc59
-rw-r--r--tensorflow/core/framework/shape_inference.h48
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc111
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc3
-rw-r--r--tensorflow/core/framework/tensor.h1
-rw-r--r--tensorflow/core/kernels/BUILD12
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc49
-rw-r--r--tensorflow/core/kernels/variable_ops.h39
-rw-r--r--tensorflow/core/ops/array_ops.cc7
-rw-r--r--tensorflow/core/ops/array_ops_test.cc17
-rw-r--r--tensorflow/core/ops/math_ops.cc16
-rw-r--r--tensorflow/core/ops/math_ops_test.cc31
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc80
-rw-r--r--tensorflow/python/BUILD60
-rw-r--r--tensorflow/python/framework/common_shapes.py33
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.cc56
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.h2
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.proto13
-rw-r--r--tensorflow/python/framework/ops.py22
-rw-r--r--tensorflow/python/kernel_tests/BUILD10
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py51
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py30
28 files changed, 162 insertions, 714 deletions
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 907158b646..8cdecf706a 100644
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -94,38 +94,6 @@ function(RELATIVE_PROTOBUF_GENERATE_PYTHON ROOT_DIR SRCS)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
-function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR)
- if(NOT ARGN)
- message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_CPP() called without any proto files")
- return()
- endif()
-
- set(${SRCS})
- set(${HDRS})
- foreach(FIL ${ARGN})
- set(ABS_FIL ${ROOT_DIR}/${FIL})
- get_filename_component(FIL_WE ${FIL} NAME_WE)
- get_filename_component(FIL_DIR ${ABS_FIL} PATH)
- file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})
-
- list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc")
- list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h")
-
- add_custom_command(
- OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc"
- "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h"
- COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
- ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS}
- DEPENDS ${ABS_FIL} protobuf
- COMMENT "Running C++ protocol buffer compiler on ${FIL}"
- VERBATIM )
- endforeach()
-
- set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
- set(${SRCS} ${${SRCS}} PARENT_SCOPE)
- set(${HDRS} ${${HDRS}} PARENT_SCOPE)
-endfunction()
-
file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/*.proto"
"${tensorflow_source_dir}/tensorflow/python/*.proto"
@@ -134,12 +102,6 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON(
${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs}
)
-RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
- ${tensorflow_source_dir} ${tf_protos_python_srcs}
-)
-
-add_library(tf_python_protos_cc ${PROTO_SRCS} ${PROTO_HDRS})
-
# tf_python_touchup_modules adds empty __init__.py files to all
# directories containing Python code, so that Python will recognize
# them as modules.
@@ -239,7 +201,6 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name)
)
target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE
tf_protos_cc
- tf_python_protos_cc
${tensorflow_EXTERNAL_LIBRARIES}
)
@@ -351,7 +312,6 @@ target_link_libraries(pywrap_tensorflow
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
tf_protos_cc
- tf_python_protos_cc
${PYTHON_LIBRARIES}
)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 79546ccd20..1c37921afc 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -379,7 +379,6 @@ tf_gen_op_libs(
"no_op",
"parsing_ops",
"random_ops",
- "resource_variable_ops",
"sdca_ops",
"script_ops",
"sendrecv_ops",
@@ -543,7 +542,6 @@ cc_library(
"//tensorflow/core/kernels:parsing",
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:required",
- "//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index e1500ed1ad..1ddd483076 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -42,8 +42,6 @@ Status ShapeRefiner::AddNode(const Node* node) {
// indexed by 'node's input.
std::vector<Node*> input_nodes(node->num_inputs());
std::vector<ShapeHandle> input_shapes(node->num_inputs());
- std::vector<DataType> input_handle_dtypes(node->num_inputs());
- std::vector<ShapeHandle> input_handle_shapes(node->num_inputs());
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
@@ -59,15 +57,6 @@ Status ShapeRefiner::AddNode(const Node* node) {
DCHECK_GE(e->dst_input(), 0);
input_nodes[e->dst_input()] = input;
input_shapes[e->dst_input()] = c->output(e->src_output());
-
- // Only propagate handle xshape and dtype of edges which are carrying
- // resource handles.
- if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
- input_handle_dtypes[e->dst_input()] =
- c->output_handle_dtype(e->src_output());
- input_handle_shapes[e->dst_input()] =
- c->output_handle_shape(e->src_output());
- }
}
// Get the shape function for this node
@@ -87,9 +76,9 @@ Status ShapeRefiner::AddNode(const Node* node) {
std::vector<ShapeHandle> input_tensors_as_shapes;
// Create the inference context for this node with the existing input shapes.
- std::unique_ptr<InferenceContext> c(new InferenceContext(
- &node->def(), node->op_def(), input_shapes, input_tensors,
- input_tensors_as_shapes, input_handle_shapes, input_handle_dtypes));
+ std::unique_ptr<InferenceContext> c(
+ new InferenceContext(&node->def(), node->op_def(), input_shapes,
+ input_tensors, input_tensors_as_shapes));
if (!c->construction_status().ok()) {
return c->construction_status();
}
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index ca1326844c..7196bc8304 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -56,7 +56,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
.Input({{"data", 0, DT_FLOAT}})
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {});
TF_EXPECT_OK(NoOutputs(&c));
EXPECT_EQ(0, c.num_outputs());
}
@@ -74,14 +74,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({})}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
{
- InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@@ -108,7 +108,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown inner dimension for one
- InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -126,7 +126,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Invalid rank.
- InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -136,7 +136,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown outer dimension
- InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -145,7 +145,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
- InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -156,8 +156,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
- InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {},
- {});
+ InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -175,7 +174,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -192,7 +191,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -216,7 +215,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -225,7 +224,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Unknown ranks.
- InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_FALSE(c.RankKnown(output));
@@ -233,8 +232,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Rank > 2
- InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {},
- {});
+ InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
@@ -247,7 +245,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
@@ -260,8 +258,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {},
- {}, {});
+ InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {},
+ {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
@@ -274,8 +272,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {}, {},
- {});
+ InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[10,11,12]", c.DebugString(output));
@@ -283,7 +280,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Input rank not high enough
- InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@@ -295,7 +292,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@@ -314,7 +311,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 10})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 10})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -322,7 +319,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Rank > 2
- InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -334,7 +331,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -346,8 +343,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {},
- {});
+ InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -359,7 +355,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -367,7 +363,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Input rank not high enough
- InferenceContext c(&def, op_def, {S({3})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({3})}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@@ -378,7 +374,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {S({2, 3})}, {}, {}, {}, {});
+ InferenceContext c(&def, op_def, {S({2, 3})}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 85b83c4d74..2d16760625 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -66,14 +66,6 @@ void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
AttrValueMap::value_type(name.ToString(), attr_value));
}
-// Adds an attr to an attr value map.
-template <class T>
-void AddAttr(StringPiece name, T&& value, AttrValueMap* map) {
- AttrValue attr_value;
- SetAttrValue(value, &attr_value);
- map->insert(AttrValueMap::value_type(name.ToString(), attr_value));
-}
-
class AttrSlice {
public:
AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 30bb2cb708..f823462079 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -21,7 +21,6 @@ limitations under the License.
#include <typeinfo>
#include <unordered_map>
-#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 4aa32f6a84..da88b6a7ca 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -31,9 +31,7 @@ InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes,
- const std::vector<TensorShapeProto>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes)
+ const std::vector<ShapeHandle>& input_tensors_as_shapes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
@@ -45,30 +43,19 @@ InferenceContext::InferenceContext(
}
inputs_.push_back(shape);
}
- std::vector<ShapeHandle> handle_shapes;
- for (const auto& p : input_handle_shapes) {
- ShapeHandle shape;
- construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
- if (!construction_status_.ok()) {
- return;
- }
- handle_shapes.push_back(shape);
- }
- PostInputInit(handle_shapes, input_handle_dtypes);
+ PostInputInit();
}
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes,
- const std::vector<ShapeHandle>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes)
+ const std::vector<ShapeHandle>& input_tensors_as_shapes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
- PostInputInit(input_handle_shapes, input_handle_dtypes);
+ PostInputInit();
}
InferenceContext::~InferenceContext() {
@@ -137,44 +124,15 @@ void InferenceContext::PreInputInit(
for (int i = 0; i < num_outputs; ++i) {
outputs_.push_back(nullptr);
}
- output_handle_shape_.reserve(num_outputs);
- for (int i = 0; i < num_outputs; ++i) {
- output_handle_shape_.push_back(UnknownShape());
- }
- output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID);
}
-void InferenceContext::PostInputInit(
- const std::vector<ShapeHandle>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes) {
+void InferenceContext::PostInputInit() {
int num_inputs_from_node_def = 0;
for (const auto& e : input_name_map_) {
num_inputs_from_node_def =
std::max(num_inputs_from_node_def, e.second.second);
}
- // Allow passing empty shapes/dtypes to avoid changing every single test.
- if (input_handle_shapes.empty()) {
- input_handle_shape_.resize(inputs_.size());
- } else {
- input_handle_shape_ = input_handle_shapes;
- if (input_handle_shape_.size() != inputs_.size()) {
- construction_status_ = errors::InvalidArgument(
- "Wrong number of handle shapes passed; expected ", inputs_.size(),
- " got ", input_handle_shape_.size());
- }
- }
- if (input_handle_dtypes.empty()) {
- input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID);
- } else {
- input_handle_dtype_ = input_handle_dtypes;
- if (input_handle_dtype_.size() != inputs_.size()) {
- construction_status_ = errors::InvalidArgument(
- "Wrong number of handle dtypes passed; expected ", inputs_.size(),
- " got ", input_handle_dtype_.size());
- }
- }
-
if (inputs_.size() != num_inputs_from_node_def) {
construction_status_ = errors::InvalidArgument(
"Wrong number of inputs passed: ", inputs_.size(), " while ",
@@ -779,13 +737,6 @@ Status InferenceContext::AttachContext(const Status& status) {
strings::StrCat(status.error_message(), error_context));
}
-ShapeHandle InferenceContext::input_handle_shape(int idx) {
- if (!input_handle_shape_[idx].IsSet()) {
- input_handle_shape_[idx] = UnknownShape();
- }
- return input_handle_shape_[idx];
-}
-
// -----------------------------------------------------------------------------
// ShapeManager
// -----------------------------------------------------------------------------
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index e02490efd9..f5befc15a1 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -147,9 +147,7 @@ class InferenceContext {
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes,
- const std::vector<ShapeHandle>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes);
+ const std::vector<ShapeHandle>& input_tensors_as_shapes);
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
//
@@ -164,9 +162,7 @@ class InferenceContext {
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes,
- const std::vector<TensorShapeProto>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes);
+ const std::vector<ShapeHandle>& input_tensors_as_shapes);
~InferenceContext();
@@ -235,12 +231,12 @@ class InferenceContext {
}
return s->dims_[idx];
}
- int32 Rank(ShapeHandle s) const { return s->rank_; }
- bool RankKnown(ShapeHandle s) const { return Rank(s) != kUnknownRank; }
- inline int64 Value(DimensionOrConstant d) const {
+ int32 Rank(ShapeHandle s) { return s->rank_; }
+ bool RankKnown(ShapeHandle s) { return Rank(s) != kUnknownRank; }
+ inline int64 Value(DimensionOrConstant d) {
return d.dim.IsSet() ? d.dim->value_ : d.val;
}
- inline bool ValueKnown(DimensionOrConstant d) const {
+ inline bool ValueKnown(DimensionOrConstant d) {
return Value(d) != kUnknownDim;
}
@@ -395,30 +391,6 @@ class InferenceContext {
Status construction_status() const { return construction_status_; }
- // Methods to propagate shape and dtype on edges of handles. Handles are the
- // dtype DT_RESOURCE which can be used to access state stored in a
- // ResourceManager. When ops (such as variables) consume these handles to
- // produce tensors they might need to know side-information about the shapes
- // and dtypes of tensors which can be accessed via the handle. These methods
- // propagate that information. Output handle dtypes and shapes are ignored if
- // the output tensor is not of type DT_RESOURCE.
- ShapeHandle input_handle_shape(int idx);
- DataType input_handle_dtype(int idx) const {
- return input_handle_dtype_[idx];
- }
- void set_output_handle_shape(int idx, ShapeHandle shape) {
- output_handle_shape_[idx] = shape;
- }
- void set_output_handle_dtype(int idx, DataType dtype) {
- output_handle_dtype_[idx] = dtype;
- }
- ShapeHandle output_handle_shape(int idx) const {
- return output_handle_shape_[idx];
- }
- DataType output_handle_dtype(int idx) const {
- return output_handle_dtype_[idx];
- }
-
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(ShapeHandle indices_shape,
@@ -509,8 +481,7 @@ class InferenceContext {
void PreInputInit(const OpDef& op_def,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
- void PostInputInit(const std::vector<ShapeHandle>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes);
+ void PostInputInit();
DimensionHandle GetDimension(const DimensionOrConstant& d);
@@ -539,11 +510,6 @@ class InferenceContext {
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
- std::vector<ShapeHandle> input_handle_shape_;
- std::vector<DataType> input_handle_dtype_;
- std::vector<ShapeHandle> output_handle_shape_;
- std::vector<DataType> output_handle_dtype_;
-
const NodeDef& node_def_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 8d6b4ac021..06096bfdcc 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -71,8 +71,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
.Attr("N", 3)
.Input(FakeInput(DT_FLOAT))
.Finalize(&def);
- InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {},
- {}, {});
+ InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {});
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
@@ -108,7 +107,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {});
EXPECT_EQ(InferenceContext::kUnknownDim,
c.Value(InferenceContext::kUnknownDim));
EXPECT_EQ(1, c.Value(1));
@@ -123,7 +122,7 @@ TEST_F(ShapeInferenceTest, Run) {
NodeDef def;
def.set_name("foo");
def.set_op("foo_op");
- InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {});
{
auto fn = [](InferenceContext* c) {
@@ -155,7 +154,7 @@ TEST_F(ShapeInferenceTest, Run) {
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})},
- {}, {}, {}, {});
+ {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@@ -196,8 +195,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
TEST_F(ShapeInferenceTest, NumElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2),
- {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {},
- {});
+ {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {});
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
@@ -210,8 +208,7 @@ TEST_F(ShapeInferenceTest, NumElements) {
TEST_F(ShapeInferenceTest, WithRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
- {}, {});
+ InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -249,8 +246,7 @@ TEST_F(ShapeInferenceTest, WithRank) {
TEST_F(ShapeInferenceTest, WithRankAtMost) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
- {}, {});
+ InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -288,8 +284,7 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
- {}, {});
+ InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -327,7 +322,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
TEST_F(ShapeInferenceTest, WithValue) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {});
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@@ -368,8 +363,7 @@ TEST_F(ShapeInferenceTest, WithValue) {
TEST_F(ShapeInferenceTest, MergeDim) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {});
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@@ -418,7 +412,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
InferenceContext c(&def, MakeOpDef(7, 2),
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
Unknown(), S({1})},
- {}, {}, {}, {});
+ {}, {});
auto s_unknown = c.input(0);
auto s_1_2 = c.input(1);
@@ -489,7 +483,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
{
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
},
- {}, {}, {}, {});
+ {}, {});
auto s_unknown = c.input(0);
auto s_u_2 = c.input(1);
@@ -542,7 +536,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
TEST_F(ShapeInferenceTest, Subshape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()},
- {}, {}, {}, {});
+ {}, {});
ShapeHandle unknown = c.input(1);
ShapeHandle out;
@@ -617,7 +611,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
TEST_F(ShapeInferenceTest, Concatenate) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2),
- {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {});
+ {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -643,8 +637,7 @@ TEST_F(ShapeInferenceTest, Concatenate) {
TEST_F(ShapeInferenceTest, ReplaceDim) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {},
- {}, {});
+ InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {});
auto in = c.input(0);
auto unknown = c.input(1);
@@ -675,8 +668,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
TEST_F(ShapeInferenceTest, MakeShape) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {});
std::vector<DimensionHandle> dims;
auto in0 = c.input(0);
@@ -701,7 +693,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
TEST_F(ShapeInferenceTest, UnknownShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto u0 = c.UnknownShape();
auto u1 = c.UnknownShape();
@@ -713,7 +705,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
TEST_F(ShapeInferenceTest, Scalar) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto s0 = c.Scalar();
EXPECT_EQ("[]", c.DebugString(s0));
@@ -724,7 +716,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
TEST_F(ShapeInferenceTest, Vector) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto s0 = c.Vector(1);
EXPECT_EQ("[1]", c.DebugString(s0));
@@ -740,7 +732,7 @@ TEST_F(ShapeInferenceTest, Vector) {
TEST_F(ShapeInferenceTest, Matrix) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto s0 = c.Matrix(1, 2);
EXPECT_EQ("[1,2]", c.DebugString(s0));
@@ -762,7 +754,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
auto create = [&](Tensor* t) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {});
ShapeHandle out;
Status s = c.MakeShapeFromShapeTensor(0, &out);
if (s.ok()) {
@@ -814,8 +806,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
// Test when the input shape is wrong.
{
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {});
ShapeHandle out;
EXPECT_EQ("Shape must be rank 1 but is rank 2",
c.MakeShapeFromShapeTensor(0, &out).error_message());
@@ -825,7 +816,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
TensorShapeProto proto;
// With a set unknown rank.
@@ -861,7 +852,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
TEST_F(ShapeInferenceTest, MakeDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto d0 = c.MakeDim(1);
auto d1 = c.MakeDim(1);
@@ -875,7 +866,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
TEST_F(ShapeInferenceTest, UnknownDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto d0 = c.UnknownDim();
auto d1 = c.UnknownDim();
@@ -887,7 +878,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
@@ -901,7 +892,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
- {&t1, &t2}, {}, {}, {});
+ {&t1, &t2}, {});
EXPECT_TRUE(c.input_tensor(0) == &t1);
EXPECT_TRUE(c.input_tensor(1) == &t2);
@@ -912,8 +903,7 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {});
DimensionHandle d;
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
@@ -944,7 +934,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
.ok());
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, op_reg_data.op_def, empty, {}, {}, {}, {});
+ InferenceContext c(&def, op_reg_data.op_def, empty, {}, {});
string value;
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
EXPECT_EQ("bar", value);
@@ -952,8 +942,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
TEST_F(ShapeInferenceTest, Divide) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}, {},
- {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1015,7 +1004,7 @@ TEST_F(ShapeInferenceTest, Divide) {
TEST_F(ShapeInferenceTest, Add) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1066,7 +1055,7 @@ TEST_F(ShapeInferenceTest, Add) {
TEST_F(ShapeInferenceTest, Subtract) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1115,7 +1104,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
TEST_F(ShapeInferenceTest, Multiply) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1168,7 +1157,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
TEST_F(ShapeInferenceTest, FullyDefined) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
// No rank or missing dimension information should return false.
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
@@ -1181,7 +1170,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
TEST_F(ShapeInferenceTest, Min) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@@ -1229,7 +1218,7 @@ TEST_F(ShapeInferenceTest, Min) {
TEST_F(ShapeInferenceTest, Max) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@@ -1267,7 +1256,7 @@ TEST_F(ShapeInferenceTest, Max) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
- {}, {}, {}, {});
+ {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1280,7 +1269,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
- {}, {}, {});
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1292,8 +1281,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
- {}, {});
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {},
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1306,8 +1295,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
- {}, {});
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {},
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1320,8 +1309,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
- {}, {});
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {},
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1335,7 +1324,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
- {}, {}, {});
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1348,7 +1337,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
- {}, {}, {});
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1361,7 +1350,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
- {}, {}, {});
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1374,7 +1363,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
- {}, {}, {});
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -1386,8 +1375,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
- {}, {});
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {},
+ {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index a225824f82..ed1d3ec520 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -44,7 +44,8 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
}
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
- in_shapes, op.input_tensors, {}, {}, {});
+ in_shapes, op.input_tensors,
+ {} /* input_tensors_as_shapes */);
TF_RETURN_IF_ERROR(c.construction_status());
if (op_reg_data->shape_inference_fn == nullptr) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 47d74d4def..92915685f5 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -435,7 +435,6 @@ class Tensor {
friend class VariableOp; // For access to set_shape
friend class AutoReloadVariableOp; // For access to set_shape
friend class TensorTestHelper; // For access to set_shape
- friend class CreateVariableOp;
// Creates a tensor with the input datatype, shape and buf.
//
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1954ebdc10..34954f0066 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1031,18 +1031,6 @@ tf_kernel_library(
)
tf_kernel_library(
- name = "resource_variable_ops",
- srcs = ["resource_variable_ops.cc"],
- deps = [
- ":variable_ops",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:resource_variable_ops_op_lib",
- "//third_party/eigen3",
- ],
-)
-
-tf_kernel_library(
name = "fact_op",
prefix = "fact_op",
deps = [
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
deleted file mode 100644
index fbe66e8386..0000000000
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/kernels/variable_ops.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-REGISTER_RESOURCE_HANDLE_KERNEL(Var);
-
-class CreateVariableOp : public OpKernel {
- public:
- CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) {
- OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
- }
-
- void Compute(OpKernelContext* c) override {
- Var* var = new Var(dtype_);
- var->Ref();
- core::ScopedUnref ur(var);
- OP_REQUIRES_OK(c, CreateResource<Var>(c, HandleFromInput(c, 0), var));
- // TODO(apassos): this currently does not initialize the tensor, so it's
- // pointless, other than checking construction in tests. Fix this.
- }
-
- private:
- DataType dtype_;
-};
-REGISTER_KERNEL_BUILDER(Name("CreateVariableOp").Device(DEVICE_CPU),
- CreateVariableOp);
-
-} // namespace tensorflow
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index f44f94c51b..f754629a72 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -26,26 +26,6 @@ limitations under the License.
namespace tensorflow {
-// Resource stored by variables in the resource manager.
-class Var : public ResourceBase {
- public:
- explicit Var(DataType dtype) : tensor_(dtype) {}
- mutex* mu() { return &mu_; }
- Tensor* tensor() { return &tensor_; }
-
- string DebugString() override {
- return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
- tensor_.shape().DebugString());
- }
-
- private:
- mutex mu_;
- Tensor tensor_;
-
- ~Var() override {}
- TF_DISALLOW_COPY_AND_ASSIGN(Var);
-};
-
class VariableOp : public OpKernel {
public:
explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -79,6 +59,25 @@ class VariableOp : public OpKernel {
}
private:
+ class Var : public ResourceBase {
+ public:
+ explicit Var(DataType dtype) : tensor_(dtype) {}
+ mutex* mu() { return &mu_; }
+ Tensor* tensor() { return &tensor_; }
+
+ string DebugString() override {
+ return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
+ tensor_.shape().DebugString());
+ }
+
+ private:
+ mutex mu_;
+ Tensor tensor_;
+
+ ~Var() override {}
+ TF_DISALLOW_COPY_AND_ASSIGN(Var);
+ };
+
DataType dtype_;
TensorShape shape_;
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 58b95fcff1..6e076a092e 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1252,12 +1252,7 @@ REGISTER_OP("Identity")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->input(0));
- c->set_output_handle_dtype(0, c->input_handle_dtype(0));
- c->set_output_handle_shape(0, c->input_handle_shape(0));
- return Status::OK();
- })
+ .SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"Doc(
Return a tensor with the same shape and contents as the input tensor or value.
)Doc");
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 90f3c60b64..8679739b70 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -15,9 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -155,21 +153,6 @@ TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) {
INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
}
-TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
- const char* op_name = "Identity";
- ShapeInferenceTestOp op(op_name);
- // Check that handle dtypes are preserved.
- const OpRegistrationData* op_reg_data;
- TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
- shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
- {TensorShapeProto()}, {}, {}, {},
- {DT_BOOL});
- TF_ASSERT_OK(c.construction_status());
- ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
- TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
- EXPECT_TRUE(c.output_handle_dtype(0) == DT_BOOL);
-}
-
TEST(ArrayOpsTest, Diag_ShapeFn) {
ShapeInferenceTestOp op("Diag");
INFER_OK(op, "?", "?");
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index ff00214da3..8d3d9310a4 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -911,19 +911,6 @@ REGISTER_OP("Select")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- // Merge handle shape and dtype if applicable.
- if (c->input_handle_dtype(1) != c->input_handle_dtype(2)) {
- // TODO(apassos) resolve this in the manner of b/32476923
- return errors::InvalidArgument(
- "Trying to merge handles pointing to different dtypes.");
- }
- c->set_output_handle_dtype(0, c->input_handle_dtype(1));
- ShapeHandle output_handle_shape;
- TF_RETURN_IF_ERROR(c->Merge(c->input_handle_shape(1),
- c->input_handle_shape(2),
- &output_handle_shape));
- c->set_output_handle_shape(0, output_handle_shape);
-
// The inputs 'then' and 'else' must have the same shape.
ShapeHandle data = c->input(1);
ShapeHandle other = c->input(2);
@@ -974,9 +961,8 @@ REGISTER_OP("Select")
}
c->set_output(0, data);
-
return Status::OK();
- })
+ })
.Doc(R"doc(
Selects elements from `t` or `e`, depending on `condition`.
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index 6a1cc8e7eb..79ae187342 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -216,37 +216,6 @@ TEST(MathOpsTest, Select_ShapeFn) {
INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
"[2,?,5];[?,?,3];[?,2,?]");
-
- // Test that handle shapes were merged.
- const OpRegistrationData* op_reg_data;
- TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
- TensorShapeProto i0;
- i0.add_dim()->set_size(1);
- i0.add_dim()->set_size(-1);
- TensorShapeProto i1;
- i1.add_dim()->set_size(-1);
- i1.add_dim()->set_size(2);
-
- ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
- shape_inference::InferenceContext c(
- &op.node_def, op_reg_data->op_def,
- {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
- {TensorShapeProto(), i0, i1}, {});
- TF_ASSERT_OK(c.construction_status());
- TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
- EXPECT_TRUE(c.FullyDefined(c.output_handle_shape(0)));
- EXPECT_EQ("[1,2]", c.DebugString(c.output_handle_shape(0)));
-
- // Expect an error when the shapes can't be merged.
- TensorShapeProto i2;
- i1.add_dim()->set_size(2);
- i1.add_dim()->set_size(2);
- shape_inference::InferenceContext c2(
- &op.node_def, op_reg_data->op_def,
- {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
- {TensorShapeProto(), i0, i2}, {});
- TF_ASSERT_OK(c.construction_status());
- EXPECT_FALSE(c2.Run(op_reg_data->shape_inference_fn).ok());
}
TEST(MathOpsTest, Range_ShapeFn) {
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
deleted file mode 100644
index 6211b07ac5..0000000000
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// ============================================================================
-
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-namespace tensorflow {
-
-REGISTER_OP("VarHandleOp")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("dtype: type")
- .Attr("shape: shape")
- .Output("resource: resource")
- .SetIsStateful()
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->Scalar());
- DataType t;
- c->GetAttr("dtype", &t);
- c->set_output_handle_dtype(0, t);
- TensorShapeProto p;
- c->GetAttr("shape", &p);
- shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(p, &s));
- c->set_output_handle_shape(0, s);
- return Status::OK();
- })
- .Doc(R"(
-Creates a handle to a Variable resource.
-
-container: the container this variable is placed in.
-shared_name: the name by which this variable is referred to.
-dtype: the type of this variable. Must agree with the dtypes
- of all ops using this variable.
-shape: The (possibly partially specified) shape of this variable.
-)");
-
-REGISTER_OP("CreateVariableOp")
- .Input("resource: resource")
- .Input("value: dtype")
- .Attr("dtype: type")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- DataType handle_dtype = c->input_handle_dtype(0);
- DataType value_dtype;
- c->GetAttr("dtype", &value_dtype);
- if (handle_dtype != value_dtype) {
- return errors::InvalidArgument(
- "Trying to initialize handle for variable with wrong dtype. "
- "Expected ",
- handle_dtype, " got ", value_dtype);
- }
- shape_inference::ShapeHandle s = c->input_handle_shape(0);
- shape_inference::ShapeHandle value_shape = c->input(1);
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
- return Status::OK();
- })
- .Doc(R"(
-Creates a variable resource.
-
-resource: handle to the resource in which to store the variable.
-value: the value to set the new tensor to use.
-dtype: the dtype of the value.
-)");
-
-} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index da875c081a..dcb6f52605 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -218,6 +218,23 @@ cc_library(
)
cc_library(
+ name = "cpp_shape_inference",
+ srcs = ["framework/cpp_shape_inference.cc"],
+ hdrs = ["framework/cpp_shape_inference.h"],
+ copts = ["-Wno-sign-compare"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":numpy_lib",
+ ":py_func_lib",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_cc",
+ "//third_party/py/numpy:headers",
+ "//util/python:python_headers",
+ ],
+)
+
+cc_library(
name = "python_op_gen_main",
srcs = [
"framework/python_op_gen_main.cc",
@@ -267,7 +284,6 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":cpp_shape_inference_proto_py",
":framework_for_generated_wrappers",
":pywrap_tensorflow",
],
@@ -645,11 +661,6 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
- name = "resource_variable_ops_gen",
- require_shape_functions = True,
-)
-
-tf_gen_op_wrapper_private_py(
name = "script_ops_gen",
require_shape_functions = True,
)
@@ -979,16 +990,6 @@ py_library(
)
py_library(
- name = "resource_variable_ops",
- srcs = ["ops/resource_variable_ops.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":framework",
- ":resource_variable_ops_gen",
- ],
-)
-
-py_library(
name = "nn",
srcs = ["ops/nn.py"],
srcs_version = "PY2AND3",
@@ -1408,7 +1409,6 @@ py_library(
":partitioned_variables",
":random_ops",
":random_ops_gen",
- ":resource_variable_ops",
":resources",
":rnn",
":rnn_cell",
@@ -1690,7 +1690,6 @@ tf_proto_library(
["**/*.proto"],
exclude = [
"util/protobuf/compare_test.proto",
- "framework/cpp_shape_inference.proto",
],
),
go_api_version = 2,
@@ -1702,13 +1701,6 @@ tf_proto_library_py(
srcs = ["util/protobuf/compare_test.proto"],
)
-tf_proto_library(
- name = "cpp_shape_inference_proto",
- srcs = ["framework/cpp_shape_inference.proto"],
- cc_api_version = 2,
- cc_libs = ["//tensorflow/core:protos_all_cc"],
-)
-
py_test(
name = "protobuf_compare_test",
size = "small",
@@ -1775,24 +1767,6 @@ py_library(
],
)
-cc_library(
- name = "cpp_shape_inference",
- srcs = ["framework/cpp_shape_inference.cc"],
- hdrs = ["framework/cpp_shape_inference.h"],
- copts = ["-Wno-sign-compare"],
- visibility = ["//visibility:public"],
- deps = [
- ":cpp_shape_inference_proto_cc",
- ":numpy_lib",
- ":py_func_lib",
- "//tensorflow/c:tf_status_helper",
- "//tensorflow/core:framework",
- "//tensorflow/core:protos_cc",
- "//third_party/py/numpy:headers",
- "//util/python:python_headers",
- ],
-)
-
cuda_py_tests(
name = "device_lib_test",
size = "small",
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 09afe56b19..a0c97f5f8f 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
import six.moves
+from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -567,12 +567,8 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
the C++ shape function.
Returns:
- A dictionary with the following keys:
- shapes: A TensorShape list of the output shapes of the op, as computed
- using the C++ shape inference function registered for the op.
- handle_shapes: A TensorShape list of the shapes for handle outputs, if
- any.
- handle_dtypes: A list of DataType enums for the handle outputs, if any.
+ A TensorShape list of the output shapes of the op, as computed using the
+ C++ shape inference function registered for the op.
Raises:
ValueError: If the C++ shape function returned an error (e.g. because the
@@ -580,16 +576,8 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
according to the shape function).
"""
node_def_str = op.node_def.SerializeToString()
-
- def tensor_to_inference_result(t):
- r = cpp_shape_inference_pb2.CppShapeInferenceResult()
- r.shape.CopyFrom(t.get_shape().as_proto())
- # pylint: disable=protected-access
- r.handle_shape.CopyFrom(t._handle_shape)
- r.handle_dtype = t._handle_dtype
- # pylint: enable=protected-access
- return r.SerializeToString()
- input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
+ input_shapes = [i.get_shape().as_proto().SerializeToString() for i in
+ op.inputs]
input_tensors = [None for i in input_shapes]
if input_tensors_needed:
@@ -608,13 +596,10 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
raise ValueError(err.message)
# Convert TensorShapeProto values in output_shapes.
- result_protos = [
- cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
+ result = [
+ tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s))
for s in output_shapes
]
- result = [r.shape for r in result_protos]
- result_handle_shapes = [r.handle_shape for r in result_protos]
- result_handle_dtypes = [r.handle_dtype for r in result_protos]
if debug_python_shape_fn:
try:
@@ -631,6 +616,4 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
str(result), str(python_result), str(op.node_def),
",".join([str(i.get_shape()) for i in op.inputs])))
- return {"shapes": result,
- "handle_shapes": result_handle_shapes,
- "handle_dtypes": result_handle_dtypes}
+ return result
diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc
index 57b85e8118..bb5a57e617 100644
--- a/tensorflow/python/framework/cpp_shape_inference.cc
+++ b/tensorflow/python/framework/cpp_shape_inference.cc
@@ -20,32 +20,12 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
#include "tensorflow/python/lib/core/py_func.h"
namespace tensorflow {
namespace swig {
namespace {
-void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s,
- tensorflow::shape_inference::InferenceContext* c,
- TensorShapeProto* out) {
- if (c->RankKnown(s)) {
- const int32 rank = c->Rank(s);
- for (int i = 0; i < rank; ++i) {
- shape_inference::DimensionHandle d = c->Dim(s, i);
- auto* out_dim = out->add_dim();
- if (c->ValueKnown(d)) {
- out_dim->set_size(c->Value(d));
- } else {
- out_dim->set_size(-1);
- }
- }
- } else {
- out->set_unknown_rank(true);
- }
-}
-
Status RunCppShapeInferenceImpl(
const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
@@ -69,21 +49,12 @@ Status RunCppShapeInferenceImpl(
// Convert input shapes.
std::vector<TensorShapeProto> input_shapes;
- std::vector<TensorShapeProto> input_handle_shapes;
- std::vector<DataType> input_handle_dtypes;
input_shapes.resize(input_serialized_shapes.size());
- input_handle_shapes.resize(input_serialized_shapes.size());
- input_handle_dtypes.resize(input_serialized_shapes.size());
- CppShapeInferenceResult tmp;
for (int i = 0; i < input_serialized_shapes.size(); ++i) {
- tmp.Clear();
- if (!tmp.ParseFromString(input_serialized_shapes[i])) {
+ if (!input_shapes[i].ParseFromString(input_serialized_shapes[i])) {
return errors::InvalidArgument(
"Error parsing shape proto during cpp shape inference");
}
- input_shapes[i].Swap(tmp.mutable_shape());
- input_handle_dtypes[i] = tmp.handle_dtype();
- input_handle_shapes[i].Swap(tmp.mutable_handle_shape());
}
// Convert input tensor values;
@@ -102,23 +73,34 @@ Status RunCppShapeInferenceImpl(
}
// Run shape inference.
+ // TODO(cwhipkey): pass a value for input_tensors_as_shapes.
tensorflow::shape_inference::InferenceContext c(
&node, op_reg_data->op_def, input_shapes, input_tensors,
- {} /* input_tensors_as_shapes */, input_handle_shapes,
- input_handle_dtypes);
+ {} /* input_tensors_as_shapes */);
TF_RETURN_IF_ERROR(c.construction_status());
TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
// Convert output shapes.
output_tensor_shape_protos->resize(c.num_outputs());
- CppShapeInferenceResult out;
+ TensorShapeProto out;
for (int i = 0; i < c.num_outputs(); ++i) {
+ shape_inference::ShapeHandle s = c.output(i);
out.Clear();
- ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape());
- ProtoFromShapeHandle(c.output_handle_shape(i), &c,
- out.mutable_handle_shape());
- out.set_handle_dtype(c.output_handle_dtype(i));
+ if (c.RankKnown(s)) {
+ const int32 rank = c.Rank(s);
+ for (int i = 0; i < rank; ++i) {
+ shape_inference::DimensionHandle d = c.Dim(s, i);
+ auto* out_dim = out.add_dim();
+ if (c.ValueKnown(d)) {
+ out_dim->set_size(c.Value(d));
+ } else {
+ out_dim->set_size(-1);
+ }
+ }
+ } else {
+ out.set_unknown_rank(true);
+ }
CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
}
diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h
index f91af8e1a8..a2f52227d6 100644
--- a/tensorflow/python/framework/cpp_shape_inference.h
+++ b/tensorflow/python/framework/cpp_shape_inference.h
@@ -36,7 +36,7 @@ namespace swig {
// inference was successful.
//
// On success, <*output_shapes> is populated with the inferred output shapes (as
-// serialized CppShapeInferenceResult protos).
+// serialized TensorShapeProtos).
// <*output_shapes> must be empty when this function is called.
//
// This is temporary code to be used during the migration
diff --git a/tensorflow/python/framework/cpp_shape_inference.proto b/tensorflow/python/framework/cpp_shape_inference.proto
deleted file mode 100644
index c7d59cc55e..0000000000
--- a/tensorflow/python/framework/cpp_shape_inference.proto
+++ /dev/null
@@ -1,13 +0,0 @@
-syntax = "proto3";
-
-package tensorflow;
-option cc_enable_arenas = true;
-
-import "tensorflow/core/framework/types.proto";
-import "tensorflow/core/framework/tensor_shape.proto";
-
-message CppShapeInferenceResult {
- TensorShapeProto shape = 1;
- TensorShapeProto handle_shape = 2;
- DataType handle_dtype = 3;
-} \ No newline at end of file
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index be9de08ba3..6d80da1300 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -32,8 +32,6 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
-from tensorflow.core.framework import tensor_shape_pb2
-from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
@@ -299,10 +297,6 @@ class Tensor(object):
# to easily navigate a computation graph.
self._consumers = []
- # Attributes used for C++ shape inference. Not inspected, only forwarded.
- self._handle_shape = tensor_shape_pb2.TensorShapeProto()
- self._handle_dtype = types_pb2.DT_INVALID
-
@property
def op(self):
"""The `Operation` that produces this tensor as an output."""
@@ -1797,22 +1791,10 @@ def set_shapes_for_outputs(op):
if shapes is None:
raise RuntimeError(
"Shape function for op %s did not return any shapes" % op)
- elif isinstance(shapes, dict):
- # Returned by call_cpp_shape_fn
- shapes_dict = shapes
- shapes = shapes_dict["shapes"]
- handle_shapes = shapes_dict["handle_shapes"]
- handle_dtypes = shapes_dict["handle_dtypes"]
- for output, handle_shape, handle_dtype in zip(op.outputs, handle_shapes, handle_dtypes):
- # pylint: disable=protected-access
- output._handle_shape = handle_shape
- output._handle_dtype = handle_dtype
- # pylint: enable=protected-access
-
if len(op.outputs) != len(shapes):
raise RuntimeError(
- "Shape function for op %s returned %d shapes but expected %d %s %s" %
- (op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
+ "Shape function for op %s returned %d shapes but expected %d" %
+ (op, len(shapes), len(op.outputs)))
for output, s in zip(op.outputs, shapes):
output.set_shape(s)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index fe74b3426c..baa48ec8a4 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -256,16 +256,6 @@ tf_py_test(
)
tf_py_test(
- name = "resource_variable_ops_test",
- size = "small",
- srcs = ["resource_variable_ops_test.py"],
- additional_deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/python:resource_variable_ops",
- ],
-)
-
-tf_py_test(
name = "save_restore_ops_test",
size = "small",
srcs = ["save_restore_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
deleted file mode 100644
index cb4375ce91..0000000000
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tensorflow.ops.resource_variable_ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.platform import test
-
-
-class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
-
- def testHandleDtypeShapeMatch(self):
- with self.test_session():
- handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
- with self.assertRaises(ValueError):
- resource_variable_ops.create_variable_op(
- handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
- with self.assertRaises(ValueError):
- resource_variable_ops.create_variable_op(
- handle, constant_op.constant([0], dtype=dtypes.int32)).run()
- resource_variable_ops.create_variable_op(
- handle, constant_op.constant(0, dtype=dtypes.int32)).run()
-
- def testDtypeSurvivesIdentity(self):
- with self.test_session():
- handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
- id_handle = array_ops.identity(handle)
- resource_variable_ops.create_variable_op(
- id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
deleted file mode 100644
index 7db9731e19..0000000000
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Ops to use variables as resources."""
-
-# pylint: disable=g-bad-name
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import common_shapes
-from tensorflow.python.framework import ops
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.python.ops.gen_resource_variable_ops import *
-# pylint: enable=wildcard-import
-
-ops.RegisterShape("VarHandleOp")(common_shapes.call_cpp_shape_fn)
-ops.RegisterShape("CreateVariableOp")(common_shapes.call_cpp_shape_fn)