diff options
author | 2016-07-15 14:28:59 -0800 | |
---|---|---|
committer | 2016-07-15 15:33:32 -0700 | |
commit | 25ac3dabfa3af7a313eb46b03690117c85030cc2 (patch) | |
tree | 06010c7cc7d25a538880c5f0d53df079e27093fd /tensorflow/examples/label_image | |
parent | 194efde51895e0251d39c72c969dff1a50b67d35 (diff) |
Improvements to the C++ graph building API.
TESTED:
- passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/
Change: 127585603
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r-- | tensorflow/examples/label_image/main.cc | 55 |
1 files changed, 25 insertions, 30 deletions
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 9251d06684..e4cc11dbe0 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -87,50 +87,44 @@ Status ReadTensorFromImageFile(string file_name, const int input_height, const int input_width, const float input_mean, const float input_std, std::vector<Tensor>* out_tensors) { - tensorflow::GraphDefBuilder b; + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + string input_name = "file_reader"; string output_name = "normalized"; - tensorflow::Node* file_reader = - tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()), - b.opts().WithName(input_name)); + auto file_reader = ReadFile(root.WithOpName(input_name), file_name); // Now try to figure out what kind of file it is and decode it. const int wanted_channels = 3; - tensorflow::Node* image_reader; + Output image_reader; if (tensorflow::StringPiece(file_name).ends_with(".png")) { - image_reader = tensorflow::ops::DecodePng( - file_reader, - b.opts().WithAttr("channels", wanted_channels).WithName("png_reader")); + image_reader = DecodePng(root.WithOpName("png_reader"), file_reader, + DecodePng::Channels(wanted_channels)); } else { // Assume if it's not a PNG then it must be a JPEG. - image_reader = tensorflow::ops::DecodeJpeg( - file_reader, - b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader")); + image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader, + DecodeJpeg::Channels(wanted_channels)); } // Now cast the image data to float so we can do normal math on it. - tensorflow::Node* float_caster = tensorflow::ops::Cast( - image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster")); + auto float_caster = + Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT); // The convention for image ops in TensorFlow is that all images are expected // to be in batches, so that they're four-dimensional arrays with indices of // [batch, height, width, channel]. Because we only have a single image, we // have to add a batch dimension of 1 to the start with ExpandDims(). - tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims( - float_caster, tensorflow::ops::Const(0, b.opts()), b.opts()); + auto dims_expander = ExpandDims(root, float_caster, 0); // Bilinearly resize the image to fit the required dimensions. - tensorflow::Node* resized = tensorflow::ops::ResizeBilinear( - dims_expander, tensorflow::ops::Const({input_height, input_width}, - b.opts().WithName("size")), - b.opts()); + auto resized = ResizeBilinear( + root, dims_expander, + Const(root.WithOpName("size"), {input_height, input_width})); // Subtract the mean and divide by the scale. - tensorflow::ops::Div( - tensorflow::ops::Sub( - resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()), - tensorflow::ops::Const({input_std}, b.opts()), - b.opts().WithName(output_name)); + Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}), + {input_std}); // This runs the GraphDef network definition that we've just constructed, and // returns the results in the output tensor. tensorflow::GraphDef graph; - TF_RETURN_IF_ERROR(b.ToGraphDef(&graph)); + TF_RETURN_IF_ERROR(root.ToGraphDef(&graph)); + std::unique_ptr<tensorflow::Session> session( tensorflow::NewSession(tensorflow::SessionOptions())); TF_RETURN_IF_ERROR(session->Create(graph)); @@ -161,15 +155,16 @@ Status LoadGraph(string graph_file_name, // their positions in the tensor, which correspond to categories. Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels, Tensor* indices, Tensor* scores) { - tensorflow::GraphDefBuilder b; + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + string output_name = "top_k"; - tensorflow::ops::TopKV2(tensorflow::ops::Const(outputs[0], b.opts()), - tensorflow::ops::Const(how_many_labels, b.opts()), - b.opts().WithName(output_name)); + TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels); // This runs the GraphDef network definition that we've just constructed, and // returns the results in the output tensors. tensorflow::GraphDef graph; - TF_RETURN_IF_ERROR(b.ToGraphDef(&graph)); + TF_RETURN_IF_ERROR(root.ToGraphDef(&graph)); + std::unique_ptr<tensorflow::Session> session( tensorflow::NewSession(tensorflow::SessionOptions())); TF_RETURN_IF_ERROR(session->Create(graph)); |