aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/shape_util.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-07-05 10:10:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-05 10:14:25 -0700
commit19220809123803af8fdcbbe628ee9f80fb4521c8 (patch)
tree674109a5d013d044c31c5b265307f4a7d862b2c2 /tensorflow/compiler/tf2xla/shape_util.cc
parent94d52acdc0087d5829f220c4d46eea67e0d30305 (diff)
[TF:XLA] Refactor XLAShapeToTensorShape so it returns an error Status if passed an XLA tuple shape, rather than CHECK-failing.
PiperOrigin-RevId: 160971216
Diffstat (limited to 'tensorflow/compiler/tf2xla/shape_util.cc')
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.cc14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc
index f5ecb51a5b..9d1992205b 100644
--- a/tensorflow/compiler/tf2xla/shape_util.cc
+++ b/tensorflow/compiler/tf2xla/shape_util.cc
@@ -24,12 +24,18 @@ limitations under the License.
namespace tensorflow {
// Convert an XLA Shape into the equivalent TensorFlow shape.
-TensorShape XLAShapeToTensorShape(const xla::Shape& shape) {
- TensorShape tensor_shape;
+Status XLAShapeToTensorShape(const xla::Shape& shape,
+ TensorShape* tensor_shape) {
+ if (xla::ShapeUtil::IsTuple(shape)) {
+ return errors::InvalidArgument("XLA shape ",
+ xla::ShapeUtil::HumanString(shape),
+ " cannot be converted to a TensorShape");
+ }
+ *tensor_shape = TensorShape();
for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) {
- tensor_shape.AddDim(shape.dimensions(i));
+ tensor_shape->AddDim(shape.dimensions(i));
}
- return tensor_shape;
+ return Status::OK();
}
// Convert a TensorShape into the equivalent XLA Shape proto.