diff options
author | 2017-12-01 16:24:27 -0800 | |
---|---|---|
committer | 2017-12-01 16:29:19 -0800 | |
commit | da105bfabc311840024b40d484dd1cd234697e23 (patch) | |
tree | 8f24d631b31d3282ff7a89a8bd23ccdd3f6d7060 | |
parent | 1ee6d7ccbcc20ac3051fd69d7377306e49f5b6dd (diff) |
Adds a GuaranteeConstOp.
- Acts as indicator for the TF runtime to make possible optimizations by treating the input tensor as a constant.
PiperOrigin-RevId: 177656212
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_GuaranteeConst.pbtxt | 12 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 26 | ||||
-rw-r--r-- | tensorflow/core/kernels/guarantee_const_op.cc | 47 | ||||
-rw-r--r-- | tensorflow/core/kernels/guarantee_const_op_test.cc | 75 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops_test.cc | 7 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/array_ops_test.py | 34 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 1 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.pbtxt | 4 |
9 files changed, 226 insertions, 0 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_GuaranteeConst.pbtxt b/tensorflow/core/api_def/base_api/api_def_GuaranteeConst.pbtxt new file mode 100644 index 0000000000..b2a2e1aaef --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GuaranteeConst.pbtxt @@ -0,0 +1,12 @@ +op { + graph_op_name: "GuaranteeConst" + summary: "Gives a guarantee to the TF runtime that the input tensor is a constant." + description: <<END +The runtime is then free to make optimizations based on this. + +Only accepts value typed tensors as inputs and rejects resource variable handles +as input. + +Returns the input tensor without modification. +END +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8d87915658..a46fbbfc8e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -589,6 +589,7 @@ cc_library( ":extract_image_patches_op", ":gather_nd_op", ":gather_op", + ":guarantee_const_op", ":identity_n_op", ":identity_op", ":inplace_ops", @@ -636,6 +637,12 @@ tf_kernel_library( ) tf_kernel_library( + name = "guarantee_const_op", + prefix = "guarantee_const_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( name = "constant_op", prefix = "constant_op", deps = ARRAY_DEPS, @@ -1194,6 +1201,25 @@ tf_cuda_cc_test( ) tf_cc_test( + name = "guarantee_const_op_test", + size = "small", + srcs = ["guarantee_const_op_test.cc"], + deps = [ + ":guarantee_const_op", + ":ops_testutil", + ":ops_util", + ":variable_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( name = "identity_op_test", size = "small", srcs = ["identity_op_test.cc"], diff --git a/tensorflow/core/kernels/guarantee_const_op.cc b/tensorflow/core/kernels/guarantee_const_op.cc new file mode 100644 index 0000000000..de3a2a1148 --- /dev/null +++ b/tensorflow/core/kernels/guarantee_const_op.cc @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +// Refer to the Op description for detailed comments. +class GuaranteeConstOp : public OpKernel { + public: + explicit GuaranteeConstOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const DataType input_dtype = ctx->input_dtype(0); + OP_REQUIRES(ctx, input_dtype != DT_RESOURCE, + errors::InvalidArgument( + "Input tensor cannot be a resource variable handle.")); + const Tensor& input_tensor = ctx->input(0); + Tensor* output = nullptr; + if (!ctx->forward_input_to_output_with_shape(0, 0, input_tensor.shape(), + &output)) { + ctx->set_output(0, input_tensor); + } + } + + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("GuaranteeConst").Device(DEVICE_CPU), + GuaranteeConstOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/guarantee_const_op_test.cc b/tensorflow/core/kernels/guarantee_const_op_test.cc new file mode 100644 index 0000000000..01461fbb8c --- /dev/null +++ b/tensorflow/core/kernels/guarantee_const_op_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class GuaranteeConstOpTest : public OpsTestBase { + protected: + Status Init(DataType input_type) { + TF_CHECK_OK(NodeDefBuilder("op", "GuaranteeConst") + .Input(FakeInput(input_type)) + .Finalize(node_def())); + return InitOp(); + } +}; + +TEST_F(GuaranteeConstOpTest, Int32Success_6) { + TF_ASSERT_OK(Init(DT_INT32)); + AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual<int32>(expected, *GetOutput(0)); +} + +TEST_F(GuaranteeConstOpTest, Int32Success_2_3) { + TF_ASSERT_OK(Init(DT_INT32)); + AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({2, 3})); + test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6}); + test::ExpectTensorEqual<int32>(expected, *GetOutput(0)); +} + +TEST_F(GuaranteeConstOpTest, StringSuccess) { + TF_ASSERT_OK(Init(DT_STRING)); + AddInputFromArray<string>(TensorShape({6}), {"A", "b", "C", "d", "E", "f"}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({6})); + test::FillValues<string>(&expected, {"A", "b", "C", "d", "E", "f"}); + test::ExpectTensorEqual<string>(expected, *GetOutput(0)); +} + +TEST_F(GuaranteeConstOpTest, ResourceInputError) { + TF_ASSERT_OK(Init(DT_RESOURCE)); + AddResourceInput("", "resource", new Var(DT_INT32)); + const auto status = RunOpKernel(); + ASSERT_EQ(error::INVALID_ARGUMENT, status.code()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 6f4ea09206..36d27ea110 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -706,6 +706,26 @@ memory_region_name: Name of readonly memory region used by the tensor, see NewReadOnlyMemoryRegionFromFile in tensorflow::Env. )doc"); +REGISTER_OP("GuaranteeConst") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + return UnchangedShape(c); + }) + // We don't want this to be optimized away. + .SetIsStateful() + .Doc(R"( +Gives a guarantee to the TF runtime that the input tensor is a constant. + +The runtime is then free to make optimizations based on this. + +Only accepts value typed tensors as inputs and rejects resource variable handles +as input. + +Returns the input tensor without modification. +)"); + // -------------------------------------------------------------------------- REGISTER_OP("ZerosLike") .Input("x: T") diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 94eb120175..e010ecda8e 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -158,6 +158,13 @@ TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) { INFER_OK(op, "[1,2,?,4,5];?;?", "in0"); } +TEST(ArrayOpsTest, GuaranteeConst_ShapeFn) { + ShapeInferenceTestOp op("GuaranteeConst"); + INFER_OK(op, "?", "in0"); + INFER_OK(op, "[]", "in0"); + INFER_OK(op, "[1,2,?,4,5]", "in0"); +} + TEST(ArrayOpsTest, Identity_ShapeFnHandles) { const char* op_name = "Identity"; ShapeInferenceTestOp op(op_name); diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 1bf2b70c1b..6d649b1cac 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -34,9 +34,11 @@ from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test as test_lib @@ -1090,5 +1092,37 @@ class InvertPermutationTest(test_util.TensorFlowTestCase): self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1]) +class GuaranteeConstOpTest(test_util.TensorFlowTestCase): + + def testSimple(self): + with self.test_session(): + a = array_ops.constant(10) + guarantee_a = array_ops.guarantee_const(a) + self.assertEqual(10, guarantee_a.eval()) + + def testVariables(self): + with self.test_session() as sess: + for use_resource in [False, True]: + a = variable_scope.get_variable( + "var_{}".format(use_resource), [], + initializer=init_ops.constant_initializer(10.0), + use_resource=use_resource) + guarantee_a = array_ops.guarantee_const(a) + sess.run(variables.global_variables_initializer()) + self.assertEqual(10.0, guarantee_a.eval()) + + def testResourceRejection(self): + with self.test_session() as sess: + a = variable_scope.get_variable( + "resource_var", [], + initializer=init_ops.constant_initializer(10.0), + use_resource=True) + guarantee_a = array_ops.guarantee_const(a.handle) + sess.run(variables.global_variables_initializer()) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "cannot be a resource variable"): + guarantee_a.eval() + + if __name__ == "__main__": test_lib.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 38eff54c69..23aa74c027 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -70,6 +70,7 @@ See the @{$python/array_ops} guide. @@quantize_v2 @@quantized_concat @@setdiff1d +@@guarantee_const @@fake_quant_with_min_max_args @@fake_quant_with_min_max_args_gradient @@fake_quant_with_min_max_vars diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 57573d5024..e79f2a56f5 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1141,6 +1141,10 @@ tf_module { argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None" } member_method { + name: "guarantee_const" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { name: "hessians" argspec: "args=[\'ys\', \'xs\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'hessians\', \'False\', \'False\', \'None\'], " } |