diff options
author | 2018-09-26 13:48:21 -0700 | |
---|---|---|
committer | 2018-09-26 13:51:50 -0700 | |
commit | 1736e0bbbfdeeba178dff37c970b5a0180ee013f (patch) | |
tree | 390c309b5997a752644d2c50bb4ee5bf8fc1654d /tensorflow/core/ops | |
parent | 652ce1aaefdadd04a9905a0788ab26c6fff93658 (diff) |
[TF] Add new internal ops _VarHandlesOp and _ReadVariablesOp.
The purpose of these ops is to fix a latency problem observed for an inference benchmark. Often a inference step starts by reading the value of many (hundreds) of weights. For a resource variable, this requires a VarHandleOp and a ReadVariableOp per variable. Running hundreds of trivial ops can add hundreds of microseconds of latency to the critical path of an inference step. The inter-op latency of the executor can be hundreds of nanoseconds, which rapidly adds up.
This change introduces two fused ops _VarHandlesOp and _ReadVariablesOp that allow us to read many variables in a pair of larger ops, rather than many tiny ops.
PiperOrigin-RevId: 214662338
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r-- | tensorflow/core/ops/resource_variable_ops.cc | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 26499540f1..adc9cd1486 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -19,6 +19,7 @@ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" using ::tensorflow::shape_inference::InferenceContext; using ::tensorflow::shape_inference::ShapeAndType; @@ -56,6 +57,36 @@ Status ReadVariableShapeFn(InferenceContext* c) { return Status::OK(); } +Status ReadVariablesShapeFn(InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + DataTypeVector value_dtypes; + TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes)); + if (n != value_dtypes.size()) { + return errors::InvalidArgument( + "Mismatched number of arguments to ReadVariablesOp"); + } + for (int i = 0; i < n; ++i) { + ShapeAndType shape_and_type; + auto* handle_data = c->input_handle_shapes_and_types(i); + if (handle_data == nullptr || handle_data->empty()) { + shape_and_type.shape = c->UnknownShape(); + shape_and_type.dtype = DT_INVALID; + } else { + shape_and_type = (*handle_data)[0]; + if (shape_and_type.dtype != value_dtypes[i]) { + return errors::InvalidArgument( + "Trying to read variable with wrong dtype. " + "Expected ", + DataTypeString(shape_and_type.dtype), " got ", + DataTypeString(value_dtypes[i])); + } + } + c->set_output(i, shape_and_type.shape); + } + return Status::OK(); +} + } // namespace REGISTER_OP("VarHandleOp") @@ -79,12 +110,53 @@ REGISTER_OP("VarHandleOp") return Status::OK(); }); +REGISTER_OP("_VarHandlesOp") + .Attr("containers: list(string)") + .Attr("shared_names: list(string)") + .Attr("N: int >= 0") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Output("resources: N * resource") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + DataTypeVector dtypes; + TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes)); + std::vector<PartialTensorShape> shapes; + TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); + if (dtypes.size() != n) { + return errors::InvalidArgument("Mismatched number of dtypes (n=", n, + ", num dtypes=", dtypes.size(), ")"); + } + if (shapes.size() != n) { + return errors::InvalidArgument("Mismatched number of shapes (n=", n, + ", num shapes=", shapes.size(), ")"); + } + for (int i = 0; i < n; ++i) { + c->set_output(i, c->Scalar()); + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s)); + c->set_output_handle_shapes_and_types( + i, std::vector<ShapeAndType>{{s, dtypes[i]}}); + } + + return Status::OK(); + }); + REGISTER_OP("ReadVariableOp") .Input("resource: resource") .Output("value: dtype") .Attr("dtype: type") .SetShapeFn(ReadVariableShapeFn); +REGISTER_OP("_ReadVariablesOp") + .Attr("N: int >= 0") + .Input("resources: N * resource") + .Output("values: dtypes") + .Attr("dtypes: list(type)") + .SetShapeFn(ReadVariablesShapeFn); + Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FunctionDefHelper::Define( |