aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-11-11 13:05:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 13:24:28 -0800
commit20df37f40296662519b89fd6658e43fce7c000b7 (patch)
treea3a61939e364b0d1205f558c6229a1b6c9f5c831
parentb772f470d5901684c6a125020e36faa3d1d91744 (diff)
Switch moreops to using c++ shape functions instead of python.
Change: 138909864
-rw-r--r--tensorflow/core/ops/state_ops.cc63
-rw-r--r--tensorflow/python/ops/control_flow_ops.py49
-rw-r--r--tensorflow/python/ops/data_flow_ops.py41
-rw-r--r--tensorflow/python/ops/random_ops.py2
-rw-r--r--tensorflow/python/ops/state_ops.py34
5 files changed, 71 insertions, 118 deletions
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index f6d21b909a..b300fbbe26 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
@@ -445,6 +446,63 @@ use_locking: If True, the operation will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
+namespace {
+
+Status ScatterNdUpdateShape(InferenceContext* c) {
+ ShapeHandle ref_shape = c->input(0);
+ ShapeHandle indices_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
+ ShapeHandle updates_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
+
+ if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
+ const int64 outer_dims = c->Rank(indices_shape) - 1;
+ const DimensionHandle ixdim = c->Dim(indices_shape, -1);
+
+ // We can only do more validation if the last dimension of indices
+ // is a known value.
+ if (c->ValueKnown(ixdim)) {
+ int64 ix = c->Value(ixdim);
+ ShapeHandle unused;
+ ShapeHandle prefix_indices;
+ TF_RETURN_IF_ERROR(
+ c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
+ ShapeHandle prefix_updates;
+ TF_RETURN_IF_ERROR(
+ c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
+
+ Status s = c->Merge(prefix_indices, prefix_updates, &unused);
+ if (!s.ok()) {
+ return errors::InvalidArgument(
+ "The outer ", outer_dims, " dimensions of indices.shape=",
+ c->DebugString(indices_shape), "must match the outer ", outer_dims,
+ " dimensions of updates.shape=", c->DebugString(updates_shape),
+ ": ", s.error_message());
+ }
+
+ ShapeHandle suffix_ref;
+ TF_RETURN_IF_ERROR(c->Subshape(ref_shape, ix, &suffix_ref));
+ ShapeHandle suffix_updates;
+ TF_RETURN_IF_ERROR(
+ c->Subshape(updates_shape, outer_dims, &suffix_updates));
+ s = c->Merge(suffix_ref, suffix_updates, &unused);
+ if (!s.ok()) {
+ return errors::InvalidArgument(
+ "The inner ", c->Rank(ref_shape) - ix, " dimensions of ref.shape=",
+ c->DebugString(ref_shape), "must match the inner ",
+ c->Rank(updates_shape) - outer_dims,
+ " dimensions of updates.shape=", c->DebugString(updates_shape),
+ ": ", s.error_message());
+ }
+ }
+ }
+
+ c->set_output(0, ref_shape);
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("ScatterNdUpdate")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
@@ -453,6 +511,7 @@ REGISTER_OP("ScatterNdUpdate")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
+ .SetShapeFn(ScatterNdUpdateShape)
.Doc(R"doc(
Applies sparse `updates` to individual values or slices within a given
variable according to `indices`.
@@ -509,6 +568,7 @@ REGISTER_OP("ScatterNdAdd")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterNdUpdateShape)
.Doc(R"doc(
Applies sparse addition between `updates` and individual values or slices
within a given variable according to `indices`.
@@ -565,6 +625,7 @@ REGISTER_OP("ScatterNdSub")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterNdUpdateShape)
.Doc(R"doc(
Applies sparse subtraction between `updates` and individual values or slices
within a given variable according to `indices`.
@@ -624,6 +685,7 @@ output_ref: Same as ref. Returned as a convenience for operations that want
// .Attr("T: numbertype")
// .Attr("Tindices: {int32, int64}")
// .Attr("use_locking: bool = false")
+// .SetShapeFn(ScatterNdUpdateShape)
// .Doc(
// R"doc(Applies sparse subtraction between `updates` and individual
// values or slices within a given variable according to `indices`.
@@ -679,6 +741,7 @@ output_ref: Same as ref. Returned as a convenience for operations that want
// .Attr("T: numbertype")
// .Attr("Tindices: {int32, int64}")
// .Attr("use_locking: bool = false")
+// .SetShapeFn(ScatterNdUpdateShape)
// .Doc(
// R"doc(Applies sparse subtraction between `updates` and individual
// values or slices within a given variable according to `indices`.
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 1efc2d5380..e4155fadd4 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -3018,54 +3018,9 @@ ops.RegisterShape("RefNextIteration")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ControlTrigger")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("NoOp")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Abort")(common_shapes.call_cpp_shape_fn)
-
-
-@ops.RegisterShape("LoopCond")
-def _LoopCondShape(op):
- """Shape function for the LoopCond op."""
- return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
-
-
+ops.RegisterShape("LoopCond")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Merge")(common_shapes.call_cpp_shape_fn)
-
-
-def _MergeShape(op):
- """Shape function for the Merge op.
-
- The Merge op takes many inputs of arbitrary shapes, and produces a
- first output that is one of those inputs, and a second scalar
- output.
-
- If all input shapes are known and have the same rank, the output
- shape must have that rank, otherwise the output shape is unknown.
- Each output dimension is specified only if that dimension in all
- inputs are the same.
-
- Args:
- op: A Merge Operation.
-
- Returns:
- A single-element list containing the Shape of the Merge op.
-
- """
- output_shape = op.inputs[0].get_shape()
- if output_shape.dims is None:
- return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
- else:
- for input_ in op.inputs[1:]:
- input_shape = input_.get_shape()
- if input_shape.dims is None or input_shape.ndims != output_shape.ndims:
- return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
- else:
- output_shape = tensor_shape.TensorShape(
- [input_dim.value if input_dim.value == output_dim.value else None
- for input_dim, output_dim in zip(input_shape.dims,
- output_shape.dims)])
- return [output_shape, tensor_shape.scalar()]
-
-ops.RegisterShape("RefMerge")(_MergeShape)
-
-
+ops.RegisterShape("RefMerge")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("RefSelect")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("RefSwitch")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Switch")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index d2de88a9ca..ecadfa62a6 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1097,45 +1097,8 @@ ops.RegisterShape("BarrierInsertMany")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("GetSessionHandle")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("GetSessionTensor")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("DeleteSessionTensor")(common_shapes.call_cpp_shape_fn)
-
-
-@ops.RegisterShape("DynamicPartition")
-def _DynamicPartitionShape(op):
- """Shape function for data_flow_ops.dynamic_partition."""
- data_shape = op.inputs[0].get_shape()
- partitions_shape = op.inputs[1].get_shape()
- # If we don't know the rank of partitions, we don't know anything
- mid = partitions_shape.ndims
- if mid is None:
- result_shape = tensor_shape.unknown_shape()
- else:
- # data_shape must start with partitions_shape
- partitions_shape.assert_is_compatible_with(data_shape[:mid])
- # The partition shape is dynamic in the 0th dimension, and matches
- # data_shape in the remaining dimensions.
- result_shape = tensor_shape.TensorShape([None]).concatenate(
- data_shape[mid:])
- return [result_shape] * op.get_attr("num_partitions")
-
-
-@ops.RegisterShape("DynamicStitch")
-def _DynamicStitchShape(op):
- """Shape function for data_flow_ops.dynamic_stitch."""
- num_partitions = op.get_attr("N")
- indices_shapes = [t.get_shape() for t in op.inputs[0:num_partitions]]
- data_shapes = [t.get_shape() for t in op.inputs[num_partitions:]]
- output_shape = tensor_shape.unknown_shape()
- extra_shape = tensor_shape.TensorShape(None)
- for indices_shape, data_shape in zip(indices_shapes, data_shapes):
- indices_ndims = indices_shape.ndims
- if indices_ndims is not None:
- # Assert that data_shape starts with indices_shape
- indices_shape.merge_with(data_shape[:indices_ndims])
- # The rest belongs to output
- extra_shape = extra_shape.merge_with(data_shape[indices_ndims:])
- return [tensor_shape.TensorShape([None]).concatenate(extra_shape)]
-
-
+ops.RegisterShape("DynamicPartition")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("DynamicStitch")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("LookupTableFind")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("LookupTableInsert")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("LookupTableImport")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index f257acae0c..1755fb57a8 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -460,4 +460,4 @@ def _RandomShape(op):
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
-ops.RegisterShape("RandomShuffle")(common_shapes.unchanged_shape)
+ops.RegisterShape("RandomShuffle")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index e196bdd3ff..775342a82c 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -219,34 +219,6 @@ ops.RegisterShape("ScatterDiv")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ScatterMul")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ScatterSub")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ScatterUpdate")(common_shapes.call_cpp_shape_fn)
-
-
-@ops.RegisterShape("ScatterNdAdd")
-@ops.RegisterShape("ScatterNdSub")
-@ops.RegisterShape("ScatterNdMul")
-@ops.RegisterShape("ScatterNdDiv")
-@ops.RegisterShape("ScatterNdUpdate")
-def scatter_nd_update_shape(op):
- """Shape function for the ScatterNd update ops."""
- ref_shape = op.inputs[0].get_shape()
- indices_shape = op.inputs[1].get_shape()
- updates_shape = op.inputs[2].get_shape()
-
- if indices_shape.ndims is not None and ref_shape.ndims is not None:
- outer_dims = len(indices_shape) - 1
- ixdim = indices_shape[-1].value or 0
-
- if not indices_shape[:outer_dims].is_compatible_with(
- updates_shape[:outer_dims]):
- raise ValueError("The outer %d dimensions of indices.shape=%s must "
- "match the outer %d dimensions of updates.shape=%s" % (
- outer_dims, indices_shape, outer_dims,
- updates_shape))
-
- if not ref_shape[ixdim:].is_compatible_with(updates_shape[outer_dims:]):
- raise ValueError("The inner %d dimensions of ref.shape=%s must match "
- "the inner %d dimensions of updates.shape=%s" % (
- len(ref_shape)-ixdim, ref_shape,
- len(updates_shape)-outer_dims, updates_shape))
-
- return [ref_shape]
+ops.RegisterShape("ScatterNdAdd")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ScatterNdSub")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ScatterNdUpdate")(common_shapes.call_cpp_shape_fn)