aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 18:09:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 18:12:04 -0700
commit94b2d2db576a6cce878aee92d6b1f90ded4278b4 (patch)
treee551c5d591f1c506496f0d6cff5f0faecb6ef1cd
parent4979e1d55783d05520fda56fd89641f817daf119 (diff)
[XLA] Remove CheckShape and CheckSameShape in ComputationBuilder, they are not/rarely used.
PiperOrigin-RevId: 190706088
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc20
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h9
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc26
3 files changed, 20 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 39d02f0863..4d3b0ee0d6 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -253,26 +253,6 @@ StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
return std::move(*response.mutable_program_shape());
}
-ComputationDataHandle ComputationBuilder::CheckShape(
- const ComputationDataHandle& operand, const Shape& expected_shape) {
- std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
- CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
- << "want " << ShapeUtil::HumanString(expected_shape) << " got "
- << ShapeUtil::HumanString(*actual_shape);
- return operand;
-}
-
-void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) {
- std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie();
- std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie();
- VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals "
- << ShapeUtil::HumanString(*rhs_shape);
- CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape))
- << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs "
- << ShapeUtil::HumanString(*rhs_shape);
-}
-
ComputationDataHandle ComputationBuilder::Slice(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 2141ebc206..019c6f3afb 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -104,15 +104,6 @@ class ComputationBuilder {
// Retrieves the (inferred) result for the current computation's shape.
StatusOr<ProgramShape> GetProgramShape();
- // Checks that the operand has the given expected shape. Returns the operand
- // if yes, fails with a CHECK error if no.
- ComputationDataHandle CheckShape(const ComputationDataHandle& operand,
- const Shape& expected_shape);
-
- // Checks that the lhs and rhs results have the same shape.
- void CheckSameShape(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs);
-
// Enqueues a constant with the value of the given literal onto the
// computation.
ComputationDataHandle ConstantLiteral(const Literal& literal);
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 28ab965499..af8af99c79 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -69,6 +69,17 @@ class BatchNormalizationTest
CHECK_EQ(kY, input_array_.width());
}
+ ComputationDataHandle CheckShape(ComputationBuilder* b,
+ const ComputationDataHandle& operand,
+ const Shape& expected_shape) const {
+ std::unique_ptr<Shape> actual_shape =
+ b->GetShape(operand).ConsumeValueOrDie();
+ CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
+ << "want " << ShapeUtil::HumanString(expected_shape) << " got "
+ << ShapeUtil::HumanString(*actual_shape);
+ return operand;
+ }
+
static constexpr int64 kSamples = 3;
static constexpr int64 kX = 1;
static constexpr int64 kY = 1;
@@ -164,14 +175,15 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
ComputationBuilder builder(client_, "batch_normalize_per_spec");
auto input_activations =
- builder.CheckShape(builder.ConstantLiteral(input_literal_),
- ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
+ CheckShape(&builder, builder.ConstantLiteral(input_literal_),
+ ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
auto gamma = builder.ConstantR1<float>({1.0, 1.0});
auto beta = builder.ConstantR1<float>({0.0, 0.0});
Computation add = CreateScalarAddComputation(F32, &builder);
// Reduce all dimensions except dimension 1.
Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
- auto sum = builder.CheckShape(
+ auto sum = CheckShape(
+ &builder,
builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0, 2, 3}),
TwoElementVectorF32);
@@ -187,14 +199,16 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
auto activation_deviations = builder.Sub(input_activations, set_means,
/*broadcast_dimensions=*/{1});
auto dev_squares = builder.SquareF32(activation_deviations);
- auto sum_of_squares = builder.CheckShape(
+ auto sum_of_squares = CheckShape(
+ &builder,
builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0, 2, 3}),
TwoElementVectorF32);
auto variance = builder.Div(sum_of_squares, count);
auto standard_deviation = builder.SqrtF32(variance);
- auto standard_deviation_above_epsilon = builder.CheckShape(
- builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
+ auto standard_deviation_above_epsilon =
+ CheckShape(&builder, builder.Gt(standard_deviation, epsilon),
+ ShapeUtil::MakeShape(PRED, {2}));
auto gt_eps = builder.Select(standard_deviation_above_epsilon,
standard_deviation, epsilon2);
auto normalization_factors = builder.ReciprocalF32(gt_eps);