aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r--tensorflow/core/ops/array_ops.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 2a8b9f9bee..88fc03826a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -429,6 +429,58 @@ REGISTER_OP("UnravelIndex")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) { return Status::OK(); });
+REGISTER_OP("BroadcastTo")
+ .Input("input: T")
+ .Input("shape: Tidx")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle in = c->input(0);
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
+
+ if (!c->RankKnown(out)) {
+ // We have no information about the shape of the output.
+ c->set_output(0, out);
+ return Status::OK();
+ }
+
+ if (!c->RankKnown(in)) {
+ // We have no information about the shape of the input,
+ // nothing to do here.
+ c->set_output(0, out);
+ return Status::OK();
+ }
+ if (c->Rank(out) < c->Rank(in)) {
+ return errors::InvalidArgument("Cannot broadcast a tensor with shape ",
+ c->DebugString(in), " shape ",
+ c->DebugString(out));
+ }
+
+ int32 in_offset = c->Rank(out) - c->Rank(in);
+ for (int32 i = 0; i < c->Rank(out); ++i) {
+ DimensionHandle dim = c->Dim(out, i);
+ if (c->ValueKnown(dim)) {
+ // The first in_offset dimensions for input will be expanded with 1,
+ // so no check needed.
+ if (i >= in_offset) {
+ DimensionHandle in_dim = c->Dim(in, i - in_offset);
+ if (c->ValueKnown(in_dim)) {
+ if (c->Value(dim) % c->Value(in_dim) != 0) {
+ return errors::InvalidArgument(
+ "Cannot broadcast a tensor with shape ", c->DebugString(in),
+ " shape ", c->DebugString(out));
+ }
+ }
+ }
+ }
+ }
+
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
// --------------------------------------------------------------------------
// TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
// in the N == 1 case to remove the node.