diff options
author | 2016-09-07 06:35:58 -0800 | |
---|---|---|
committer | 2016-09-07 07:46:12 -0700 | |
commit | 2db9a1edc1695877b3f90181e74cfbd8b1a8cdc7 (patch) | |
tree | a55f4be70b13dc700bb6aeb2ebe9ef2f1e18e2bc /tensorflow/cc/gradients/math_grad_test.cc | |
parent | d1e2c2d2868acadee0abd83837458a3e49b067ef (diff) |
C++ Gradients: Adds gradient functions and tests for Pack/Unpack, and takes care of multiple TODOS:
*) Adds support and unit test for returning dependent gradient outputs.
*) Adds support and unit test for stopping backprop at frontier of requested inputs.
*) Adds support and unit test for returning gradients for nodes with multiple outputs.
*) Moves common unit test code out into a testlib.
*) Moves common gradient-specific unit test code out into a separate testlib.
Change: 132434513
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 55 |
1 files changed, 9 insertions, 46 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 6961c584a5..d10a96a4ab 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -14,14 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/testutil.h" +#include "tensorflow/cc/gradients/grad_testutil.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/default_device.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/public/session.h" namespace tensorflow { using namespace ops; // NOLINT(build/namespaces) @@ -33,31 +31,22 @@ namespace { // to a testutil library. class MathGradTest : public ::testing::Test { protected: - MathGradTest() : root_(Scope::NewRootScope()) {} + MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y, const bool t_y, const Output& dz, std::vector<Tensor>* out) { // Compute forward MatMul: z = MatMul(x, y). auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); - TF_EXPECT_OK(root_.status()); + TF_ASSERT_OK(root_.status()); CHECK_NOTNULL(z.node()); std::vector<Output> grad_outputs; // Call MatMulGrad which populates 'grad_outputs'. - CallGradFunction(Operation(z.node()), {dz}, &grad_outputs); - EXPECT_EQ(2, grad_outputs.size()); + TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz}, + &grad_outputs)); + ASSERT_EQ(2, grad_outputs.size()); // Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'. - GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); - } - - void CallGradFunction(const Operation& op, - const std::vector<Output>& grad_inputs, - std::vector<Output>* grad_outputs) { - GradFunc grad_fn; - TF_EXPECT_OK( - GradOpRegistry::Global()->Lookup(op.node()->type_string(), &grad_fn)); - TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs)); - TF_EXPECT_OK(root_.status()); + test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out); } Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y, @@ -65,7 +54,7 @@ class MathGradTest : public ::testing::Test { auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); TF_EXPECT_OK(root_.status()); Tensor out; - GetTensor(root_, z, &out); + test::GetTensor(root_, z, &out); return out; } @@ -95,32 +84,6 @@ class MathGradTest : public ::testing::Test { int Rand() { return 1 + (random::New64() % 10); } - // TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class. - // Note: they should be moved to a general/non-grad specific testutil class. - void GetTensors(const Scope& scope, OutputList tensors, - std::vector<Tensor>* out) { - SessionOptions options; - std::unique_ptr<Session> session(NewSession(options)); - GraphDef def; - scope.graph()->ToGraphDef(&def); - - graph::SetDefaultDevice("/cpu:0", &def); - - TF_CHECK_OK(session->Create(def)); - std::vector<string> names; - for (const auto& t : tensors) { - names.push_back(strings::StrCat(t.node()->name(), ":", t.index())); - } - TF_CHECK_OK(session->Run({}, names, {}, out)); - TF_CHECK_OK(session->Close()); - } - - void GetTensor(const Scope& scope, Output tensor, Tensor* out) { - std::vector<Tensor> outputs; - GetTensors(scope, {tensor}, &outputs); - *out = outputs[0]; - } - Scope root_; }; |