aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/examples/label_image/main.cc15
-rw-r--r--tensorflow/tools/benchmark/benchmark_model.cc27
-rw-r--r--tensorflow/tools/graph_transforms/README.md8
-rw-r--r--tensorflow/tools/graph_transforms/quantize_weights.cc12
-rw-r--r--tensorflow/tools/graph_transforms/quantize_weights_test.cc14
-rw-r--r--tensorflow/tools/graph_transforms/summarize_graph_main.cc12
6 files changed, 66 insertions, 22 deletions
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc
index a98c0817e3..63bc39de6c 100644
--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -89,7 +89,6 @@ Status ReadLabelsFile(const string& file_name, std::vector<string>* result,
static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
Tensor* output) {
-
tensorflow::uint64 file_size = 0;
TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
@@ -124,15 +123,15 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height,
// read file_name into a tensor named input
Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
- TF_RETURN_IF_ERROR(ReadEntireFile(tensorflow::Env::Default(), file_name,
- &input));
+ TF_RETURN_IF_ERROR(
+ ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
// use a placeholder to read input data
- auto file_reader = Placeholder(root.WithOpName("input"),
- tensorflow::DataType::DT_STRING);
+ auto file_reader =
+ Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
- {"input", input},
+ {"input", input},
};
// Now try to figure out what kind of file it is and decode it.
@@ -285,8 +284,8 @@ int main(int argc, char* argv[]) {
"tensorflow/examples/label_image/data/imagenet_slim_labels.txt";
int32 input_width = 299;
int32 input_height = 299;
- int32 input_mean = 0;
- int32 input_std = 255;
+ float input_mean = 0;
+ float input_std = 255;
string input_layer = "input";
string output_layer = "InceptionV3/Predictions/Reshape_1";
bool self_test = false;
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index 81892aef96..38c1bd3fb5 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -138,8 +138,21 @@ Status GetOutputShapes(const std::vector<InputLayerInfo>& inputs,
std::vector<std::pair<string, tensorflow::Tensor> > input_tensors;
CreateTensorsFromInputInfo(inputs, &input_tensors);
std::vector<tensorflow::Tensor> output_tensors;
- std::vector<string> output_tensor_names(wanted_shapes.begin(),
- wanted_shapes.end());
+ std::vector<string> output_tensor_names;
+ for (const string& wanted_shape : wanted_shapes) {
+ bool is_input = false;
+ for (const std::pair<string, tensorflow::Tensor>& input_tensor :
+ input_tensors) {
+ if (input_tensor.first == wanted_shape) {
+ (*node_shapes)[wanted_shape] = input_tensor.second.shape();
+ is_input = true;
+ break;
+ }
+ }
+ if (!is_input) {
+ output_tensor_names.push_back(wanted_shape);
+ }
+ }
TF_RETURN_IF_ERROR(
session->Run(input_tensors, output_tensor_names, {}, &output_tensors));
CHECK_EQ(output_tensors.size(), output_tensor_names.size());
@@ -156,7 +169,8 @@ Status CalculateFlops(const GraphDef& graph,
Session* session, int64* total_flops,
std::unordered_map<string, int64>* flops_by_op) {
std::unordered_set<string> floppable_ops = {
- "Conv2D", "MatMul", "QuantizedConv2D", "QuantizedMatMul"};
+ "Conv2D", "MatMul", "QuantizedConv2D", "QuantizedMatMul",
+ "DepthwiseConv2dNative"};
std::set<string> wanted_shapes;
for (const NodeDef& node : graph.node()) {
@@ -201,6 +215,13 @@ Status CalculateFlops(const GraphDef& graph,
}
int64 output_count = output_shape.num_elements();
current_flops = k * output_count * 2;
+ } else if (node.op() == "DepthwiseConv2dNative") {
+ const TensorShape& filter_shape = found_shapes[node.input(1)];
+ const TensorShape& output_shape = found_shapes[node.name()];
+ int64 filter_height = filter_shape.dim_size(0);
+ int64 filter_width = filter_shape.dim_size(1);
+ int64 output_count = output_shape.num_elements();
+ current_flops = output_count * filter_height * filter_width * 2;
}
(*flops_by_op)[node.op()] += current_flops;
*total_flops += current_flops;
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md
index b4274e67df..18d1523070 100644
--- a/tensorflow/tools/graph_transforms/README.md
+++ b/tensorflow/tools/graph_transforms/README.md
@@ -578,10 +578,14 @@ eight-bit form.
### quantize_weights
-Args: None \
+Args:
+
+* minimum_size: Tensors with fewer elements than this won't be quantized
+(defaults to 1024)
+
Prerequisites: None
-Converts any large (more than 15 element) float Const op into an eight-bit
+Converts any large (more than minimum_size) float Const op into an eight-bit
equivalent, followed by a float conversion op so that the result is usable by
subsequent nodes. This is mostly useful for [shrinking file
sizes](#shrinking-file-size), but also helps with the more advanced
diff --git a/tensorflow/tools/graph_transforms/quantize_weights.cc b/tensorflow/tools/graph_transforms/quantize_weights.cc
index 66d800f0da..cccae8a992 100644
--- a/tensorflow/tools/graph_transforms/quantize_weights.cc
+++ b/tensorflow/tools/graph_transforms/quantize_weights.cc
@@ -35,11 +35,15 @@ namespace graph_transforms {
Status QuantizeWeights(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
+ int32 minimum_size;
+ TF_RETURN_IF_ERROR(
+ context.GetOneInt32Parameter("minimum_size", 1024, &minimum_size));
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, {"Const"},
- [](const NodeMatch& match, const std::set<string>& input_nodes,
- const std::set<string>& output_nodes,
- std::vector<NodeDef>* new_nodes) {
+ [minimum_size](const NodeMatch& match,
+ const std::set<string>& input_nodes,
+ const std::set<string>& output_nodes,
+ std::vector<NodeDef>* new_nodes) {
const NodeDef& old_const_node = match.node;
if (!old_const_node.attr().count("dtype")) {
return errors::InvalidArgument("No 'dtype' attribute for Const node ",
@@ -58,7 +62,7 @@ Status QuantizeWeights(const GraphDef& input_graph_def,
const size_t num_elements = old_tensor.NumElements();
// If this isn't a float constant, or it's too small, then reuse the
// same node with no changes.
- if ((old_dtype != DT_FLOAT) || (num_elements < 16)) {
+ if ((old_dtype != DT_FLOAT) || (num_elements < minimum_size)) {
new_nodes->push_back(old_const_node);
return Status::OK();
}
diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc
index 63c5b5a64d..e1828831db 100644
--- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc
+++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc
@@ -70,9 +70,12 @@ class QuantizeWeightsTest : public ::testing::Test {
0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f},
&original_graph_def);
+ TransformFuncContext context;
+ context.output_names = {"output"};
+ context.params["minimum_size"] = {"16"};
GraphDef quantized_graph_def;
- TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}},
- &quantized_graph_def));
+ TF_ASSERT_OK(
+ QuantizeWeights(original_graph_def, context, &quantized_graph_def));
// Verify the structure of the quantized graph.
std::map<string, const NodeDef*> node_lookup;
@@ -122,9 +125,12 @@ TEST_F(QuantizeWeightsTest, RangesAlwaysIncludeZero) {
0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f},
&original_graph_def);
+ TransformFuncContext context;
+ context.output_names = {"output"};
+ context.params["minimum_size"] = {"16"};
GraphDef quantized_graph_def;
- TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}},
- &quantized_graph_def));
+ TF_ASSERT_OK(
+ QuantizeWeights(original_graph_def, context, &quantized_graph_def));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(quantized_graph_def, &node_lookup);
diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc
index e79e7ba121..50353c3654 100644
--- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc
+++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc
@@ -118,7 +118,17 @@ Status PrintStructure(const GraphDef& graph) {
TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph));
for (const NodeDef& node : sorted_graph.node()) {
std::cout << node.name() << " (" << node.op() << "): ["
- << str_util::Join(node.input(), ", ") << "]" << std::endl;
+ << str_util::Join(node.input(), ", ") << "]";
+ if (node.op() == "Const") {
+ Tensor tensor;
+ if (node.attr().count("value") &&
+ tensor.FromProto(node.attr().at("value").tensor())) {
+ std::cout << ", value=" << tensor.DebugString();
+ } else {
+ LOG(WARNING) << "Decoding Tensor failed for node" << node.name();
+ }
+ }
+ std::cout << std::endl;
}
return Status::OK();
}