aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-04-10 15:29:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-10 16:50:29 -0700
commit24a95ae389e1c76e771ac33d66e0ec40a236260f (patch)
treeaf9c67c7a984af7bae38fa096457cc76e7a3fda3 /tensorflow/cc/framework
parent31559fd5317abd10c457575a1718b88b3917446c (diff)
Change Placeholder to support partial shapes and enforce scalar shapes.
Adds tests; testScalar failed before with the original placeholder because it treated [] as "?" instead of scalar. Now you can actually specify [] and it means 'scalar'. Added a backwards compatibility test using a graph_def generated from a previous tf version. RELNOTES: tf.placeholder can represent scalar shapes and partially known shapes accurately. Note, this change can break already buggy programs because it makes placeholder shape handling more consistent across graph serializations. Note: There are some buggy cases where this change can break a buggy pipeline: namely those that serialize a graph using an unknown shape (e.g., [None, 10] in a tf.placeholder, but then reload the graph using import_graph_def and feed it a different shape. Prior to this change, serializing the graph_def loses the [None, 10] shape requirement, so you can feed anything you want. This change makes it so that you serialize the graph with [None, 10], and so when you reload it, it would fail if you fed it a different shape. In these cases, the fix is to correct the original placeholder shape to match what you feed it, which is not a bug in TF but in the user's program. Note 2: A python user that did tf.placeholder(shape=[]) would get scalar checking in the same process due to python shape inference code. However, a C++ user that did Placeholder(shape=[]) would not have gotten scalar shape checking; a C++ program that passed Placeholder(shape=[]) that expects to interpret this as "UnknownShape" would break -- however, that user could have already used an {unknown_shape: true} proto, and should not have expected the legacy behavior. Backwards compatibility: Old graphs that have shape = {} in the proto will also have a graph_def_version <= 21, so the default value of shape prior to this change will be interpreted by new binaries as "UnknownShape" just as before. Forwards compatibility: new graphs will produce, by default, shape={ unknown rank: true}; old binaries will use PartialTensorShape's parsing code to parse that proto into an object whose shape.dims() <= 0, and so these binaries will continue to interpret the default shape as "unknown shape" without crashing and without producing new errors. Fixes #9103 Change: 152751019
Diffstat (limited to 'tensorflow/cc/framework')
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc33
1 files changed, 23 insertions, 10 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index 22cd7fb0d4..26f15975c1 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -126,7 +126,11 @@ string PrintString(const string& str) {
return strings::StrCat("\"", str_util::CEscape(str), "\"");
}
-string PrintTensorShape(const TensorShape& shape) {
+string PrintTensorShape(const TensorShapeProto& shape_proto) {
+ PartialTensorShape shape(shape_proto);
+ if (shape.IsIdenticalTo(PartialTensorShape())) {
+ return "::tensorflow::PartialTensorShape() /* unknown */";
+ }
string ret = "{";
for (int d = 0; d < shape.dims(); ++d) {
if (d > 0) strings::StrAppend(&ret, ", ");
@@ -188,6 +192,12 @@ string PrintTensor(const TensorProto& tensor_proto) {
}
}
+string PrintTensorProto(const TensorProto& proto) {
+ return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ",
+ PrintTensorShape(proto.tensor_shape()),
+ ").AsTensorProto()");
+}
+
string PrintAttrValue(string op, const AttrValue& attr_value) {
switch (attr_value.value_case()) {
case AttrValue::kS:
@@ -203,12 +213,9 @@ string PrintAttrValue(string op, const AttrValue& attr_value) {
case AttrValue::kType:
return EnumName_DataType(attr_value.type());
case AttrValue::kShape:
- return PrintTensorShape(TensorShape(attr_value.shape()));
+ return PrintTensorShape(attr_value.shape());
case AttrValue::kTensor:
- return strings::StrCat(
- "Input::Initializer(", "{", PrintTensor(attr_value.tensor()), "}, ",
- PrintTensorShape(TensorShape(attr_value.tensor().tensor_shape())),
- ").AsTensorProto()");
+ return PrintTensorProto(attr_value.tensor());
case AttrValue::kList: {
string ret = "{";
if (attr_value.list().s_size() > 0) {
@@ -241,8 +248,14 @@ string PrintAttrValue(string op, const AttrValue& attr_value) {
} else if (attr_value.list().shape_size() > 0) {
for (int i = 0; i < attr_value.list().shape_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(
- &ret, PrintTensorShape(TensorShape(attr_value.list().shape(i))));
+ strings::StrAppend(&ret,
+ PrintTensorShape(attr_value.list().shape(i)));
+ }
+ } else if (attr_value.list().tensor_size() > 0) {
+ for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret,
+ PrintTensorProto(attr_value.list().tensor(i)));
}
}
strings::StrAppend(&ret, "}");
@@ -292,8 +305,8 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
{"list(bool)", {"gtl::ArraySlice<bool>", true}},
{"type", {"DataType", false}},
{"list(type)", {"DataTypeSlice", true}},
- {"shape", {"TensorShape", false}},
- {"list(shape)", {"gtl::ArraySlice<TensorShape>", true}},
+ {"shape", {"PartialTensorShape", false}},
+ {"list(shape)", {"gtl::ArraySlice<PartialTensorShape>", true}},
{"tensor", {"TensorProto", true}},
{"list(tensor)", {"gtl::ArraySlice<TensorProto>", true}},
{"func", {"NameAttrList", true}},