diff options
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 2a8b9f9bee..88fc03826a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -429,6 +429,58 @@ REGISTER_OP("UnravelIndex") .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }); +REGISTER_OP("BroadcastTo") + .Input("input: T") + .Input("shape: Tidx") + .Output("output: T") + .Attr("T: type") + .Attr("Tidx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle in = c->input(0); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); + + if (!c->RankKnown(out)) { + // We have no information about the shape of the output. + c->set_output(0, out); + return Status::OK(); + } + + if (!c->RankKnown(in)) { + // We have no information about the shape of the input, + // nothing to do here. + c->set_output(0, out); + return Status::OK(); + } + if (c->Rank(out) < c->Rank(in)) { + return errors::InvalidArgument("Cannot broadcast a tensor with shape ", + c->DebugString(in), " shape ", + c->DebugString(out)); + } + + int32 in_offset = c->Rank(out) - c->Rank(in); + for (int32 i = 0; i < c->Rank(out); ++i) { + DimensionHandle dim = c->Dim(out, i); + if (c->ValueKnown(dim)) { + // The first in_offset dimensions for input will be expanded with 1, + // so no check needed. + if (i >= in_offset) { + DimensionHandle in_dim = c->Dim(in, i - in_offset); + if (c->ValueKnown(in_dim)) { + if (c->Value(dim) % c->Value(in_dim) != 0) { + return errors::InvalidArgument( + "Cannot broadcast a tensor with shape ", c->DebugString(in), + " shape ", c->DebugString(out)); + } + } + } + } + } + + c->set_output(0, out); + return Status::OK(); + }); + // -------------------------------------------------------------------------- // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph // in the N == 1 case to remove the node. |