diff options
author | Manjunath Kudlur <keveman@google.com> | 2016-07-15 14:28:59 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-15 15:33:32 -0700 |
commit | 25ac3dabfa3af7a313eb46b03690117c85030cc2 (patch) | |
tree | 06010c7cc7d25a538880c5f0d53df079e27093fd /tensorflow/cc/tutorials | |
parent | 194efde51895e0251d39c72c969dff1a50b67d35 (diff) |
Improvements to the C++ graph building API.
TESTED:
- passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/
Change: 127585603
Diffstat (limited to 'tensorflow/cc/tutorials')
-rw-r--r-- | tensorflow/cc/tutorials/example_trainer.cc | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc index a465d98f88..c6351afe30 100644 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ b/tensorflow/cc/tutorials/example_trainer.cc @@ -49,31 +49,33 @@ struct Options { GraphDef CreateGraphDef() { // TODO(jeff,opensource): This should really be a more interesting // computation. Maybe turn this into an mnist model instead? - GraphDefBuilder b; + Scope root = Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - // Store rows [3, 2] and [-1, 0] in row major format. - Node* a = Const({3.f, 2.f, -1.f, 0.f}, {2, 2}, b.opts()); - // x is from the feed. - Node* x = Const({0.f}, {2, 1}, b.opts().WithName("x")); + // a = [3 2; -1 0] + auto a = Const(root, {{3.f, 2.f}, {-1.f, 0.f}}); - // y = A * x - Node* y = MatMul(a, x, b.opts().WithName("y")); + // x = [1.0; 1.0] + auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}}); + + // y = a * x + auto y = MatMul(root.WithOpName("y"), a, x); // y2 = y.^2 - Node* y2 = Square(y, b.opts()); + auto y2 = Square(root, y); // y2_sum = sum(y2) - Node* y2_sum = Sum(y2, Const(0, b.opts()), b.opts()); + auto y2_sum = Sum(root, y2, 0); // y_norm = sqrt(y2_sum) - Node* y_norm = Sqrt(y2_sum, b.opts()); + auto y_norm = Sqrt(root, y2_sum); // y_normalized = y ./ y_norm - Div(y, y_norm, b.opts().WithName("y_normalized")); + Div(root.WithOpName("y_normalized"), y, y_norm); GraphDef def; - TF_CHECK_OK(b.ToGraphDef(&def)); + TF_CHECK_OK(root.ToGraphDef(&def)); + return def; } |