diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-27 18:09:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-27 18:12:04 -0700 |
commit | 94b2d2db576a6cce878aee92d6b1f90ded4278b4 (patch) | |
tree | e551c5d591f1c506496f0d6cff5f0faecb6ef1cd | |
parent | 4979e1d55783d05520fda56fd89641f817daf119 (diff) |
[XLA] Remove CheckShape and CheckSameShape in ComputationBuilder, they are not/rarely used.
PiperOrigin-RevId: 190706088
-rw-r--r-- | tensorflow/compiler/xla/client/computation_builder.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/computation_builder.h | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/batch_normalization_test.cc | 26 |
3 files changed, 20 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 39d02f0863..4d3b0ee0d6 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -253,26 +253,6 @@ StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() { return std::move(*response.mutable_program_shape()); } -ComputationDataHandle ComputationBuilder::CheckShape( - const ComputationDataHandle& operand, const Shape& expected_shape) { - std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie(); - CHECK(ShapeUtil::Equal(expected_shape, *actual_shape)) - << "want " << ShapeUtil::HumanString(expected_shape) << " got " - << ShapeUtil::HumanString(*actual_shape); - return operand; -} - -void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { - std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie(); - std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie(); - VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals " - << ShapeUtil::HumanString(*rhs_shape); - CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape)) - << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs " - << ShapeUtil::HumanString(*rhs_shape); -} - ComputationDataHandle ComputationBuilder::Slice( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice<int64> start_indices, diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 2141ebc206..019c6f3afb 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -104,15 +104,6 @@ class ComputationBuilder { // Retrieves the (inferred) result for the current computation's shape. StatusOr<ProgramShape> GetProgramShape(); - // Checks that the operand has the given expected shape. Returns the operand - // if yes, fails with a CHECK error if no. - ComputationDataHandle CheckShape(const ComputationDataHandle& operand, - const Shape& expected_shape); - - // Checks that the lhs and rhs results have the same shape. - void CheckSameShape(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - // Enqueues a constant with the value of the given literal onto the // computation. ComputationDataHandle ConstantLiteral(const Literal& literal); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 28ab965499..af8af99c79 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -69,6 +69,17 @@ class BatchNormalizationTest CHECK_EQ(kY, input_array_.width()); } + ComputationDataHandle CheckShape(ComputationBuilder* b, + const ComputationDataHandle& operand, + const Shape& expected_shape) const { + std::unique_ptr<Shape> actual_shape = + b->GetShape(operand).ConsumeValueOrDie(); + CHECK(ShapeUtil::Equal(expected_shape, *actual_shape)) + << "want " << ShapeUtil::HumanString(expected_shape) << " got " + << ShapeUtil::HumanString(*actual_shape); + return operand; + } + static constexpr int64 kSamples = 3; static constexpr int64 kX = 1; static constexpr int64 kY = 1; @@ -164,14 +175,15 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { ComputationBuilder builder(client_, "batch_normalize_per_spec"); auto input_activations = - builder.CheckShape(builder.ConstantLiteral(input_literal_), - ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); + CheckShape(&builder, builder.ConstantLiteral(input_literal_), + ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); auto gamma = builder.ConstantR1<float>({1.0, 1.0}); auto beta = builder.ConstantR1<float>({0.0, 0.0}); Computation add = CreateScalarAddComputation(F32, &builder); // Reduce all dimensions except dimension 1. Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); - auto sum = builder.CheckShape( + auto sum = CheckShape( + &builder, builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); @@ -187,14 +199,16 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { auto activation_deviations = builder.Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = builder.CheckShape( + auto sum_of_squares = CheckShape( + &builder, builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto variance = builder.Div(sum_of_squares, count); auto standard_deviation = builder.SqrtF32(variance); - auto standard_deviation_above_epsilon = builder.CheckShape( - builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); + auto standard_deviation_above_epsilon = + CheckShape(&builder, builder.Gt(standard_deviation, epsilon), + ShapeUtil::MakeShape(PRED, {2})); auto gt_eps = builder.Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); auto normalization_factors = builder.ReciprocalF32(gt_eps); |