aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/tutorials
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2016-07-15 14:28:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-15 15:33:32 -0700
commit25ac3dabfa3af7a313eb46b03690117c85030cc2 (patch)
tree06010c7cc7d25a538880c5f0d53df079e27093fd /tensorflow/cc/tutorials
parent194efde51895e0251d39c72c969dff1a50b67d35 (diff)
Improvements to the C++ graph building API.
TESTED: - passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/ Change: 127585603
Diffstat (limited to 'tensorflow/cc/tutorials')
-rw-r--r--tensorflow/cc/tutorials/example_trainer.cc26
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc
index a465d98f88..c6351afe30 100644
--- a/tensorflow/cc/tutorials/example_trainer.cc
+++ b/tensorflow/cc/tutorials/example_trainer.cc
@@ -49,31 +49,33 @@ struct Options {
GraphDef CreateGraphDef() {
// TODO(jeff,opensource): This should really be a more interesting
// computation. Maybe turn this into an mnist model instead?
- GraphDefBuilder b;
+ Scope root = Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
- // Store rows [3, 2] and [-1, 0] in row major format.
- Node* a = Const({3.f, 2.f, -1.f, 0.f}, {2, 2}, b.opts());
- // x is from the feed.
- Node* x = Const({0.f}, {2, 1}, b.opts().WithName("x"));
+ // a = [3 2; -1 0]
+ auto a = Const(root, {{3.f, 2.f}, {-1.f, 0.f}});
- // y = A * x
- Node* y = MatMul(a, x, b.opts().WithName("y"));
+ // x = [1.0; 1.0]
+ auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
+
+ // y = a * x
+ auto y = MatMul(root.WithOpName("y"), a, x);
// y2 = y.^2
- Node* y2 = Square(y, b.opts());
+ auto y2 = Square(root, y);
// y2_sum = sum(y2)
- Node* y2_sum = Sum(y2, Const(0, b.opts()), b.opts());
+ auto y2_sum = Sum(root, y2, 0);
// y_norm = sqrt(y2_sum)
- Node* y_norm = Sqrt(y2_sum, b.opts());
+ auto y_norm = Sqrt(root, y2_sum);
// y_normalized = y ./ y_norm
- Div(y, y_norm, b.opts().WithName("y_normalized"));
+ Div(root.WithOpName("y_normalized"), y, y_norm);
GraphDef def;
- TF_CHECK_OK(b.ToGraphDef(&def));
+ TF_CHECK_OK(root.ToGraphDef(&def));
+
return def;
}