aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vinu Rajashekhar <vinuraja@google.com>2017-12-01 16:24:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 16:29:19 -0800
commitda105bfabc311840024b40d484dd1cd234697e23 (patch)
tree8f24d631b31d3282ff7a89a8bd23ccdd3f6d7060
parent1ee6d7ccbcc20ac3051fd69d7377306e49f5b6dd (diff)
Adds a GuaranteeConstOp.
- Acts as indicator for the TF runtime to make possible optimizations by treating the input tensor as a constant. PiperOrigin-RevId: 177656212
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GuaranteeConst.pbtxt12
-rw-r--r--tensorflow/core/kernels/BUILD26
-rw-r--r--tensorflow/core/kernels/guarantee_const_op.cc47
-rw-r--r--tensorflow/core/kernels/guarantee_const_op_test.cc75
-rw-r--r--tensorflow/core/ops/array_ops.cc20
-rw-r--r--tensorflow/core/ops/array_ops_test.cc7
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py34
-rw-r--r--tensorflow/python/ops/array_ops.py1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
9 files changed, 226 insertions, 0 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_GuaranteeConst.pbtxt b/tensorflow/core/api_def/base_api/api_def_GuaranteeConst.pbtxt
new file mode 100644
index 0000000000..b2a2e1aaef
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_GuaranteeConst.pbtxt
@@ -0,0 +1,12 @@
+op {
+ graph_op_name: "GuaranteeConst"
+ summary: "Gives a guarantee to the TF runtime that the input tensor is a constant."
+ description: <<END
+The runtime is then free to make optimizations based on this.
+
+Only accepts value typed tensors as inputs and rejects resource variable handles
+as input.
+
+Returns the input tensor without modification.
+END
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 8d87915658..a46fbbfc8e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -589,6 +589,7 @@ cc_library(
":extract_image_patches_op",
":gather_nd_op",
":gather_op",
+ ":guarantee_const_op",
":identity_n_op",
":identity_op",
":inplace_ops",
@@ -636,6 +637,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "guarantee_const_op",
+ prefix = "guarantee_const_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "constant_op",
prefix = "constant_op",
deps = ARRAY_DEPS,
@@ -1194,6 +1201,25 @@ tf_cuda_cc_test(
)
tf_cc_test(
+ name = "guarantee_const_op_test",
+ size = "small",
+ srcs = ["guarantee_const_op_test.cc"],
+ deps = [
+ ":guarantee_const_op",
+ ":ops_testutil",
+ ":ops_util",
+ ":variable_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_cc_test(
name = "identity_op_test",
size = "small",
srcs = ["identity_op_test.cc"],
diff --git a/tensorflow/core/kernels/guarantee_const_op.cc b/tensorflow/core/kernels/guarantee_const_op.cc
new file mode 100644
index 0000000000..de3a2a1148
--- /dev/null
+++ b/tensorflow/core/kernels/guarantee_const_op.cc
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+// Refer to the Op description for detailed comments.
+class GuaranteeConstOp : public OpKernel {
+ public:
+ explicit GuaranteeConstOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const DataType input_dtype = ctx->input_dtype(0);
+ OP_REQUIRES(ctx, input_dtype != DT_RESOURCE,
+ errors::InvalidArgument(
+ "Input tensor cannot be a resource variable handle."));
+ const Tensor& input_tensor = ctx->input(0);
+ Tensor* output = nullptr;
+ if (!ctx->forward_input_to_output_with_shape(0, 0, input_tensor.shape(),
+ &output)) {
+ ctx->set_output(0, input_tensor);
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+};
+
+REGISTER_KERNEL_BUILDER(Name("GuaranteeConst").Device(DEVICE_CPU),
+ GuaranteeConstOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/guarantee_const_op_test.cc b/tensorflow/core/kernels/guarantee_const_op_test.cc
new file mode 100644
index 0000000000..01461fbb8c
--- /dev/null
+++ b/tensorflow/core/kernels/guarantee_const_op_test.cc
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+class GuaranteeConstOpTest : public OpsTestBase {
+ protected:
+ Status Init(DataType input_type) {
+ TF_CHECK_OK(NodeDefBuilder("op", "GuaranteeConst")
+ .Input(FakeInput(input_type))
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(GuaranteeConstOpTest, Int32Success_6) {
+ TF_ASSERT_OK(Init(DT_INT32));
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+TEST_F(GuaranteeConstOpTest, Int32Success_2_3) {
+ TF_ASSERT_OK(Init(DT_INT32));
+ AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({2, 3}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+TEST_F(GuaranteeConstOpTest, StringSuccess) {
+ TF_ASSERT_OK(Init(DT_STRING));
+ AddInputFromArray<string>(TensorShape({6}), {"A", "b", "C", "d", "E", "f"});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({6}));
+ test::FillValues<string>(&expected, {"A", "b", "C", "d", "E", "f"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(GuaranteeConstOpTest, ResourceInputError) {
+ TF_ASSERT_OK(Init(DT_RESOURCE));
+ AddResourceInput("", "resource", new Var(DT_INT32));
+ const auto status = RunOpKernel();
+ ASSERT_EQ(error::INVALID_ARGUMENT, status.code());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 6f4ea09206..36d27ea110 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -706,6 +706,26 @@ memory_region_name: Name of readonly memory region used by the tensor, see
NewReadOnlyMemoryRegionFromFile in tensorflow::Env.
)doc");
+REGISTER_OP("GuaranteeConst")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ return UnchangedShape(c);
+ })
+ // We don't want this to be optimized away.
+ .SetIsStateful()
+ .Doc(R"(
+Gives a guarantee to the TF runtime that the input tensor is a constant.
+
+The runtime is then free to make optimizations based on this.
+
+Only accepts value typed tensors as inputs and rejects resource variable handles
+as input.
+
+Returns the input tensor without modification.
+)");
+
// --------------------------------------------------------------------------
REGISTER_OP("ZerosLike")
.Input("x: T")
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 94eb120175..e010ecda8e 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -158,6 +158,13 @@ TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) {
INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
}
+TEST(ArrayOpsTest, GuaranteeConst_ShapeFn) {
+ ShapeInferenceTestOp op("GuaranteeConst");
+ INFER_OK(op, "?", "in0");
+ INFER_OK(op, "[]", "in0");
+ INFER_OK(op, "[1,2,?,4,5]", "in0");
+}
+
TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
const char* op_name = "Identity";
ShapeInferenceTestOp op(op_name);
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 1bf2b70c1b..6d649b1cac 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -34,9 +34,11 @@ from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test as test_lib
@@ -1090,5 +1092,37 @@ class InvertPermutationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1])
+class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
+
+ def testSimple(self):
+ with self.test_session():
+ a = array_ops.constant(10)
+ guarantee_a = array_ops.guarantee_const(a)
+ self.assertEqual(10, guarantee_a.eval())
+
+ def testVariables(self):
+ with self.test_session() as sess:
+ for use_resource in [False, True]:
+ a = variable_scope.get_variable(
+ "var_{}".format(use_resource), [],
+ initializer=init_ops.constant_initializer(10.0),
+ use_resource=use_resource)
+ guarantee_a = array_ops.guarantee_const(a)
+ sess.run(variables.global_variables_initializer())
+ self.assertEqual(10.0, guarantee_a.eval())
+
+ def testResourceRejection(self):
+ with self.test_session() as sess:
+ a = variable_scope.get_variable(
+ "resource_var", [],
+ initializer=init_ops.constant_initializer(10.0),
+ use_resource=True)
+ guarantee_a = array_ops.guarantee_const(a.handle)
+ sess.run(variables.global_variables_initializer())
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ "cannot be a resource variable"):
+ guarantee_a.eval()
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 38eff54c69..23aa74c027 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -70,6 +70,7 @@ See the @{$python/array_ops} guide.
@@quantize_v2
@@quantized_concat
@@setdiff1d
+@@guarantee_const
@@fake_quant_with_min_max_args
@@fake_quant_with_min_max_args_gradient
@@fake_quant_with_min_max_vars
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 57573d5024..e79f2a56f5 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1141,6 +1141,10 @@ tf_module {
argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
}
member_method {
+ name: "guarantee_const"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "hessians"
argspec: "args=[\'ys\', \'xs\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'hessians\', \'False\', \'False\', \'None\'], "
}