aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/label_image
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2016-07-15 14:28:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-15 15:33:32 -0700
commit25ac3dabfa3af7a313eb46b03690117c85030cc2 (patch)
tree06010c7cc7d25a538880c5f0d53df079e27093fd /tensorflow/examples/label_image
parent194efde51895e0251d39c72c969dff1a50b67d35 (diff)
Improvements to the C++ graph building API.
TESTED: - passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/ Change: 127585603
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r--tensorflow/examples/label_image/main.cc55
1 files changed, 25 insertions, 30 deletions
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc
index 9251d06684..e4cc11dbe0 100644
--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -87,50 +87,44 @@ Status ReadTensorFromImageFile(string file_name, const int input_height,
const int input_width, const float input_mean,
const float input_std,
std::vector<Tensor>* out_tensors) {
- tensorflow::GraphDefBuilder b;
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
string input_name = "file_reader";
string output_name = "normalized";
- tensorflow::Node* file_reader =
- tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()),
- b.opts().WithName(input_name));
+ auto file_reader = ReadFile(root.WithOpName(input_name), file_name);
// Now try to figure out what kind of file it is and decode it.
const int wanted_channels = 3;
- tensorflow::Node* image_reader;
+ Output image_reader;
if (tensorflow::StringPiece(file_name).ends_with(".png")) {
- image_reader = tensorflow::ops::DecodePng(
- file_reader,
- b.opts().WithAttr("channels", wanted_channels).WithName("png_reader"));
+ image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
+ DecodePng::Channels(wanted_channels));
} else {
// Assume if it's not a PNG then it must be a JPEG.
- image_reader = tensorflow::ops::DecodeJpeg(
- file_reader,
- b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader"));
+ image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
+ DecodeJpeg::Channels(wanted_channels));
}
// Now cast the image data to float so we can do normal math on it.
- tensorflow::Node* float_caster = tensorflow::ops::Cast(
- image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster"));
+ auto float_caster =
+ Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
// The convention for image ops in TensorFlow is that all images are expected
// to be in batches, so that they're four-dimensional arrays with indices of
// [batch, height, width, channel]. Because we only have a single image, we
// have to add a batch dimension of 1 to the start with ExpandDims().
- tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims(
- float_caster, tensorflow::ops::Const(0, b.opts()), b.opts());
+ auto dims_expander = ExpandDims(root, float_caster, 0);
// Bilinearly resize the image to fit the required dimensions.
- tensorflow::Node* resized = tensorflow::ops::ResizeBilinear(
- dims_expander, tensorflow::ops::Const({input_height, input_width},
- b.opts().WithName("size")),
- b.opts());
+ auto resized = ResizeBilinear(
+ root, dims_expander,
+ Const(root.WithOpName("size"), {input_height, input_width}));
// Subtract the mean and divide by the scale.
- tensorflow::ops::Div(
- tensorflow::ops::Sub(
- resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()),
- tensorflow::ops::Const({input_std}, b.opts()),
- b.opts().WithName(output_name));
+ Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
+ {input_std});
// This runs the GraphDef network definition that we've just constructed, and
// returns the results in the output tensor.
tensorflow::GraphDef graph;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
+ TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
+
std::unique_ptr<tensorflow::Session> session(
tensorflow::NewSession(tensorflow::SessionOptions()));
TF_RETURN_IF_ERROR(session->Create(graph));
@@ -161,15 +155,16 @@ Status LoadGraph(string graph_file_name,
// their positions in the tensor, which correspond to categories.
Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
Tensor* indices, Tensor* scores) {
- tensorflow::GraphDefBuilder b;
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
string output_name = "top_k";
- tensorflow::ops::TopKV2(tensorflow::ops::Const(outputs[0], b.opts()),
- tensorflow::ops::Const(how_many_labels, b.opts()),
- b.opts().WithName(output_name));
+ TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels);
// This runs the GraphDef network definition that we've just constructed, and
// returns the results in the output tensors.
tensorflow::GraphDef graph;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
+ TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
+
std::unique_ptr<tensorflow::Session> session(
tensorflow::NewSession(tensorflow::SessionOptions()));
TF_RETURN_IF_ERROR(session->Create(graph));