aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-11 10:36:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 10:40:27 -0700
commita887cbfaf9027d89fa8784a1de127d1770382bdb (patch)
treed9a2466733e9818aad594e395cb6b6efc9d997e9
parenteb3cbce2819c619a900287fc2cae47d9b0b99bd2 (diff)
[TF:XLA] Add implementation of ResourceApplyCenteredRMSProp.
Add support to XlaOpKernelContext for accessing inputs by name . PiperOrigin-RevId: 204148428
-rw-r--r--tensorflow/compiler/tests/rmsprop_test.py117
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc81
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc71
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h42
4 files changed, 237 insertions, 74 deletions
diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py
index 9489fded32..ff8bbac911 100644
--- a/tensorflow/compiler/tests/rmsprop_test.py
+++ b/tensorflow/compiler/tests/rmsprop_test.py
@@ -30,31 +30,102 @@ from tensorflow.python.training import rmsprop
class RmspropTest(xla_test.XLATestCase):
+ def _rmsprop_update_numpy(self,
+ var,
+ g,
+ mg,
+ rms,
+ mom,
+ lr,
+ decay=0.9,
+ momentum=0.0,
+ epsilon=1e-10,
+ centered=False):
+ rms_t = rms * decay + (1 - decay) * g * g
+ denom_t = rms_t + epsilon
+ if centered:
+ mg_t = mg * decay + (1 - decay) * g
+ denom_t -= mg_t * mg_t
+ else:
+ mg_t = mg
+ mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
+ var_t = var - mom_t
+ return var_t, mg_t, rms_t, mom_t
+
def testBasic(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
- grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
- rms_opt = rmsprop.RMSPropOptimizer(3.0)
- rms_update = rms_opt.apply_gradients(
- zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
-
- # Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
-
- # Run 3 steps of RMSProp
- for _ in range(3):
- rms_update.run()
-
- # Validate updated params
- self.assertAllCloseAccordingToType(
- np.array([2.91705132e-04, 1.00029182e+00]), var0.eval())
- self.assertAllCloseAccordingToType(
- np.array([2.89990854, 3.89990854]), var1.eval())
+ for centered in [False, True]:
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+ mg0_np = np.array([0.0, 0.0], dtype=dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype)
+ rms0_np = np.array([1.0, 1.0], dtype=dtype)
+ rms1_np = np.array([1.0, 1.0], dtype=dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ learning_rate = 3.0
+ rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered)
+ rms_update = rms_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = rms_opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = rms_opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = rms_opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = rms_opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = rms_opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = rms_opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of RMSProp
+ for _ in range(3):
+ rms_update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
+ var0_np,
+ grads0_np,
+ mg0_np,
+ rms0_np,
+ mom0_np,
+ learning_rate,
+ centered=centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
+ var1_np,
+ grads1_np,
+ mg1_np,
+ rms1_np,
+ mom1_np,
+ learning_rate,
+ centered=centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 03902f012c..98df730249 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -503,34 +503,39 @@ REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes),
class ResourceApplyRMSProp : public XlaOpKernel {
public:
- explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
- DataType type = ctx->input_type(3);
-
- TensorShape var_shape, ms_shape, mom_shape;
- xla::XlaOp var, ms, mom;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
+ TensorShape var_shape, ms_shape, mom_shape, mg_shape;
+ xla::XlaOp var, ms, mom, mg;
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput("var", dtype_, &var_shape, &var));
+ if (centered_) {
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg));
+ }
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom));
- TensorShape lr_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape("lr");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
errors::InvalidArgument("lr is not a scalar: ",
lr_shape.DebugString()));
- TensorShape rho_shape = ctx->InputShape(4);
+ TensorShape rho_shape = ctx->InputShape("rho");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
errors::InvalidArgument("rho is not a scalar: ",
rho_shape.DebugString()));
- TensorShape momentum_shape = ctx->InputShape(5);
+ TensorShape momentum_shape = ctx->InputShape("momentum");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
- TensorShape epsilon_shape = ctx->InputShape(6);
+ TensorShape epsilon_shape = ctx->InputShape("epsilon");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon_shape.DebugString()));
- TensorShape grad_shape = ctx->InputShape(7);
+ TensorShape grad_shape = ctx->InputShape("grad");
// var should be the same shape as mom and ms.
OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape),
@@ -546,11 +551,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::XlaOp lr = ctx->Input(3);
- xla::XlaOp rho = ctx->Input(4);
- xla::XlaOp momentum = ctx->Input(5);
- xla::XlaOp epsilon = ctx->Input(6);
- xla::XlaOp grad = ctx->Input(7);
+ xla::XlaOp lr = ctx->Input("lr");
+ xla::XlaOp rho = ctx->Input("rho");
+ xla::XlaOp momentum = ctx->Input("momentum");
+ xla::XlaOp epsilon = ctx->Input("epsilon");
+ xla::XlaOp grad = ctx->Input("grad");
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
@@ -569,20 +574,46 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::XlaOp new_ms =
- ms + (xla::Square(grad) - ms) * (xla::ScalarLike(ms, 1.0) - rho);
- xla::XlaOp new_mom =
- mom * momentum + grad * lr * xla::Rsqrt(new_ms + epsilon);
+ xla::XlaOp one = xla::ScalarLike(ms, 1.0);
+ xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho;
+ xla::XlaOp denominator;
+ if (centered_) {
+ mg = grad * (one - rho) + mg * rho;
+ denominator = new_ms - xla::Square(mg) + epsilon;
+ } else {
+ denominator = new_ms + epsilon;
+ }
+ xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator);
xla::XlaOp new_var = var - new_mom;
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var));
+ if (centered_) {
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg));
+ }
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom));
}
+
+ protected:
+ bool centered_ = false;
+
+ private:
+ DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
ResourceApplyRMSProp);
+class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp {
+ public:
+ explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx)
+ : ResourceApplyRMSProp(ctx) {
+ centered_ = true;
+ }
+};
+REGISTER_XLA_OP(
+ Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes),
+ ResourceApplyCenteredRMSProp);
+
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
xla::XlaBuilder* b = ctx->builder();
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 359cb4c467..e8eafb3819 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -66,10 +66,18 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) {
return GetComputationFromTensor(context_->input(index));
}
+const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) {
+ return GetComputationFromTensor(GetInputTensorByName(name));
+}
+
TensorShape XlaOpKernelContext::InputShape(int index) {
return context_->input(index).shape();
}
+TensorShape XlaOpKernelContext::InputShape(StringPiece name) {
+ return GetInputTensorByName(name).shape();
+}
+
DataType XlaOpKernelContext::input_type(int index) const {
return context_->input(index).dtype();
}
@@ -332,10 +340,11 @@ Status XlaOpKernelContext::ConstantInputList(
return Status::OK();
}
-Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
- TensorShape* shape,
- xla::XlaOp* value) {
- const Tensor& tensor = context_->input(index);
+namespace {
+
+Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
+ const OpKernelContext* ctx, TensorShape* shape,
+ xla::XlaOp* value) {
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
@@ -353,7 +362,7 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
*shape = variable->shape();
}
- XlaContext& xla_context = XlaContext::Get(context_);
+ XlaContext& xla_context = XlaContext::Get(ctx);
TF_ASSIGN_OR_RETURN(
TensorShape representation_shape,
xla_context.RepresentationShape(variable->shape(), variable->type()));
@@ -365,6 +374,22 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
return Status::OK();
}
+} // namespace
+
+Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
+ return ReadVariableInputTensor(context_->input(index), type, context_, shape,
+ value);
+}
+
+Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
+ return ReadVariableInputTensor(GetInputTensorByName(name), type, context_,
+ shape, value);
+}
+
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
@@ -455,17 +480,17 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
return Status::OK();
}
-Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
- xla::XlaOp handle) {
- TF_RET_CHECK(handle.valid());
+namespace {
- const XlaExpression* expression =
- CastExpressionFromTensor(context_->input(input_index));
+Status AssignVariableTensor(const Tensor& tensor, DataType type,
+ const OpKernelContext* ctx, xla::XlaOp handle,
+ xla::XlaBuilder* builder) {
+ const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
- auto shape_or_status = builder()->GetShape(handle);
+ auto shape_or_status = builder->GetShape(handle);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
@@ -475,7 +500,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
- XlaContext& xla_context = XlaContext::Get(context_);
+ XlaContext& xla_context = XlaContext::Get(ctx);
TF_ASSIGN_OR_RETURN(TensorShape representation_shape,
xla_context.RepresentationShape(shape, type));
if (shape != representation_shape) {
@@ -484,6 +509,22 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
return variable->SetValue(handle);
}
+} // namespace
+
+Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.valid());
+ return AssignVariableTensor(context_->input(input_index), type, context_,
+ handle, builder());
+}
+
+Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type,
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.valid());
+ return AssignVariableTensor(GetInputTensorByName(name), type, context_,
+ handle, builder());
+}
+
XlaCompiler* XlaOpKernelContext::compiler() const {
return XlaContext::Get(context_).compiler();
}
@@ -523,6 +564,12 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
return XlaContext::Get(context_).GetOrCreateMul(type);
}
+const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) {
+ const Tensor* tensor;
+ CHECK(context_->input(name, &tensor).ok());
+ return *tensor;
+}
+
XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
void XlaOpKernel::Compute(OpKernelContext* context) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 2bde2c983d..6203cffd80 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -67,21 +67,26 @@ class XlaOpKernelContext {
// Returns the number of inputs to the operator.
int num_inputs() const { return context_->num_inputs(); }
- // Returns the type of input 'index'.
+ // Returns the type of input `index`.
DataType input_type(int index) const;
- // Returns the type of input 'index' as an xla::PrimitiveType. If the type
+ // Returns the type of input `index` as an xla::PrimitiveType. If the type
// is not representable as an XLA type, sets an error status and returns
// xla::PRIMITIVE_TYPE_INVALID.
xla::PrimitiveType input_xla_type(int index);
- // Returns the shape of input 'index'.
+ // Returns the shape of input `index`.
TensorShape InputShape(int index);
- // Returns input 'index' as a XlaOp. Unlike
+ // Returns the shape of input `name`.
+ TensorShape InputShape(StringPiece name);
+
+ // Returns input `index` as a XlaOp. Unlike
// OpKernelContext::Input returns a symbolic value rather than a concrete
// Tensor.
const xla::XlaOp& Input(int index);
+ // Returns input `name` as a XlaOp.
+ const xla::XlaOp& Input(StringPiece name);
// Returns true if all inputs are the same shape, otherwise sets the
// status to a non-OK value and returns false.
@@ -96,13 +101,13 @@ class XlaOpKernelContext {
// Helper methods for constant inputs.
- // Evaluates input 'index' and stores it in '*constant_literal'. If the
+ // Evaluates input `index` and stores it in `*constant_literal`. If the
// expression cannot be evaluated, e.g., because it depends on unbound
// parameters, returns a non-OK status.
Status ConstantInput(int index, xla::Literal* constant_literal);
- // Evaluates input 'index', reshapes it to 'new_shape' if new_shape !=
- // InputShape(index), and stores it in '*constant_literal'. If the input
+ // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
+ // InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
@@ -137,17 +142,17 @@ class XlaOpKernelContext {
return context_->expected_output_dtype(index);
}
- // Sets output 'index' to the XlaOp 'handle'.
+ // Sets output `index` to the XlaOp `handle`.
// All outputs should be set using SetOutput and SetConstantOutput, not
// via the underlying OpKernelContext.
void SetOutput(int index, const xla::XlaOp& handle);
- // Sets output 'index' to compile-time constant 'host_tensor', where
- // 'host_tensor' is a tensor in host memory. It is preferable to use
+ // Sets output `index` to compile-time constant `host_tensor`, where
+ // `host_tensor` is a tensor in host memory. It is preferable to use
// SetConstantOutput where possible.
void SetConstantOutput(int index, const Tensor& host_tensor);
- // Sets output 'index' to an invalid value.
+ // Sets output `index` to an invalid value.
// Any subsequent attempt to consume this output will cause an error.
void SetInvalidOutput(int index);
@@ -157,10 +162,10 @@ class XlaOpKernelContext {
// Variables
- // Sets '*resource' to the resource associated with input `index`.
+ // Sets `*resource` to the resource associated with input `index`.
Status GetResourceInput(int index, XlaResource** resource);
- // Sets output 'index' to be a reference to resource 'resource'.
+ // Sets output `index` to be a reference to resource `resource`.
void SetResourceOutput(int index, XlaResource* resource);
// Sets `*type` and `*shape` to the current type and shape of a variable's
@@ -169,17 +174,23 @@ class XlaOpKernelContext {
TensorShape* shape) const;
// Reads the current value of the resouce variable referred to by input
- // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
+ // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
// variable. Returns an error if the variable has not been initialized, or if
// its type does not match `type`.
Status ReadVariableInput(int index, DataType type, TensorShape* shape,
xla::XlaOp* value);
+ // Reads the current value of the resouce variable referred to by input
+ // `name`.
+ Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape,
+ xla::XlaOp* value);
// Assigns the value `handle` to the variable referenced by input
// `input_index`. The variable must be of `type`. Returns an error if the
// variable has been initialized with a different type or with a
// different shape.
Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
+ // Assigns the value `handle` to the variable referenced by input `name`.
+ Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle);
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);
@@ -227,6 +238,9 @@ class XlaOpKernelContext {
const xla::XlaComputation* GetOrCreateMul(const DataType type);
private:
+ // Returns the tensor of input `name`.
+ const Tensor& GetInputTensorByName(StringPiece name);
+
OpKernelContext* const context_;
};