diff options
author | Vijay Vasudevan <vrv@google.com> | 2017-04-10 15:29:15 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-10 16:50:29 -0700 |
commit | 24a95ae389e1c76e771ac33d66e0ec40a236260f (patch) | |
tree | af9c67c7a984af7bae38fa096457cc76e7a3fda3 /tensorflow/cc/framework | |
parent | 31559fd5317abd10c457575a1718b88b3917446c (diff) |
Change Placeholder to support partial shapes and enforce scalar shapes.
Adds tests; testScalar failed before with the original placeholder because
it treated [] as "?" instead of scalar. Now you can actually specify
[] and it means 'scalar'.
Added a backwards compatibility test using a graph_def generated
from a previous tf version.
RELNOTES: tf.placeholder can represent scalar shapes and partially known
shapes accurately. Note, this change can break already buggy programs because
it makes placeholder shape handling more consistent across graph serializations.
Note: There are some buggy cases where this change can break a buggy pipeline: namely those that serialize a graph using an unknown shape (e.g., [None, 10] in a tf.placeholder, but then reload the graph using import_graph_def and feed it a different shape. Prior to this change, serializing the graph_def loses the [None, 10] shape requirement, so you can feed anything you want. This change makes it so that you serialize the graph with [None, 10], and so when you reload it, it would fail if you fed it a different shape. In these cases, the fix is to correct the original placeholder shape to match what you feed it, which is not a bug in TF but in the user's program.
Note 2: A python user that did tf.placeholder(shape=[]) would get scalar checking
in the same process due to python shape inference code. However, a C++ user that did Placeholder(shape=[]) would not have gotten
scalar shape checking; a C++ program that passed Placeholder(shape=[]) that expects
to interpret this as "UnknownShape" would break -- however, that user could have
already used an {unknown_shape: true} proto, and should not have expected the legacy behavior.
Backwards compatibility: Old graphs that have shape = {} in the proto will also have a
graph_def_version <= 21, so the default value of shape prior to this change will be interpreted by new binaries as "UnknownShape" just as before.
Forwards compatibility: new graphs will produce, by default, shape={ unknown rank: true}; old binaries will use PartialTensorShape's parsing code to parse that proto
into an object whose shape.dims() <= 0, and so these binaries will continue to interpret
the default shape as "unknown shape" without crashing and without producing new errors.
Fixes #9103
Change: 152751019
Diffstat (limited to 'tensorflow/cc/framework')
-rw-r--r-- | tensorflow/cc/framework/cc_op_gen.cc | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 22cd7fb0d4..26f15975c1 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -126,7 +126,11 @@ string PrintString(const string& str) { return strings::StrCat("\"", str_util::CEscape(str), "\""); } -string PrintTensorShape(const TensorShape& shape) { +string PrintTensorShape(const TensorShapeProto& shape_proto) { + PartialTensorShape shape(shape_proto); + if (shape.IsIdenticalTo(PartialTensorShape())) { + return "::tensorflow::PartialTensorShape() /* unknown */"; + } string ret = "{"; for (int d = 0; d < shape.dims(); ++d) { if (d > 0) strings::StrAppend(&ret, ", "); @@ -188,6 +192,12 @@ string PrintTensor(const TensorProto& tensor_proto) { } } +string PrintTensorProto(const TensorProto& proto) { + return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ", + PrintTensorShape(proto.tensor_shape()), + ").AsTensorProto()"); +} + string PrintAttrValue(string op, const AttrValue& attr_value) { switch (attr_value.value_case()) { case AttrValue::kS: @@ -203,12 +213,9 @@ string PrintAttrValue(string op, const AttrValue& attr_value) { case AttrValue::kType: return EnumName_DataType(attr_value.type()); case AttrValue::kShape: - return PrintTensorShape(TensorShape(attr_value.shape())); + return PrintTensorShape(attr_value.shape()); case AttrValue::kTensor: - return strings::StrCat( - "Input::Initializer(", "{", PrintTensor(attr_value.tensor()), "}, ", - PrintTensorShape(TensorShape(attr_value.tensor().tensor_shape())), - ").AsTensorProto()"); + return PrintTensorProto(attr_value.tensor()); case AttrValue::kList: { string ret = "{"; if (attr_value.list().s_size() > 0) { @@ -241,8 +248,14 @@ string PrintAttrValue(string op, const AttrValue& attr_value) { } else if (attr_value.list().shape_size() > 0) { for (int i = 0; i < attr_value.list().shape_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); - strings::StrAppend( - &ret, PrintTensorShape(TensorShape(attr_value.list().shape(i)))); + strings::StrAppend(&ret, + PrintTensorShape(attr_value.list().shape(i))); + } + } else if (attr_value.list().tensor_size() > 0) { + for (int i = 0; i < attr_value.list().tensor_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, + PrintTensorProto(attr_value.list().tensor(i))); } } strings::StrAppend(&ret, "}"); @@ -292,8 +305,8 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) { {"list(bool)", {"gtl::ArraySlice<bool>", true}}, {"type", {"DataType", false}}, {"list(type)", {"DataTypeSlice", true}}, - {"shape", {"TensorShape", false}}, - {"list(shape)", {"gtl::ArraySlice<TensorShape>", true}}, + {"shape", {"PartialTensorShape", false}}, + {"list(shape)", {"gtl::ArraySlice<PartialTensorShape>", true}}, {"tensor", {"TensorProto", true}}, {"list(tensor)", {"gtl::ArraySlice<TensorProto>", true}}, {"func", {"NameAttrList", true}}, |