aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-09-26 13:48:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 13:51:50 -0700
commit1736e0bbbfdeeba178dff37c970b5a0180ee013f (patch)
tree390c309b5997a752644d2c50bb4ee5bf8fc1654d /tensorflow/core/ops
parent652ce1aaefdadd04a9905a0788ab26c6fff93658 (diff)
[TF] Add new internal ops _VarHandlesOp and _ReadVariablesOp.
The purpose of these ops is to fix a latency problem observed for an inference benchmark. Often a inference step starts by reading the value of many (hundreds) of weights. For a resource variable, this requires a VarHandleOp and a ReadVariableOp per variable. Running hundreds of trivial ops can add hundreds of microseconds of latency to the critical path of an inference step. The inter-op latency of the executor can be hundreds of nanoseconds, which rapidly adds up. This change introduces two fused ops _VarHandlesOp and _ReadVariablesOp that allow us to read many variables in a pair of larger ops, rather than many tiny ops. PiperOrigin-RevId: 214662338
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc72
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 26499540f1..adc9cd1486 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -19,6 +19,7 @@
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeAndType;
@@ -56,6 +57,36 @@ Status ReadVariableShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status ReadVariablesShapeFn(InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector value_dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes));
+ if (n != value_dtypes.size()) {
+ return errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp");
+ }
+ for (int i = 0; i < n; ++i) {
+ ShapeAndType shape_and_type;
+ auto* handle_data = c->input_handle_shapes_and_types(i);
+ if (handle_data == nullptr || handle_data->empty()) {
+ shape_and_type.shape = c->UnknownShape();
+ shape_and_type.dtype = DT_INVALID;
+ } else {
+ shape_and_type = (*handle_data)[0];
+ if (shape_and_type.dtype != value_dtypes[i]) {
+ return errors::InvalidArgument(
+ "Trying to read variable with wrong dtype. "
+ "Expected ",
+ DataTypeString(shape_and_type.dtype), " got ",
+ DataTypeString(value_dtypes[i]));
+ }
+ }
+ c->set_output(i, shape_and_type.shape);
+ }
+ return Status::OK();
+}
+
} // namespace
REGISTER_OP("VarHandleOp")
@@ -79,12 +110,53 @@ REGISTER_OP("VarHandleOp")
return Status::OK();
});
+REGISTER_OP("_VarHandlesOp")
+ .Attr("containers: list(string)")
+ .Attr("shared_names: list(string)")
+ .Attr("N: int >= 0")
+ .Attr("dtypes: list(type)")
+ .Attr("shapes: list(shape)")
+ .Output("resources: N * resource")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
+ std::vector<PartialTensorShape> shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+ if (dtypes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of dtypes (n=", n,
+ ", num dtypes=", dtypes.size(), ")");
+ }
+ if (shapes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of shapes (n=", n,
+ ", num shapes=", shapes.size(), ")");
+ }
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s));
+ c->set_output_handle_shapes_and_types(
+ i, std::vector<ShapeAndType>{{s, dtypes[i]}});
+ }
+
+ return Status::OK();
+ });
+
REGISTER_OP("ReadVariableOp")
.Input("resource: resource")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn(ReadVariableShapeFn);
+REGISTER_OP("_ReadVariablesOp")
+ .Attr("N: int >= 0")
+ .Input("resources: N * resource")
+ .Output("values: dtypes")
+ .Attr("dtypes: list(type)")
+ .SetShapeFn(ReadVariablesShapeFn);
+
Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FunctionDefHelper::Define(