aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/math_grad_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-07 06:35:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-07 07:46:12 -0700
commit2db9a1edc1695877b3f90181e74cfbd8b1a8cdc7 (patch)
treea55f4be70b13dc700bb6aeb2ebe9ef2f1e18e2bc /tensorflow/cc/gradients/math_grad_test.cc
parentd1e2c2d2868acadee0abd83837458a3e49b067ef (diff)
C++ Gradients: Adds gradient functions and tests for Pack/Unpack, and takes care of multiple TODOS:
*) Adds support and unit test for returning dependent gradient outputs. *) Adds support and unit test for stopping backprop at frontier of requested inputs. *) Adds support and unit test for returning gradients for nodes with multiple outputs. *) Moves common unit test code out into a testlib. *) Moves common gradient-specific unit test code out into a separate testlib. Change: 132434513
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc55
1 files changed, 9 insertions, 46 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 6961c584a5..d10a96a4ab 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -14,14 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
+#include "tensorflow/cc/framework/testutil.h"
+#include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/public/session.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
@@ -33,31 +31,22 @@ namespace {
// to a testutil library.
class MathGradTest : public ::testing::Test {
protected:
- MathGradTest() : root_(Scope::NewRootScope()) {}
+ MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
const bool t_y, const Output& dz,
std::vector<Tensor>* out) {
// Compute forward MatMul: z = MatMul(x, y).
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
- TF_EXPECT_OK(root_.status());
+ TF_ASSERT_OK(root_.status());
CHECK_NOTNULL(z.node());
std::vector<Output> grad_outputs;
// Call MatMulGrad which populates 'grad_outputs'.
- CallGradFunction(Operation(z.node()), {dz}, &grad_outputs);
- EXPECT_EQ(2, grad_outputs.size());
+ TF_ASSERT_OK(test::CallGradFunction(root_, Operation(z.node()), {dz},
+ &grad_outputs));
+ ASSERT_EQ(2, grad_outputs.size());
// Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
- GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
- }
-
- void CallGradFunction(const Operation& op,
- const std::vector<Output>& grad_inputs,
- std::vector<Output>* grad_outputs) {
- GradFunc grad_fn;
- TF_EXPECT_OK(
- GradOpRegistry::Global()->Lookup(op.node()->type_string(), &grad_fn));
- TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs));
- TF_EXPECT_OK(root_.status());
+ test::GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
}
Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
@@ -65,7 +54,7 @@ class MathGradTest : public ::testing::Test {
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
TF_EXPECT_OK(root_.status());
Tensor out;
- GetTensor(root_, z, &out);
+ test::GetTensor(root_, z, &out);
return out;
}
@@ -95,32 +84,6 @@ class MathGradTest : public ::testing::Test {
int Rand() { return 1 + (random::New64() % 10); }
- // TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class.
- // Note: they should be moved to a general/non-grad specific testutil class.
- void GetTensors(const Scope& scope, OutputList tensors,
- std::vector<Tensor>* out) {
- SessionOptions options;
- std::unique_ptr<Session> session(NewSession(options));
- GraphDef def;
- scope.graph()->ToGraphDef(&def);
-
- graph::SetDefaultDevice("/cpu:0", &def);
-
- TF_CHECK_OK(session->Create(def));
- std::vector<string> names;
- for (const auto& t : tensors) {
- names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
- }
- TF_CHECK_OK(session->Run({}, names, {}, out));
- TF_CHECK_OK(session->Close());
- }
-
- void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
- std::vector<Tensor> outputs;
- GetTensors(scope, {tensor}, &outputs);
- *out = outputs[0];
- }
-
Scope root_;
};