diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-11-11 13:05:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-11 13:24:28 -0800 |
commit | 20df37f40296662519b89fd6658e43fce7c000b7 (patch) | |
tree | a3a61939e364b0d1205f558c6229a1b6c9f5c831 | |
parent | b772f470d5901684c6a125020e36faa3d1d91744 (diff) |
Switch moreops to using c++ shape functions instead of python.
Change: 138909864
-rw-r--r-- | tensorflow/core/ops/state_ops.cc | 63 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 49 | ||||
-rw-r--r-- | tensorflow/python/ops/data_flow_ops.py | 41 | ||||
-rw-r--r-- | tensorflow/python/ops/random_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/state_ops.py | 34 |
5 files changed, 71 insertions, 118 deletions
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index f6d21b909a..b300fbbe26 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -18,6 +18,7 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; @@ -445,6 +446,63 @@ use_locking: If True, the operation will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); +namespace { + +Status ScatterNdUpdateShape(InferenceContext* c) { + ShapeHandle ref_shape = c->input(0); + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); + ShapeHandle updates_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); + + if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { + const int64 outer_dims = c->Rank(indices_shape) - 1; + const DimensionHandle ixdim = c->Dim(indices_shape, -1); + + // We can only do more validation if the last dimension of indices + // is a known value. + if (c->ValueKnown(ixdim)) { + int64 ix = c->Value(ixdim); + ShapeHandle unused; + ShapeHandle prefix_indices; + TF_RETURN_IF_ERROR( + c->Subshape(indices_shape, 0, outer_dims, &prefix_indices)); + ShapeHandle prefix_updates; + TF_RETURN_IF_ERROR( + c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); + + Status s = c->Merge(prefix_indices, prefix_updates, &unused); + if (!s.ok()) { + return errors::InvalidArgument( + "The outer ", outer_dims, " dimensions of indices.shape=", + c->DebugString(indices_shape), "must match the outer ", outer_dims, + " dimensions of updates.shape=", c->DebugString(updates_shape), + ": ", s.error_message()); + } + + ShapeHandle suffix_ref; + TF_RETURN_IF_ERROR(c->Subshape(ref_shape, ix, &suffix_ref)); + ShapeHandle suffix_updates; + TF_RETURN_IF_ERROR( + c->Subshape(updates_shape, outer_dims, &suffix_updates)); + s = c->Merge(suffix_ref, suffix_updates, &unused); + if (!s.ok()) { + return errors::InvalidArgument( + "The inner ", c->Rank(ref_shape) - ix, " dimensions of ref.shape=", + c->DebugString(ref_shape), "must match the inner ", + c->Rank(updates_shape) - outer_dims, + " dimensions of updates.shape=", c->DebugString(updates_shape), + ": ", s.error_message()); + } + } + } + + c->set_output(0, ref_shape); + return Status::OK(); +} + +} // namespace + REGISTER_OP("ScatterNdUpdate") .Input("ref: Ref(T)") .Input("indices: Tindices") @@ -453,6 +511,7 @@ REGISTER_OP("ScatterNdUpdate") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") + .SetShapeFn(ScatterNdUpdateShape) .Doc(R"doc( Applies sparse `updates` to individual values or slices within a given variable according to `indices`. @@ -509,6 +568,7 @@ REGISTER_OP("ScatterNdAdd") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(ScatterNdUpdateShape) .Doc(R"doc( Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`. @@ -565,6 +625,7 @@ REGISTER_OP("ScatterNdSub") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(ScatterNdUpdateShape) .Doc(R"doc( Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`. @@ -624,6 +685,7 @@ output_ref: Same as ref. Returned as a convenience for operations that want // .Attr("T: numbertype") // .Attr("Tindices: {int32, int64}") // .Attr("use_locking: bool = false") +// .SetShapeFn(ScatterNdUpdateShape) // .Doc( // R"doc(Applies sparse subtraction between `updates` and individual // values or slices within a given variable according to `indices`. @@ -679,6 +741,7 @@ output_ref: Same as ref. Returned as a convenience for operations that want // .Attr("T: numbertype") // .Attr("Tindices: {int32, int64}") // .Attr("use_locking: bool = false") +// .SetShapeFn(ScatterNdUpdateShape) // .Doc( // R"doc(Applies sparse subtraction between `updates` and individual // values or slices within a given variable according to `indices`. diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 1efc2d5380..e4155fadd4 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -3018,54 +3018,9 @@ ops.RegisterShape("RefNextIteration")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ControlTrigger")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("NoOp")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Abort")(common_shapes.call_cpp_shape_fn) - - -@ops.RegisterShape("LoopCond") -def _LoopCondShape(op): - """Shape function for the LoopCond op.""" - return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())] - - +ops.RegisterShape("LoopCond")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Merge")(common_shapes.call_cpp_shape_fn) - - -def _MergeShape(op): - """Shape function for the Merge op. - - The Merge op takes many inputs of arbitrary shapes, and produces a - first output that is one of those inputs, and a second scalar - output. - - If all input shapes are known and have the same rank, the output - shape must have that rank, otherwise the output shape is unknown. - Each output dimension is specified only if that dimension in all - inputs are the same. - - Args: - op: A Merge Operation. - - Returns: - A single-element list containing the Shape of the Merge op. - - """ - output_shape = op.inputs[0].get_shape() - if output_shape.dims is None: - return [tensor_shape.unknown_shape(), tensor_shape.scalar()] - else: - for input_ in op.inputs[1:]: - input_shape = input_.get_shape() - if input_shape.dims is None or input_shape.ndims != output_shape.ndims: - return [tensor_shape.unknown_shape(), tensor_shape.scalar()] - else: - output_shape = tensor_shape.TensorShape( - [input_dim.value if input_dim.value == output_dim.value else None - for input_dim, output_dim in zip(input_shape.dims, - output_shape.dims)]) - return [output_shape, tensor_shape.scalar()] - -ops.RegisterShape("RefMerge")(_MergeShape) - - +ops.RegisterShape("RefMerge")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("RefSelect")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("RefSwitch")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Switch")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index d2de88a9ca..ecadfa62a6 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -1097,45 +1097,8 @@ ops.RegisterShape("BarrierInsertMany")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("GetSessionHandle")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("GetSessionTensor")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("DeleteSessionTensor")(common_shapes.call_cpp_shape_fn) - - -@ops.RegisterShape("DynamicPartition") -def _DynamicPartitionShape(op): - """Shape function for data_flow_ops.dynamic_partition.""" - data_shape = op.inputs[0].get_shape() - partitions_shape = op.inputs[1].get_shape() - # If we don't know the rank of partitions, we don't know anything - mid = partitions_shape.ndims - if mid is None: - result_shape = tensor_shape.unknown_shape() - else: - # data_shape must start with partitions_shape - partitions_shape.assert_is_compatible_with(data_shape[:mid]) - # The partition shape is dynamic in the 0th dimension, and matches - # data_shape in the remaining dimensions. - result_shape = tensor_shape.TensorShape([None]).concatenate( - data_shape[mid:]) - return [result_shape] * op.get_attr("num_partitions") - - -@ops.RegisterShape("DynamicStitch") -def _DynamicStitchShape(op): - """Shape function for data_flow_ops.dynamic_stitch.""" - num_partitions = op.get_attr("N") - indices_shapes = [t.get_shape() for t in op.inputs[0:num_partitions]] - data_shapes = [t.get_shape() for t in op.inputs[num_partitions:]] - output_shape = tensor_shape.unknown_shape() - extra_shape = tensor_shape.TensorShape(None) - for indices_shape, data_shape in zip(indices_shapes, data_shapes): - indices_ndims = indices_shape.ndims - if indices_ndims is not None: - # Assert that data_shape starts with indices_shape - indices_shape.merge_with(data_shape[:indices_ndims]) - # The rest belongs to output - extra_shape = extra_shape.merge_with(data_shape[indices_ndims:]) - return [tensor_shape.TensorShape([None]).concatenate(extra_shape)] - - +ops.RegisterShape("DynamicPartition")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("DynamicStitch")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("LookupTableFind")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("LookupTableInsert")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("LookupTableImport")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index f257acae0c..1755fb57a8 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -460,4 +460,4 @@ def _RandomShape(op): return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0]) -ops.RegisterShape("RandomShuffle")(common_shapes.unchanged_shape) +ops.RegisterShape("RandomShuffle")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index e196bdd3ff..775342a82c 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -219,34 +219,6 @@ ops.RegisterShape("ScatterDiv")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ScatterMul")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ScatterSub")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ScatterUpdate")(common_shapes.call_cpp_shape_fn) - - -@ops.RegisterShape("ScatterNdAdd") -@ops.RegisterShape("ScatterNdSub") -@ops.RegisterShape("ScatterNdMul") -@ops.RegisterShape("ScatterNdDiv") -@ops.RegisterShape("ScatterNdUpdate") -def scatter_nd_update_shape(op): - """Shape function for the ScatterNd update ops.""" - ref_shape = op.inputs[0].get_shape() - indices_shape = op.inputs[1].get_shape() - updates_shape = op.inputs[2].get_shape() - - if indices_shape.ndims is not None and ref_shape.ndims is not None: - outer_dims = len(indices_shape) - 1 - ixdim = indices_shape[-1].value or 0 - - if not indices_shape[:outer_dims].is_compatible_with( - updates_shape[:outer_dims]): - raise ValueError("The outer %d dimensions of indices.shape=%s must " - "match the outer %d dimensions of updates.shape=%s" % ( - outer_dims, indices_shape, outer_dims, - updates_shape)) - - if not ref_shape[ixdim:].is_compatible_with(updates_shape[outer_dims:]): - raise ValueError("The inner %d dimensions of ref.shape=%s must match " - "the inner %d dimensions of updates.shape=%s" % ( - len(ref_shape)-ixdim, ref_shape, - len(updates_shape)-outer_dims, updates_shape)) - - return [ref_shape] +ops.RegisterShape("ScatterNdAdd")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ScatterNdSub")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ScatterNdUpdate")(common_shapes.call_cpp_shape_fn) |