diff options
author | (David) Siu-Kei Muk <muksiukei@gmail.com> | 2018-09-14 19:58:26 +0800 |
---|---|---|
committer | (David) Siu-Kei Muk <muksiukei@gmail.com> | 2018-09-14 19:58:26 +0800 |
commit | ae7e8d01372a2df39dc5669b00735529c5cfffb9 (patch) | |
tree | 3183e7729343d1426efce92a06dc4ce886d9844b | |
parent | 51d72a7d7f74784b68916819edd04e890b36f957 (diff) | |
parent | 54cbee5d034af8693aa39cc5877c3dfcd62d3740 (diff) |
Merge branch 'master' of https://github.com/tensorflow/tensorflow into est_spec_metrics_ops_check_tensor
75 files changed, 2249 insertions, 176 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7d5db713f6..f4e1bc5e83 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -363,6 +363,7 @@ cc_library( "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", + "encapsulate_xla_computations_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -371,6 +372,7 @@ cc_library( "build_xla_launch_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", + "encapsulate_xla_computations_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -397,6 +399,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -475,6 +478,7 @@ tf_cc_test( size = "small", srcs = [ "encapsulate_subgraphs_pass_test.cc", + "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], @@ -490,7 +494,9 @@ tf_cc_test( "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index ae7a22f451..e0632ff7e4 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; +void SortControlInputs(GraphDef* gdef) { + int64 num_nodes = gdef->node_size(); + for (int64 i = 0; i < num_nodes; ++i) { + NodeDef* node = gdef->mutable_node(i); + // Stable sort control inputs and leave the order of data inputs unchanged. + std::stable_sort(node->mutable_input()->begin(), + node->mutable_input()->end(), + [](const string& a, const string& b) { + bool a_is_control = absl::StartsWith(a, "^"); + bool b_is_control = absl::StartsWith(b, "^"); + return (!a_is_control && b_is_control) || + (a_is_control && b_is_control && a < b); + }); + } +} + namespace { bool AreAllParentsGuaranteedConst( diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 926589546f..90354a801a 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; +// Sorts each node's control inputs by their names. This guarantees that for two +// structually equivalent GraphDefs, we get the same traversal ordering on +// node's control input fields. +// TODO(hpucha): Move the utilities to a more appropriate place. +void SortControlInputs(GraphDef* gdef); + class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc new file mode 100644 index 0000000000..97ef8cd3cb --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -0,0 +1,360 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/fingerprint.h" + +namespace tensorflow { + +const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = + "_xla_compile_id"; + +namespace { + +const char* const kXlaClusterOutput = "XlaClusterOutput"; + +// Checks if a graph node is marked to be a guaranteed constant. +bool is_guaranteed_constant(const Node& n) { + bool guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) + .ok()) { + return false; + } + return guaranteed_constant; +} + +// Finds the `index` of an _Arg or _Retval node. +Status GetIndexAttr(const Node& n, int num_args, int* index) { + TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); + if (*index < 0 || *index >= num_args) { + return errors::InvalidArgument("Invalid ", n.type_string(), " number ", + *index); + } + return Status::OK(); +} + +// Returns the data type of the destination of an edge. +DataType EdgeType(const Edge* edge) { + return edge->dst()->input_type(edge->dst_input()); +} + +// Adds the control inputs of `node` to `*deps`. +void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) { + for (const Edge* edge : node.in_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->src()); + } + } +} + +// Adds the control outputs of `node` to `*deps`. +void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) { + for (const Edge* edge : node.out_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->dst()); + } + } +} + +// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts +// the arguments into the order expected by XlaLaunch computations: +// 1) arguments +// 2) resource variable arguments +// See the documentation of EncapsulateSubgraphsInFunctions for the meaning +// of the arguments. +// +// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. +Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, + std::unique_ptr<Graph>* graph_ptr, + std::vector<int>* input_permutation, + std::vector<int>* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + const int num_args = input_permutation->size(); + const int num_retvals = output_permutation->size(); + + std::vector<Node*> args; + std::vector<Node*> retvals; + args.reserve(num_args); + retvals.reserve(num_retvals); + for (Node* n : graph->nodes()) { + if (n->type_string() == "_Arg") { + // Check if this is a guaranteed constant. + if (is_guaranteed_constant(*n)) { + return errors::InvalidArgument( + "Guaranteed constants are not supported (", n->name(), ")"); + } + args.push_back(n); + } else if (n->type_string() == "_Retval") { + retvals.push_back(n); + } + } + + if (std::find(args.begin(), args.end(), nullptr) != args.end()) { + return errors::InvalidArgument("Missing or non-consecutive arguments"); + } + + // Reorders the arguments. + std::sort(args.begin(), args.end(), [&](Node* a, Node* b) { + // Non-resources appear before resources + bool a_is_resource = (a->output_type(0) == DT_RESOURCE); + bool b_is_resource = (b->output_type(0) == DT_RESOURCE); + // Uses the name as a tiebreaker so the output is deterministic. + StringPiece a_name(a->name()); + StringPiece b_name(b->name()); + return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); + }); + + // Sorts the retvals by name so the order is deterministic. + std::sort(retvals.begin(), retvals.end(), + [](Node* a, Node* b) { return a->name() < b->name(); }); + + // Computes the permutation to produce the correct argument order, and update + // the argument indices. + int variable_start_index = num_args; + for (int i = 0; i < num_args; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index)); + if (args[i]->output_type(0) == DT_RESOURCE && + variable_start_index == num_args) { + variable_start_index = i; + } + (*input_permutation)[index] = i; + args[i]->AddAttr("index", i); + } + VLOG(4) << "variable_start_index: " << variable_start_index; + + // Computes the permutation to produce the correct retval order, and update + // the argument indices. + for (int i = 0; i < num_retvals; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index)); + (*output_permutation)[index] = i; + retvals[i]->AddAttr("index", i); + } + + AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), + call_def); + AddNodeAttr("_variable_start_index", variable_start_index, call_def); + + // Uniquify the function name. + GraphDef gdef; + graph->ToGraphDef(&gdef); + + // Before serialization, sort each node's control inputs to achieve + // determinism. Sorting control inputs could help (but not necessarily) create + // a deterministic serialization and fingerprint. Other sources of + // nondeterminism include unstable node ordering. + SortControlInputs(&gdef); + // Fingerprint the function. + // Nondeterminism in serialization would not lead to incorrect results, but + // may cause spurious cache misses. DeterministicSerialization is a + // best-effort deterministic serialization. + string serialized; + TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); + uint64 fingerprint = Fingerprint64(serialized); + LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); + return Status::OK(); +} + +} // namespace + +/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( + std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) { + // Check for undeclared outputs before Encapsulation, so we can give a better + // error message. + // TODO(phawkins): merge this with the encapsulation code to avoid the extra + // O(n) pass over the edges. + for (const Edge* e : (*graph)->edges()) { + if (!e->IsControlEdge() && + e->src()->attrs().Find(kXlaClusterAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->dst()->type_string() != kXlaClusterOutput) { + return errors::InvalidArgument( + "Undeclared output of XLA computation. A common cause of this error " + "is variable initializers that depend on the XLA computation. Edge: ", + e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", + e->dst_input()); + } + } + + auto output = absl::make_unique<Graph>((*graph)->op_registry()); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, "", **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), + "EncapsulateXlaComputationsPass failed"); + graph->swap(output); + return Status::OK(); +} + +/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( + Graph* graph) { + // Finds all of the XlaLaunch function calls, to avoid mutating the graph + // while iterating. + std::vector<Node*> launch_nodes; + for (Node* n : graph->nodes()) { + string name; + if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { + launch_nodes.push_back(n); + } + } + + // Replaces each launch function call together with its neighboring + // XlaClusterOutput nodes with a XlaLaunch node. + for (Node* launch : launch_nodes) { + int variable_start_index; + TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index", + &variable_start_index)); + + std::vector<const Edge*> in_edges; + TF_RETURN_IF_ERROR(launch->input_edges(&in_edges)); + + const int num_inputs = in_edges.size(); + const int num_variables = num_inputs - variable_start_index; + const int num_args = variable_start_index; + + VLOG(4) << "Launch node '" << launch->name() << "'" + << " input edges: " << in_edges.size() << " num_args: " << num_args + << " num_variables: " << num_variables; + + std::vector<Node*> nodes_to_remove = {launch}; + + // Data and control inputs to the new XlaLaunch node. + std::vector<std::pair<Node*, int>> data_inputs(num_inputs); + gtl::FlatSet<Node*> control_inputs; + DataTypeVector arg_types(num_args); + + AddControlInputs(*launch, &control_inputs); + + for (int i = 0; i < num_args; ++i) { + const Edge* edge = in_edges[i]; + data_inputs[i] = {edge->src(), edge->src_output()}; + arg_types[i] = EdgeType(edge); + } + + // Appends the variable inputs. + for (int i = 0; i < num_variables; ++i) { + int pos = variable_start_index + i; + const Edge* edge = in_edges[pos]; + data_inputs[pos] = {edge->src(), edge->src_output()}; + } + + // Outputs. + const int num_outputs = launch->output_types().size(); + gtl::FlatSet<Node*> control_outputs; + std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs); + DataTypeVector output_types(num_outputs); + + for (const Edge* le : launch->out_edges()) { + if (le->IsControlEdge()) { + control_outputs.insert(le->dst()); + } else { + TF_RET_CHECK(le->src_output() < num_outputs); + Node* output_node = le->dst(); + + TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput) + << le->DebugString(); + nodes_to_remove.push_back(output_node); + + for (const Edge* oe : output_node->out_edges()) { + TF_RET_CHECK(!oe->IsControlEdge()); + data_outputs[le->src_output()].push_back( + {oe->dst(), oe->dst_input()}); + } + output_types[le->src_output()] = output_node->input_type(0); + + AddControlOutputs(*output_node, &control_outputs); + } + } + + NodeDef def; + def.set_name(launch->name()); + + // Target the XLA CPU/GPU backends. + VLOG(2) << "Replacing with XlaLaunch"; + def.set_op("XlaLaunch"); + AddNodeAttr("Tconstants", DataTypeVector{}, &def); + AddNodeAttr("Targs", arg_types, &def); + AddNodeAttr("Nresources", num_variables, &def); + AddNodeAttr("Tresults", output_types, &def); + NameAttrList function; + function.set_name(launch->type_string()); + AddNodeAttr("function", function, &def); + + for (Node* node : nodes_to_remove) { + VLOG(2) << "Deleting node " << node->DebugString(); + // Ensure that we do not attempt to add control edges to nodes that are + // deleted. + control_inputs.erase(node); + control_outputs.erase(node); + graph->RemoveNode(node); + } + + Status status; + Node* xla_launch = graph->AddNode(def, &status); + if (!status.ok()) { + return status; + } + for (int i = 0; i < data_inputs.size(); ++i) { + graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, + i); + } + for (Node* n : control_inputs) { + graph->AddControlEdge(n, xla_launch); + } + for (int i = 0; i < data_outputs.size(); ++i) { + for (const auto& successor : data_outputs[i]) { + graph->AddEdge(xla_launch, i, successor.first, successor.second); + } + } + for (Node* n : control_outputs) { + graph->AddControlEdge(xla_launch, n); + } + } + return Status::OK(); +} + +Status EncapsulateXlaComputationsPass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "EncapsulateXlaComputations(): " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + VLOG(1) << "EncapsulateXlaComputations() half-way: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + VLOG(1) << "EncapsulateXlaComputations() finished: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", + **options.graph, options.flib_def); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h new file mode 100644 index 0000000000..99e9dfd598 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +// Rewrites computations generated by the xla.compile() Python code into +// XlaLaunch nodes. +// +// xla.compile() does two main things: +// a) marks operators that make up an XLA computation with the attribute +// _xla_compile_id=XYZ, where XYZ is a unique key. +// b) adds XlaClusterOutput nodes to represent outputs of the computation. +// These nodes are not marked with the _xla_compile_id attribute. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" + + namespace tensorflow { + +// Encapsulates nodes marked with the _xla_compile_id attribute into +// XlaLaunch operators. +class EncapsulateXlaComputationsPass : public GraphOptimizationPass { + public: + static const char* const kXlaClusterAttr; // _xla_compile_id + + Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for unit tests. + + // This pass has two stages: + // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes + // marked with the same _xla_compile_id attribute into functions. These + // functions contain the computations to be passed to XlaLaunch. During + // encapsulation, we sort the arguments into the order expected by + // XlaLaunch. + static Status Encapsulate(std::unique_ptr<Graph>* graph, + FunctionLibraryDefinition* flib_def); + + // b) we rewrite the function calls generated in phase (a) into XlaLaunch + // operators. We also convert the XlaClusterOutput output nodes of the + // function call into the outputs of the XlaLaunch operator. + static Status BuildXlaLaunchOps(Graph* graph); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc new file mode 100644 index 0000000000..f643fb0cfe --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -0,0 +1,346 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +static std::unique_ptr<Graph> MakeOuterGraph( + const FunctionLibraryDefinition& flib_def, const string& function) { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NodeDef def; + TF_CHECK_OK( + NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); + + Status status; + Node* launch = scope.graph()->AddNode(def, &status); + TF_CHECK_OK(status); + TF_CHECK_OK(scope.DoShapeInference(launch)); + scope.graph()->AddEdge(a.node(), 0, launch, 0); + scope.graph()->AddEdge(b.node(), 0, launch, 1); + scope.graph()->AddEdge(c.node(), 0, launch, 2); + scope.graph()->AddEdge(d.node(), 0, launch, 3); + scope.graph()->AddEdge(u.node(), 0, launch, 4); + scope.graph()->AddEdge(v.node(), 0, launch, 5); + scope.graph()->AddEdge(w.node(), 0, launch, 6); + + auto out0 = + ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0)); + auto out1 = + ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1)); + auto out2 = + ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2)); + auto out3 = + ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3)); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +// Makes an encapsulate body graph for use in tests. +static std::unique_ptr<Graph> MakeBodyGraph() { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); + auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); + + auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); + auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); + auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, arg3); + add_attrs(g.node()); + + auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), + b_identity, 0); + auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); + auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); + auto out3 = + ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { + // Test that control edge insertion order doesn't affect the cache key + // (cluster name) generated by TPU encapsulate pass. + auto get_serialized_graph = [](bool control_input_reversed, + bool operand_reversed) -> string { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr<Graph> graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); + auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); + + ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) + : ops::Add(scope.WithOpName("E"), a1, a0); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, + "launch0"); + }; + add_attrs(e.node()); + + TF_CHECK_OK(scope.ToGraph(graph.get())); + auto get_node_in_graph = [&graph](Node* node) { + return graph->FindNodeId(node->id()); + }; + // Insert control edge in different order. The order should not affect + // the encapsulated or serialized graph. + if (!control_input_reversed) { + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + } else { + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + } + } + TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + GraphDef gdef; + graph->ToGraphDef(&gdef); + // Before serialization, sort control inputs first to remove + // nondeterminism. + SortControlInputs(&gdef); + string serialized; + SerializeToStringDeterministic(gdef, &serialized); + return serialized; + }; + + // Changing the order of control input shouldn't affect the graph generated. + EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, + /*operand_reversed=*/false), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); + + // Changing the order of data input should affect the graph generated. + EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/true), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); +} + +TEST(EncapsulateXlaComputations, Encapsulate) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr<Graph> graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); + add_attrs(b_identity.node()); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), a, c); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, d); + add_attrs(g.node()); + + auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity); + auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e); + auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g); + auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + std::unique_ptr<Graph> graph_copy(new Graph(&flib_def)); + CopyGraph(*graph, graph_copy.get()); + + TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + + std::unordered_map<string, Node*> index = BuildNodeIndex(*graph); + string function = index.at("launch0")->type_string(); + + // Tests the outer graph is as expected. + { + std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function); + GraphDef expected_def; + outer->ToGraphDef(&expected_def); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); + } + + // Tests the encapsulated body graph is as expected. + { + std::unique_ptr<Graph> body = MakeBodyGraph(); + GraphDef expected_body_def; + body->ToGraphDef(&expected_body_def); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, + DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); + } + + // Encapsulates the same computation again, verifies we reuse the same + // function. Encapsulation should be deterministic to avoid recompilation. + TF_ASSERT_OK( + EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); + std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy); + string function_copy = index_copy.at("launch0")->type_string(); + EXPECT_EQ(function, function_copy); +} + +TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { + std::unique_ptr<Graph> body_graph = MakeBodyGraph(); + FunctionDefLibrary flib; + TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function())); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0"); + TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get())); + + Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NameAttrList function; + function.set_name("launch0"); + auto launch = ops::XlaLaunch( + scope.WithOpName("launch0"), std::initializer_list<Input>{}, + std::initializer_list<Input>{a, b, c, d}, + std::initializer_list<Input>{u, v, w}, + DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); + + auto consumer0_a = + ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); + auto consumer0_b = + ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); + auto consumer0_c = + ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); + auto consumer1 = + ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); + auto consumer2 = + ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); + auto consumer3 = + ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); + + GraphDef expected_def; + TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ(expected_def, actual_def); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 5dcf754969..3770eea6d0 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" @@ -23,6 +24,11 @@ namespace tensorflow { // PRE_PLACEMENT passes: +// EncapsulateXlaComputationsPass rewrites computations generated by the +// xla.compile() Python code into XlaLaunch nodes. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, + EncapsulateXlaComputationsPass); + // from // third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc // FunctionalizeControlFlowPass: 27 @@ -32,7 +38,8 @@ namespace tensorflow { // control flow structure (XlaIf/XlaWhile). Following passes must // handle those FunctionDef correctly. -// POST_REWRITE_FOR_EXEC passes: +// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA: + REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index f2473d98ff..1a29c3caab 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::InferenceContext; + REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") @@ -32,4 +36,19 @@ REGISTER_OP("XlaLaunch") .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); +REGISTER_OP("XlaClusterOutput") + .Input("input: T") + // Note: when replication is supported, this op will have N outputs. + .Output("outputs: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(0)); + } + return Status::OK(); + }) + .Doc( + "Operator that connects the output of an XLA computation to other " + "consumer graph nodes."); + } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 2176eaebe4..97ed554171 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -277,9 +277,10 @@ tf_xla_py_test( ], ) +# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "medium", + size = "large", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 37e5318bb5..2d225ad226 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase): ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"): array_ops.concat([scalar, scalar, scalar], dim) + # The purpose of this is to ensure that XLA on GPU will not run out of memory + # with too many arguments. + def testConcatLargeNumberOfTensors(self): + with self.cached_session(): + with self.test_scope(): + for concat_dim in range(2): + params = {} + p = [] + shape = np.array([7, 13]) + num_tensors = 1001 + for i in np.arange(num_tensors): + input_shape = shape + placeholder = array_ops.placeholder( + dtypes.float32, shape=input_shape) + p.append(placeholder) + params[placeholder] = np.random.rand(*input_shape).astype( + np.float32) + + concat_inputs = p + c = array_ops.concat(concat_inputs, concat_dim) + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + cur_offset = 0 + + for i in np.arange(num_tensors): + # The index into the result is the ':' along all dimensions + # except the concat_dim. slice(0, size) is used for ':', and + # a list of slices is used to index into result. + index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)] + index[concat_dim] = slice( + cur_offset, cur_offset + params[p[i]].shape[concat_dim]) + cur_offset += params[p[i]].shape[concat_dim] + self.assertAllEqual(result[index], params[p[i]]) + class ConcatOffsetTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index d549e7bb59..ba1e3b2b4f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -611,6 +611,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index ab7cac7100..e9f02201cf 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -55,17 +55,17 @@ namespace tensorflow { // op registration infrastructure instead of FunctionLibraryRuntime. class GraphCompiler { public: - GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, - Graph* graph, FunctionLibraryRuntime* flib, + GraphCompiler(XlaCompilationDevice* device, Graph* graph, + FunctionLibraryRuntime* flib, ScopedStepContainer* step_container) - : xla_context_(xla_context), - device_(device), + : device_(device), graph_(graph), flib_(flib), step_container_(step_container) {} - // Compiles the graph. The results are written in `xla_context` that is passed - // into the compiler. + // Compiles the graph. The results are written in xla_context stored in the + // resource_manager of the 'XlaCompilationDevice' that's passed into the + // constructor. Status Compile(); private: @@ -82,7 +82,6 @@ class GraphCompiler { // using `compiler_`. Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); - XlaContext* xla_context_; XlaCompilationDevice* device_; Graph* graph_; FunctionLibraryRuntime* flib_; diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index df17da4c1c..0d9a768a6f 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -66,6 +66,9 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + if (DataTypeIsUnsigned(dtype)) { + return xla::Div(x, y); + } auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index f410605104..0ae23aa6df 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -37,6 +37,16 @@ limitations under the License. namespace tensorflow { namespace { +// Used to determine the number of Tensors allowed in a Concat op to prevent +// going over the max gpu parameter memory size. This is an issue because concat +// is variadic and can have an unlimited number of arguments when called. +// Concat ops with more Tensors than this will be split into multiple concat +// ops. +// +// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass +// along with boxing large numbers of parameters. +constexpr int64 kMaxConcatArgsPerOp = 500; + // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector<xla::XlaOp> input_data; + std::vector<xla::XlaOp> partial_concats; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { @@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; + + // Concat is associative, so it can be split into many operations when too + // many arguments are in a single op. This is a temporary workaround for + // b/112613927 where too many parameters in an XlaLaunchOp later result in + // too many parameters to a single GPU kernel. + if (i && i % kMaxConcatArgsPerOp == 0) { + partial_concats.push_back( + xla::ConcatInDim(ctx->builder(), input_data, axis)); + input_data.clear(); + } } + // Add any inputs that have not been put into another concat yet. + partial_concats.insert(partial_concats.end(), input_data.begin(), + input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + // Don't add an additional "identity" concatenate for better readibility of + // IR. + if (partial_concats.size() == 1) { + ctx->SetOutput(0, partial_concats.front()); + } else { + ctx->SetOutput(0, + xla::ConcatInDim(ctx->builder(), partial_concats, axis)); + } } private: diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 3c6c9a91b6..f31bfb45a2 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } +std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) { + std::unordered_map<string, Node*> index; + for (Node* node : graph.nodes()) { + index[node->name()] = node; + } + return index; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index e6e4ae92ed..350a868568 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); +// Builds a map from node name to Node* for `graph`. +std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph); + } // namespace tensorflow +// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for +// equality. +#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ + do { \ + string diff; \ + EqualGraphDefOptions eq_options; \ + eq_options.ignore_internal_attrs = false; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + #endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 105f3b61d5..739e47778a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -325,8 +325,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph, step_container->name(), XlaContext::kXlaContextResourceName, xla_context)); - GraphCompiler graph_compiler(xla_context, device, graph.get(), flib, - step_container.get()); + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); TF_RETURN_IF_ERROR(graph_compiler.Compile()); // Explicitly clean up the step container, to capture the cleanup status. step_container.reset(); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 76e36f3c46..ef70c1f8ac 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -193,6 +193,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index 787725e884..b507a2ef79 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { @@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) { return safe_file_name; } +std::pair<tensorflow::mutex*, std::vector<std::function<string(string)>>*> +GetDirectoryExpanders() { + static auto* mutex = new tensorflow::mutex; + static auto* singleton = new std::vector<std::function<string(string)>>; + return {mutex, singleton}; +} + +// Runs all the directory expanders over x and returns the result. +string Expand(string x) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + for (const auto& f : *pair.second) { + x = f(x); + } + return x; +} + } // namespace Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string expanded_dir = Expand(directory); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); + const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name); return tensorflow::WriteBinaryProto(env, path, message); } +void RegisterDirectoryExpander(const std::function<string(string)>& expander) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + pair.second->push_back(expander); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 3667621367..f22fc8b849 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); +// Registers a function that may either expand a dirpath or forward the original +// dirpath along as-is. +void RegisterDirectoryExpander(const std::function<string(string)>& expander); + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index e5a6c28478..96bd2616f5 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModule> hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index c607aea1a8..f528e62b17 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -221,25 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( allocator = &*se_allocator; } - // Allocate space for the input, filter, and output of the convolution. We - // use a ScratchAllocator for this instead of calling allocator_ directly so - // that our allocations don't leak. - ScratchAllocator input_output_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(params.input_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(input_shape))); - TF_ASSIGN_OR_RETURN(params.filter_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(filter_shape))); - TF_ASSIGN_OR_RETURN(params.output_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(output_shape))); - - if (cross_check_enabled) { - // Broadcast a constant to the buffer, instead of zeroing the buffer. A - // non-zero constant is useful for the cross checking, because zero-inputs - // may not always reveal the bugs. - const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) { + const auto initialize_buffer = [&stream, cross_check_enabled]( + DeviceMemoryBase buffer) { + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); size_t left_over_bytes = buffer.size() % 4; CHECK_EQ(0, left_over_bytes % 2); @@ -257,19 +244,32 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( DeviceMemoryBase left_over( static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes); stream.ThenMemcpy(&left_over, halfs, left_over_bytes); - }; - initialize_f16(params.input_buf); - initialize_f16(params.filter_buf); - initialize_f16(params.output_buf); - } else { - // Although we don't have evidence this matters, zero out the buffers before - // autotuning. It's conceivable that using uninitialized memory as the - // inputs might affect performance if e.g. the inputs contain denormals, and - // this is easy enough. - stream.ThenMemZero(¶ms.input_buf, params.input_buf.size()) - .ThenMemZero(¶ms.filter_buf, params.filter_buf.size()) - .ThenMemZero(¶ms.output_buf, params.output_buf.size()); - } + } else { + // Although we don't have evidence this matters, zero out the buffers + // before autotuning. It's conceivable that using uninitialized memory as + // the inputs might affect performance if e.g. the inputs contain + // denormals, and this is easy enough. + stream.ThenMemZero(&buffer, buffer.size()); + } + }; + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + TF_ASSIGN_OR_RETURN(params.input_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(input_shape))); + TF_ASSIGN_OR_RETURN(params.filter_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(filter_shape))); + TF_ASSIGN_OR_RETURN(params.output_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(output_shape))); + + initialize_buffer(params.input_buf); + initialize_buffer(params.filter_buf); + initialize_buffer(params.output_buf); DeviceMemoryBase* result_buf = [&] { switch (params.kind) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 064b86493d..06b6d5b559 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1339,6 +1339,12 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) { Status HloEvaluator::Postprocess(HloInstruction* hlo) { VLOG(2) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); + // Out of convenience the literal may have been produced with a different + // layout. Relayout as indicated by the HLO instruction. + if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), + hlo->shape())) { + evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 16411eb078..01e88566a5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -2570,6 +2570,25 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); } +TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { + // Regression test for b/114735354. + const string hlo_text = R"( +HloModule SliceWithDifferentLayout + +ENTRY main { + arg = f32[2,2,2]{0,1,2} parameter(0) + ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR3WithLayout<float>( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + LayoutUtil::MakeLayout({0, 1, 2})); + Literal actual = Evaluate({&arg}); + EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 7f090a52db..8fb17a0033 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -249,12 +249,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = result.Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -265,11 +260,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = result.Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -2350,8 +2341,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get<ReturnT>(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); + Literal result(shape); TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 922ebdf0e3..b27a92f2a0 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -812,7 +812,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, HloModule::CreateFromProto(module_proto, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -1160,7 +1160,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas( return replicas; } -Status Service::MaybeDumpHloModule(const HloModule& module) const { +Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { const string xla_dump_unoptimized_hlo_proto_to = module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); if (xla_dump_unoptimized_hlo_proto_to.empty()) { @@ -1168,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const { } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, + StrCat(module.name(), ".unoptimized")); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 44c5248b15..1f62fad4c8 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -271,7 +271,9 @@ class Service : public ServiceInterface { StatusOr<std::vector<se::StreamExecutor*>> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - Status MaybeDumpHloModule(const HloModule& module) const; + // Dumps the (unoptimized) module given if the corresponding DebugOptions + // field has been set. + Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 0b79f718d4..789dab81ed 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -1,6 +1,10 @@ TensorFlow CMake build ====================== +CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all +platforms. For details, see the +[TensorFlow install guide](https://www.tensorflow.org/install/). + This directory contains CMake files for building TensorFlow on Microsoft Windows. [CMake](https://cmake.org) is a cross-platform tool that can generate build scripts for multiple build systems, including Microsoft diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index baec238c62..c378b1ce8d 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -62,6 +62,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@sloppy_interleave @@unbatch @@unique + +@@AUTOTUNE """ from __future__ import absolute_import @@ -91,6 +93,10 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator + +# Optimization constant that can be used to enable auto-tuning. +from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE + from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device @@ -113,6 +119,3 @@ from tensorflow.python.data.ops.optional_ops import Optional from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) - -# A constant that can be used to enable auto-tuning. -AUTOTUNE = -1 diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 4b45cc7e36..a14781cd93 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -80,6 +80,7 @@ py_library( ":batching", ":gen_dataset_ops", ":interleave_ops", + ":optimization", ":parsing_ops", ":shuffle_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index 4114b62e29..73840452df 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -24,6 +24,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# A constant that can be used to enable auto-tuning. +AUTOTUNE = -1 + # TODO(jsimsa): Support RE matching for both individual transformation (e.g. to # account for indexing) and transformation sequence. diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4c466781f7..785b395707 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.data.python.ops import optimization from tensorflow.contrib.data.python.ops import parsing_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops @@ -214,18 +215,17 @@ def _maybe_shuffle_and_repeat( return dataset -def make_tf_record_dataset( - file_pattern, - batch_size, - parser_fn=None, - num_epochs=None, - shuffle=True, - shuffle_buffer_size=None, - shuffle_seed=None, - prefetch_buffer_size=None, - num_parallel_reads=None, - num_parallel_parser_calls=None, - drop_final_batch=False): +def make_tf_record_dataset(file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=optimization.AUTOTUNE, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): """Reads and optionally parses TFRecord files into a dataset. Provides common functionality such as batching, optional parsing, shuffling, @@ -300,8 +300,6 @@ def make_tf_record_dataset( parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, drop_remainder=drop_final_batch)) - if prefetch_buffer_size is None: - prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE if prefetch_buffer_size == 0: return dataset else: @@ -323,7 +321,7 @@ def make_csv_dataset( shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=1, + prefetch_buffer_size=optimization.AUTOTUNE, num_parallel_reads=1, sloppy=False, num_rows_for_inference=100, @@ -386,9 +384,10 @@ def make_csv_dataset( shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size ensures better shuffling, but increases memory usage and startup time. shuffle_seed: Randomization seed to use for shuffling. - prefetch_buffer_size: An int specifying the number of feature batches to - prefetch for performance improvement. Recommended value is the number of - batches consumed per training step. + prefetch_buffer_size: An int specifying the number of feature + batches to prefetch for performance improvement. Recommended value is the + number of batches consumed per training step. Defaults to auto-tune. + num_parallel_reads: Number of threads used to read CSV records from files. If >1, the results will be interleaved. sloppy: If `True`, reading performance will be improved at @@ -666,7 +665,7 @@ def make_batched_features_dataset(file_pattern, shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=1, + prefetch_buffer_size=optimization.AUTOTUNE, reader_num_threads=1, parser_num_threads=2, sloppy_ordering=False, @@ -739,7 +738,7 @@ def make_batched_features_dataset(file_pattern, shuffle_seed: Randomization seed to use for shuffling. prefetch_buffer_size: Number of feature batches to prefetch in order to improve performance. Recommended value is the number of batches consumed - per training step (default is 1). + per training step. Defaults to auto-tune. reader_num_threads: Number of threads used to read `Example` records. If >1, the results will be interleaved. parser_num_threads: Number of threads to use for parsing `Example` tensors diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 87f76eaa94..aaecbb0eb1 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -485,7 +485,6 @@ py_library( srcs = ["single_loss_example.py"], deps = [ ":step_fn", - "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:layers", diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index 5aa19cf6a9..09b351ffa4 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import step_fn from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -59,10 +58,9 @@ def minimize_loss_example(optimizer_fn, def dataset_fn(): dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() - # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be + # TODO(isaprykin): batch with drop_remainder causes shapes to be # fully defined for TPU. Remove this when XLA supports dynamic shapes. - return dataset.apply( - batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True)) + return dataset.batch(1, drop_remainder=True) # An Optimizer instance is created either outside or inside model_fn. outer_optimizer = None diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h index fa43e6a024..be9d551ee4 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data.h +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -25,6 +25,9 @@ extern "C" { // TODO(aselle): Consider using "if this then that" for testing. +// IMPORTANT: All new members of structs must be added at the end to ensure +// backwards compatibility. + // Possible padding types (for convolutions) typedef enum { kTfLitePaddingUnknown = 0, @@ -71,11 +74,15 @@ typedef struct { } TfLitePoolParams; typedef struct { + // Parameters for DepthwiseConv version 1 or above. TfLitePadding padding; int stride_width; int stride_height; int depth_multiplier; TfLiteFusedActivation activation; + // Parameters for DepthwiseConv version 2 or above. + int dilation_width_factor; + int dilation_height_factor; } TfLiteDepthwiseConvParams; typedef struct { diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index eef4b6d831..f4d2839b1b 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -216,6 +216,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->depth_multiplier = conv_params->depth_multiplier(); params->activation = parse_activation(conv_params->fused_activation_function()); + + params->dilation_width_factor = conv_params->dilation_w_factor(); + params->dilation_height_factor = conv_params->dilation_h_factor(); } *builtin_data = reinterpret_cast<void*>(params); break; diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc index 52b17faf82..555a9cc4b0 100644 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc @@ -117,6 +117,8 @@ Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators( Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors( FlatBufferBuilder* fbb) { + // Initialized to -1. + // A value of -1 means this tensor will not be exported. tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1); std::vector<Offset<Tensor>> tensors; @@ -135,15 +137,17 @@ Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors( int curr_output_index = 0; for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); tensor_index++) { - if (!tensor_is_temporary[tensor_index]) { + // Temporary tensors and unused tensors will not be written. + if (!tensor_is_temporary[tensor_index] && + unused_tensors_.find(tensor_index) == unused_tensors_.end()) { tensor_to_written_tensor_[tensor_index] = curr_output_index++; } } for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); ++tensor_index) { - // Skip temporaries. - if (tensor_is_temporary[tensor_index]) continue; + // Tensor not exported. + if (tensor_to_written_tensor_[tensor_index] == -1) continue; if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) { // We only need to convert non temporaries @@ -215,7 +219,9 @@ std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten( std::vector<int> output; output.reserve(input.size()); for (int x : input) { - output.push_back(tensor_to_written_tensor_[x]); + if (tensor_to_written_tensor_[x] != -1) { + output.push_back(tensor_to_written_tensor_[x]); + } } return output; } diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h index a98108b496..a5f14697cf 100644 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib.h +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h @@ -62,6 +62,10 @@ class InterpreterWriter { // caller to change the custom data. TfLiteStatus RegisterCustomWriter(const std::string& custom_name, CustomWriter custom_writer); + // Tensors that are unused and shouldn't be written. + void SetUnusedTensors(const std::set<int>& unused_tensors) { + unused_tensors_ = unused_tensors; + } private: template <class T> @@ -111,8 +115,9 @@ class InterpreterWriter { int builtin; std::string custom; }; + std::set<int> unused_tensors_; // For every tensor index in the interpreter, the index in the written. - // This is different due to temporary tensors not being written. + // This is different due to temporary and unused tensors not being written. std::vector<int> tensor_to_written_tensor_; // List of used opcodes std::vector<OpCode> opcodes_; diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index 347515f289..3e1ce60113 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -126,23 +126,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Matching GetWindowedOutputSize in TensorFlow. auto padding = params->padding; - auto compute_out_size = [padding](int imageSize, int filterSize, - int stride) -> int { + auto compute_out_size = [padding](int image_size, int filter_size, int stride, + int dilation_rate) -> int { + int effective_filter_size = (filter_size - 1) * dilation_rate + 1; return padding == kTfLitePaddingSame - ? (imageSize + stride - 1) / stride + ? (image_size + stride - 1) / stride : padding == kTfLitePaddingValid - ? (imageSize - filterSize + stride) / stride + ? (image_size - effective_filter_size + stride) / stride : 0; }; - int out_width = compute_out_size(width, filter_width, params->stride_width); + int out_width = compute_out_size(width, filter_width, params->stride_width, + params->dilation_width_factor); int out_height = - compute_out_size(height, filter_height, params->stride_height); + compute_out_size(height, filter_height, params->stride_height, + params->dilation_height_factor); - data->padding.height = ComputePadding(params->stride_height, 1, height, - filter_height, out_height); + data->padding.height = + ComputePadding(params->stride_height, params->dilation_height_factor, + height, filter_height, out_height); data->padding.width = - ComputePadding(params->stride_width, 1, width, filter_width, out_width); + ComputePadding(params->stride_width, params->dilation_width_factor, width, + filter_width, out_width); // Note that quantized inference requires that all tensors have their // parameters set. This is usually done during quantized training. @@ -177,8 +182,19 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, void (*depthwise_conv)(const float*, const Dims<4>&, const float*, const Dims<4>&, const float*, const Dims<4>&, int, int, - int, int, int, float, float, float*, const Dims<4>&); - if (kernel_type == kReference) { + int, int, int, int, int, float, float, float*, + const Dims<4>&); + KernelType effective_kernel_type; + // TODO(suharshs): Currently only the reference implementation supports + // dilations. + if ((params->dilation_width_factor != 1) || + (params->dilation_height_factor != 1)) { + effective_kernel_type = kReference; + } else { + effective_kernel_type = kernel_type; + } + + if (effective_kernel_type == kReference) { depthwise_conv = &reference_ops::DepthwiseConv; } else { depthwise_conv = &optimized_ops::DepthwiseConv; @@ -188,7 +204,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, GetTensorData<float>(input), GetTensorDims(input), GetTensorData<float>(filter), GetTensorDims(filter), GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width, - params->stride_height, data->padding.width, data->padding.height, + params->stride_height, params->dilation_width_factor, + params->dilation_height_factor, data->padding.width, data->padding.height, params->depth_multiplier, output_activation_min, output_activation_max, GetTensorData<float>(output), GetTensorDims(output)); } @@ -204,9 +221,20 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*, const Dims<4>&, int32, const int32*, const Dims<4>&, - int, int, int, int, int, int32, int32, int, int32, - int32, uint8*, const Dims<4>&); - if (kernel_type == kReference) { + int, int, int, int, int, int, int, int32, int32, int, + int32, int32, uint8*, const Dims<4>&); + + KernelType effective_kernel_type; + // TODO(suharshs): Currently only the reference implementation supports + // dilations. + if ((params->dilation_width_factor != 1) || + (params->dilation_height_factor != 1)) { + effective_kernel_type = kReference; + } else { + effective_kernel_type = kernel_type; + } + + if (effective_kernel_type == kReference) { depthwise_conv = &reference_ops::DepthwiseConv; } else { depthwise_conv = &optimized_ops::DepthwiseConv; @@ -216,7 +244,8 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width, - params->stride_height, data->padding.width, data->padding.height, + params->stride_height, params->dilation_width_factor, + params->dilation_height_factor, data->padding.width, data->padding.height, params->depth_multiplier, output_offset, data->output_multiplier, data->output_shift, data->output_activation_min, data->output_activation_max, GetTensorData<uint8_t>(output), diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc index c00cafb9fb..2af26ab80a 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -30,7 +30,8 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel { // stride values. BaseDepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, - const TensorData& output) { + const TensorData& output, + int dilation_factor = 1) { input_ = AddInput(input); filter_ = AddInput(filter); @@ -56,7 +57,8 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel { BuiltinOperator_DEPTHWISE_CONV_2D, BuiltinOptions_DepthwiseConv2DOptions, CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, - ActivationFunctionType_NONE) + ActivationFunctionType_NONE, + dilation_factor, dilation_factor) .Union()); BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); @@ -110,6 +112,58 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) { })); } +TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) { + const int depth = 1; + const int image_width = 9; + const int image_height = 9; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int dilation_factor = 3; + DepthwiseConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, dilation_factor); + + // The image matrix is: + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // clang-format off + m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0}); + // clang-format on + // The filter matrix is: + // | 1 | 2 | 3 | + // | 4 | 5 | 6 | + // | 7 | 8 | 9 | + m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + // No bias for this test. + m.SetBias({0}); + m.Invoke(); + + // Since the dilation rate is 3 this will reduce the size of the output from + // 10x10 to 3x3 of all 5s. Specifically: + // | 5 | 5 | 5 | + // | 5 | 5 | 5 | + // | 5 | 5 | 5 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5})); +} + class QuantizedDepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel { public: @@ -207,6 +261,64 @@ TEST(QuantizedDepthwiseConvolutionOpTest, ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); } +TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) { + const int depth = 1; + const int image_width = 9; + const int image_height = 9; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int dilation_factor = 3; + QuantizedDepthwiseConvolutionOpModel m( + {TensorType_UINT8, + {image_batch_count, image_height, image_width, depth}, + 0, + 255}, + {TensorType_UINT8, + {depth, filter_size, filter_size, filter_count}, + 0, + 255}, + {TensorType_UINT8, {}, 0, 255}, dilation_factor); + + // The image matrix is: + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + // clang-format off + m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0}); + // clang-format on + // The filter matrix is: + // | 1 | 2 | 3 | + // | 4 | 5 | 6 | + // | 7 | 8 | 9 | + m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + // No bias for this test. + m.SetBias({0}); + m.Invoke(); + + // Since the dilation rate is 3 this will reduce the size of the output from + // 10x10 to 3x3 of all 5s. Specifically: + // | 5 | 5 | 5 | + // | 5 | 5 | 5 | + // | 5 | 5 | 5 | + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h index 7f6eea2d5d..70810ca784 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -1067,6 +1067,26 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, } } +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, + int dilation_width_factor, int dilation_height_factor, + int pad_width, int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + // TODO(suharshs): Optimized implementation of dilation depthwise conv need to + // be implemented. + TFLITE_DCHECK(dilation_width_factor == 1); + TFLITE_DCHECK(dilation_height_factor == 1); + + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index 3fd00c8930..f707279600 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -1964,6 +1964,30 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, } } +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, + int dilation_width_factor, int dilation_height_factor, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + // TODO(suharshs): Optimized implementation of dilation depthwise is not + // supported yet. + TFLITE_DCHECK(dilation_width_factor == 1); + TFLITE_DCHECK(dilation_height_factor == 1); + + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + // Legacy, for compatibility with old checked-in code. template <FusedActivationFunctionType Ac> void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h index 9aabee5000..bb5d590775 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h @@ -25,8 +25,9 @@ namespace reference_ops { inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int depth_multiplier, + int stride_width, int stride_height, + int dilation_width_factor, int dilation_height_factor, + int pad_width, int pad_height, int depth_multiplier, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { @@ -52,8 +53,9 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, float total = 0.f; for (int filter_y = 0; filter_y < filter_height; ++filter_y) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - const int in_x = in_x_origin + filter_x; - const int in_y = in_y_origin + filter_y; + const int in_x = in_x_origin + dilation_width_factor * filter_x; + const int in_y = + in_y_origin + dilation_height_factor * filter_y; // If the location is outside the bounds of the input image, // use zero as a default value. if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && @@ -81,6 +83,20 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, } } +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, 1, 1, pad_width, + pad_height, depth_multiplier, output_activation_min, + output_activation_max, output_data, output_dims); +} + // Legacy, for compatibility with old checked-in code. template <FusedActivationFunctionType Ac> void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h index d57739279f..5e3e8997fc 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -30,8 +30,9 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, const int32* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int depth_multiplier, + int stride_width, int stride_height, + int dilation_width_factor, int dilation_height_factor, + int pad_width, int pad_height, int depth_multiplier, int32 output_offset, int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_activation_max, uint8* output_data, @@ -58,8 +59,9 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, int32 acc = 0; for (int filter_y = 0; filter_y < filter_height; ++filter_y) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - const int in_x = in_x_origin + filter_x; - const int in_y = in_y_origin + filter_y; + const int in_x = in_x_origin + dilation_width_factor * filter_x; + const int in_y = + in_y_origin + dilation_height_factor * filter_y; // If the location is outside the bounds of the input image, // use zero as a default value. if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && @@ -90,6 +92,24 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, } } +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, 1, 1, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + // Legacy, for compatibility with old checked-in code. template <FusedActivationFunctionType Ac> void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index d5da4fcccf..f0db22d581 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -276,11 +276,15 @@ table Pool2DOptions { } table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. padding:Padding; stride_w:int; stride_h:int; depth_multiplier:int; fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; } table ConcatEmbeddingsOptions { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 0b9c57480e..8c086a5e67 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -2339,12 +2339,16 @@ struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable { int32_t stride_h; int32_t depth_multiplier; ActivationFunctionType fused_activation_function; + int32_t dilation_w_factor; + int32_t dilation_h_factor; DepthwiseConv2DOptionsT() : padding(Padding_SAME), stride_w(0), stride_h(0), depth_multiplier(0), - fused_activation_function(ActivationFunctionType_NONE) { + fused_activation_function(ActivationFunctionType_NONE), + dilation_w_factor(1), + dilation_h_factor(1) { } }; @@ -2355,7 +2359,9 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab VT_STRIDE_W = 6, VT_STRIDE_H = 8, VT_DEPTH_MULTIPLIER = 10, - VT_FUSED_ACTIVATION_FUNCTION = 12 + VT_FUSED_ACTIVATION_FUNCTION = 12, + VT_DILATION_W_FACTOR = 14, + VT_DILATION_H_FACTOR = 16 }; Padding padding() const { return static_cast<Padding>(GetField<int8_t>(VT_PADDING, 0)); @@ -2372,6 +2378,12 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab ActivationFunctionType fused_activation_function() const { return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + int32_t dilation_w_factor() const { + return GetField<int32_t>(VT_DILATION_W_FACTOR, 1); + } + int32_t dilation_h_factor() const { + return GetField<int32_t>(VT_DILATION_H_FACTOR, 1); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING) && @@ -2379,6 +2391,8 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab VerifyField<int32_t>(verifier, VT_STRIDE_H) && VerifyField<int32_t>(verifier, VT_DEPTH_MULTIPLIER) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR) && + VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR) && verifier.EndTable(); } DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2404,6 +2418,12 @@ struct DepthwiseConv2DOptionsBuilder { void add_fused_activation_function(ActivationFunctionType fused_activation_function) { fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); } + void add_dilation_w_factor(int32_t dilation_w_factor) { + fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1); + } + void add_dilation_h_factor(int32_t dilation_h_factor) { + fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1); + } explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2422,8 +2442,12 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions( int32_t stride_w = 0, int32_t stride_h = 0, int32_t depth_multiplier = 0, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + int32_t dilation_w_factor = 1, + int32_t dilation_h_factor = 1) { DepthwiseConv2DOptionsBuilder builder_(_fbb); + builder_.add_dilation_h_factor(dilation_h_factor); + builder_.add_dilation_w_factor(dilation_w_factor); builder_.add_depth_multiplier(depth_multiplier); builder_.add_stride_h(stride_h); builder_.add_stride_w(stride_w); @@ -7064,6 +7088,8 @@ inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const { auto _e = stride_h(); _o->stride_h = _e; }; { auto _e = depth_multiplier(); _o->depth_multiplier = _e; }; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; }; + { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; }; } inline flatbuffers::Offset<DepthwiseConv2DOptions> DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -7079,13 +7105,17 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions( auto _stride_h = _o->stride_h; auto _depth_multiplier = _o->depth_multiplier; auto _fused_activation_function = _o->fused_activation_function; + auto _dilation_w_factor = _o->dilation_w_factor; + auto _dilation_h_factor = _o->dilation_h_factor; return tflite::CreateDepthwiseConv2DOptions( _fbb, _padding, _stride_w, _stride_h, _depth_multiplier, - _fused_activation_function); + _fused_activation_function, + _dilation_w_factor, + _dilation_h_factor); } inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 5d0895c72f..3754b58b23 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1434,6 +1434,7 @@ def make_depthwiseconv_tests(zip_path): "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]], "filter_size": [[1, 1], [1, 2], [3, 3]], "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], + "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]], "channel_multiplier": [1, 2], "rate": [[1, 1]], "padding": ["SAME", "VALID"], @@ -1444,6 +1445,7 @@ def make_depthwiseconv_tests(zip_path): "input_shape": [[1, 3, 4, 3]], "filter_size": [[1, 1]], "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1] + "dilations": [[1, 1, 1, 1], [1, 2, 2, 1]], "channel_multiplier": [2], "rate": [[2, 2]], # Only [1, 1] is supported "padding": ["SAME"], diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 72c71b2841..96b88b60fc 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -331,7 +331,6 @@ cc_library( "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ] + select({ # Placeholder for internal darwin rule. @@ -348,6 +347,7 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "@com_google_googletest//:gtest_main", ], @@ -408,8 +408,11 @@ tf_cc_binary( ":toco_port", ":toco_tooling", ":types_proto_cc", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "//tensorflow/core:lib", + # We cannot embed the core:ops dependency directly into :toco_tooling as + # it can conflict with downstream deps when toco is used as a library. + "//tensorflow/core:ops", ], ) diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index eb36b3411d..efc1007925 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1089,6 +1089,8 @@ tensorflow::Status ConvertUnsupportedOperator( ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr()))); } else { LOG(INFO) << "Op node missing output type attribute: " << node.name(); + op->output_data_types.clear(); + break; } } } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 2e100e37f6..164b70f2df 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -477,6 +477,11 @@ struct DepthwiseConvOperator : Operator { int stride_height = 0; int stride_width = 0; int depth_multiplier = 0; + // A dilation_rate of 0 is invalid and this field is an optional attribute. + // Thus initializing it to 1 to allow default conv behavior when the + // attribute is not present. + int dilation_width_factor = 1; + int dilation_height_factor = 1; }; // Depth-to-space transform operator. diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 5486012176..1061e7c7c4 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -107,7 +107,8 @@ class DepthwiseConvolution ActivationFunction::Serialize(op.fused_activation_function); return ::tflite::CreateDepthwiseConv2DOptions( *builder, padding, op.stride_width, op.stride_height, - op.depth_multiplier, activation_function); + op.depth_multiplier, activation_function, op.dilation_width_factor, + op.dilation_height_factor); } void ReadOptions(const TfLiteOptions& options, @@ -118,9 +119,18 @@ class DepthwiseConvolution op->depth_multiplier = options.depth_multiplier(); op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); + op->dilation_width_factor = options.dilation_w_factor(); + op->dilation_height_factor = options.dilation_h_factor(); } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + return 2; + } + return 1; + } }; class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions, diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb new file mode 100644 index 0000000000..a96e2c4e1b --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb @@ -0,0 +1,702 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "6Y8E0lw5eYWm" + }, + "source": [ + "# Post Training Quantization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "CIGrZZPTZVeO" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BTC1rDAuei_1" + }, + "source": [ + "## Overview\n", + "\n", + "[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) now supports\n", + "converting weights to 8 bit precision as part of model conversion from\n", + "tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n", + "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n", + "fly quantization and dequantization of activations to allow for:\n", + "\n", + "1. Using quantized kernels for faster implementation when available.\n", + "\n", + "2. Mixing of floating-point kernels with quantized kernels for different parts\n", + " of the graph.\n", + "\n", + "Note that the activations are always stored in floating point. For ops that\n", + "support quantized kernels, the activations are quantized to 8 bits of precision\n", + "dynamically prior to processing and are de-quantized to float precision after\n", + "processing. Depending on the model being converted, this can give a speedup over\n", + "pure floating point computation.\n", + "\n", + "In contrast to\n", + "[quantization aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)\n", + ", the weights are quantized post training and the activations are quantized dynamically \n", + "at inference in this method.\n", + "Therefore, the model weights are not retrained to compensate for quantization\n", + "induced errors. It is important to check the accuracy of the quantized model to\n", + "ensure that the degradation is acceptable.\n", + "\n", + "In this tutorial, we train an MNIST model from scratch, check its accuracy in\n", + "tensorflow and then convert the saved model into a Tensorflow Lite flatbuffer\n", + "with weight quantization. We finally check the\n", + "accuracy of the converted model and compare it to the original saved model. We\n", + "run the training script mnist.py from\n", + "[Tensorflow official mnist tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2XsEP17Zelz9" + }, + "source": [ + "## Building an MNIST model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "dDqqUIZjZjac" + }, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gyqAw1M9lyab" + }, + "outputs": [], + "source": [ + "! pip uninstall -y tensorflow\n", + "! pip install -U tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WsN6s5L1ieNl" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "00U0taBoe-w7" + }, + "outputs": [], + "source": [ + "! git clone --depth 1 https://github.com/tensorflow/models" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4XZPtSh-fUOc" + }, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "\n", + "if sys.version_info.major \u003e= 3:\n", + " import pathlib\n", + "else:\n", + " import pathlib2 as pathlib\n", + "\n", + "# Add `models` to the python path.\n", + "models_path = os.path.join(os.getcwd(), \"models\")\n", + "sys.path.append(models_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eQ6Q0qqKZogR" + }, + "source": [ + "### Train and export the model" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "eMsw_6HujaqM" + }, + "outputs": [], + "source": [ + "saved_models_root = \"/tmp/mnist_saved_model\"" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "hWSAjQWagIHl" + }, + "outputs": [], + "source": [ + "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n", + "# Note: channels_last is required here or the conversion may fail. \n", + "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "5NMaNZQCkW9X" + }, + "source": [ + "For the example, we only trained the model for a single epoch, so it only trains to ~96% accuracy.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xl8_fzVAZwOh" + }, + "source": [ + "### Convert to a TFLite model\n", + "\n", + "The `savedmodel` directory is named with a timestamp. Select the most recent one: " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Xp5oClaZkbtn" + }, + "outputs": [], + "source": [ + "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n", + "saved_model_dir" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "AT8BgkKmljOy" + }, + "source": [ + "Using the python `TocoConverter`, the saved model can be converted into a TFLite model.\n", + "\n", + "First load the model using the `TocoConverter`:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_i8B2nDZmAgQ" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir)\n", + "tflite_model = converter.convert()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "F2o2ZfF0aiCx" + }, + "source": [ + "Write it out to a tflite file:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "vptWZq2xnclo" + }, + "outputs": [], + "source": [ + "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n", + "tflite_models_dir.mkdir(exist_ok=True, parents=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ie9pQaQrn5ue" + }, + "outputs": [], + "source": [ + "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n", + "tflite_model_file.write_bytes(tflite_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7BONhYtYocQY" + }, + "source": [ + "To quantize the model on export, set the `post_training_quantize` flag:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "g8PUvLWDlmmz" + }, + "outputs": [], + "source": [ + "# Note: If you don't have a recent tf-nightly installed, the\n", + "# \"post_training_quantize\" line will have no effect.\n", + "tf.logging.set_verbosity(tf.logging.INFO)\n", + "converter.post_training_quantize = True\n", + "tflite_quant_model = converter.convert()\n", + "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n", + "tflite_model_quant_file.write_bytes(tflite_quant_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PhMmUTl4sbkz" + }, + "source": [ + "Note how the resulting file, with `post_training_quantize` set, is approximately `1/4` the size." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "JExfcfLDscu4" + }, + "outputs": [], + "source": [ + "!ls -lh {tflite_models_dir}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L8lQHMp_asCq" + }, + "source": [ + "## Run the TFLite models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-5l6-ciItvX6" + }, + "source": [ + "We can run the TensorFlow Lite model using the python TensorFlow Lite\n", + "Interpreter. \n", + "\n", + "### load the test data\n", + "\n", + "First let's load the mnist test data to feed to it:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "eTIuU07NuKFL" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()\n", + "images, labels = tf.to_float(mnist_test[0])/255.0, mnist_test[1]\n", + "\n", + "# Note: If you change the batch size, then use \n", + "# `tf.contrib.lite.Interpreter.resize_tensor_input` to also change it for\n", + "# the interpreter.\n", + "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Ap_jE7QRvhPf" + }, + "source": [ + "### Load the model into an interpreter" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Jn16Rc23zTss" + }, + "outputs": [], + "source": [ + "interpreter = tf.contrib.lite.Interpreter(model_path=str(tflite_model_file))\n", + "interpreter.allocate_tensors()\n", + "input_index = interpreter.get_input_details()[0][\"index\"]\n", + "output_index = interpreter.get_output_details()[0][\"index\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "J8Pztk1mvNVL" + }, + "outputs": [], + "source": [ + "tf.logging.set_verbosity(tf.logging.DEBUG)\n", + "interpreter_quant = tf.contrib.lite.Interpreter(model_path=str(tflite_model_quant_file))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Afl6yGvWyqAr" + }, + "outputs": [], + "source": [ + "interpreter_quant.allocate_tensors()\n", + "input_index = interpreter_quant.get_input_details()[0][\"index\"]\n", + "output_index = interpreter_quant.get_output_details()[0][\"index\"]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2opUt_JTdyEu" + }, + "source": [ + "### Test the model on one image" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "AKslvo2kwWac" + }, + "outputs": [], + "source": [ + "for img, label in mnist_ds.take(1):\n", + " break\n", + "\n", + "interpreter.set_tensor(input_index, img)\n", + "interpreter.invoke()\n", + "predictions = interpreter.get_tensor(output_index)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "XZClM2vo3_bm" + }, + "outputs": [], + "source": [ + "import matplotlib.pylab as plt\n", + "\n", + "plt.imshow(img[0])\n", + "template = \"True:{true}, predicted:{predict}\"\n", + "_ = plt.title(template.format(true= str(label[0].numpy()),\n", + " predict=str(predictions[0,0])))\n", + "plt.grid(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LwN7uIdCd8Gw" + }, + "source": [ + "### Evaluate the models" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "05aeAuWjvjPx" + }, + "outputs": [], + "source": [ + "def eval_model(interpreter, mnist_ds):\n", + " total_seen = 0\n", + " num_correct = 0\n", + "\n", + " for img, label in mnist_ds:\n", + " total_seen += 1\n", + " interpreter.set_tensor(input_index, img)\n", + " interpreter.invoke()\n", + " predictions = interpreter.get_tensor(output_index)\n", + " if predictions == label.numpy():\n", + " num_correct += 1\n", + "\n", + " if total_seen % 500 == 0:\n", + " print(\"Accuracy after %i images: %f\" %\n", + " (total_seen, float(num_correct) / float(total_seen)))\n", + "\n", + " return float(num_correct) / float(total_seen)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "DqXBnDfJ7qxL" + }, + "outputs": [], + "source": [ + "print(eval_model(interpreter_quant, mnist_ds))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Km3cY9ry8ZlG" + }, + "source": [ + "We can repeat the evaluation on the weight quantized model to obtain:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "-9cnwiPp6EGm" + }, + "outputs": [], + "source": [ + "print(eval_model(interpreter_quant, mnist_ds))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L7lfxkor8pgv" + }, + "source": [ + "\n", + "In this example, we have compressed model with no difference in the accuracy." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "M0o1FtmWeKZm" + }, + "source": [ + "\n", + "\n", + "## Optimizing an existing model\n", + "\n", + "We now consider another example. Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n", + " Pre-trained frozen graph for resnet-v2-101 is available at the\n", + " [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md).\n", + "\n", + "We can convert the frozen graph to a TFLite flatbuffer with quantization by:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "v5p5VcNPjILQ" + }, + "outputs": [], + "source": [ + "archive_path = tf.keras.utils.get_file(\"resnet_v2_101.tgz\", \"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz\", extract=True)\n", + "archive_path = pathlib.Path(archive_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-sxnXQuC4ThD" + }, + "source": [ + "The `info.txt` file lists the input and output names. You can also find them using TensorBoard to visually inspect the graph." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "g_Q_OMEJ4LIc" + }, + "outputs": [], + "source": [ + "! cat {archive_path}/resnet_v2_101_299_info.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ujCAFhqm-C6H" + }, + "outputs": [], + "source": [ + "graph_def_file = pathlib.Path(archive_path).parent/\"resnet_v2_101_299_frozen.pb\"\n", + "input_arrays = [\"input\"] \n", + "output_arrays = [\"output\"]\n", + "converter = tf.contrib.lite.TocoConverter.from_frozen_graph(\n", + " str(graph_def_file), input_arrays, output_arrays, input_shapes={\"input\":[1,299,299,3]})\n", + "converter.post_training_quantize = True\n", + "resnet_tflite_file = graph_def_file.parent/\"resnet_v2_101_quantized.tflite\"\n", + "resnet_tflite_file.write_bytes(converter.convert())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "vhOjeg1x9Knp" + }, + "outputs": [], + "source": [ + "archive_dir = str(archive_path.parent)\n", + "!ls -lh {archive_dir}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qqHLaqFMCjRZ" + }, + "source": [ + "\n", + "The model size reduces from 171 MB to 43 MB.\n", + "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/accuracy/ilsvrc).\n", + "\n", + "The optimized model top-1 accuracy is 76.8, the same as the floating point model." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "post-training-quant.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 8f32bc2844..1a86bff5cd 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1921,6 +1921,13 @@ tf_pyclif_proto_library( ) tf_pyclif_proto_library( + name = "protobuf/config_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "protobuf/config.proto", + visibility = ["//visibility:public"], +) + +tf_pyclif_proto_library( name = "protobuf/device_properties_pyclif", proto_lib = ":protos_all_cc", proto_srcfile = "protobuf/device_properties.proto", diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 7f260b3139..4475fa979e 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -561,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph( grappler::GrapplerItem item; item.id = "tf_graph"; graph_->ToGraphDef(&item.graph); + // TODO(b/114748242): Add a unit test to test this bug fix. + if (flib_def_) { + *item.graph.mutable_library() = flib_def_->ToProto(); + } item.fetch.insert(item.fetch.end(), options.callable_options.fetch().begin(), diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 2e644fe987..f5b0105862 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index d273eddf81..56c8339d57 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -260,13 +260,13 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) { } bool IsEnqueue(const NodeDef& n) { - return (n.op().find("Enqueue") != std::string::npos && - n.op().find("EnqueueMany") == std::string::npos); + return (n.op().find("Enqueue") != string::npos && + n.op().find("EnqueueMany") == string::npos); } bool IsDequeue(const NodeDef& n) { - return (n.op().find("Dequeue") != std::string::npos && - n.op().find("DequeueMany") == std::string::npos); + return (n.op().find("Dequeue") != string::npos && + n.op().find("DequeueMany") == string::npos); } bool HasAnyUnknownDimensions(const TensorShapeProto& proto) { diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index aad00ce039..83434ea40f 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -127,7 +127,7 @@ static void ExtractExtraProperties( // For filename input, the file size can also be useful. if (op_def && i < op_def->input_arg_size() && - op_def->input_arg(i).name().find("filename") != std::string::npos) { + op_def->input_arg(i).name().find("filename") != string::npos) { Tensor tensor; if (!tensor.FromProto(t)) { continue; @@ -153,7 +153,7 @@ static void ExtractExtraProperties( // When the input is a handle (e.g. look up table handle), the information // in the op itself is not sufficient to predict the op memory. if (op_def && i < op_def->input_arg_size() && - op_def->input_arg(i).name().find("handle") != std::string::npos) { + op_def->input_arg(i).name().find("handle") != string::npos) { string new_key = strings::StrCat("parent_", i, "_op"); AttrValue attr; attr.set_s(input_node->op()); @@ -320,8 +320,8 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) { buckets_.begin(), std::plus<uint64>()); } -std::string TensorSizeHistogram::ToString() const { - std::string r; +string TensorSizeHistogram::ToString() const { + string r; char buf[200]; snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_); r.append(buf); diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h index d2c7c67666..5fd6717712 100644 --- a/tensorflow/core/grappler/costs/utils.h +++ b/tensorflow/core/grappler/costs/utils.h @@ -80,7 +80,7 @@ class TensorSizeHistogram { uint64 Max() const { return max_; } uint64 NumElem() const { return num_elem_; } uint64 SumElem() const { return sum_elem_; } - std::string ToString() const; + string ToString() const; protected: const int Index(const uint64 value) const; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 02a379fca8..80889afc86 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -1999,13 +1999,13 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { // Helper lambda to extract port num from _Send and _Recv op name. auto get_port_num = [](const string& name) -> int { - if (name.find("bn_0") != std::string::npos) { + if (name.find("bn_0") != string::npos) { return 0; - } else if (name.find("bn_1") != std::string::npos) { + } else if (name.find("bn_1") != string::npos) { return 1; - } else if (name.find("bn_2") != std::string::npos) { + } else if (name.find("bn_2") != string::npos) { return 2; - } else if (name.find("bn_minus1") != std::string::npos) { + } else if (name.find("bn_minus1") != string::npos) { return -1; } return -999; diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc index 5029dff877..def9198a69 100644 --- a/tensorflow/core/grappler/inputs/utils.cc +++ b/tensorflow/core/grappler/inputs/utils.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/inputs/utils.h" -#include "tensorflow/core/platform/env.h" #include <vector> +#include "tensorflow/core/platform/env.h" + namespace tensorflow { namespace grappler { @@ -29,12 +30,12 @@ bool FilesExist(const std::set<string>& files) { return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr); } -bool FileExists(const std::string& file, Status* status) { +bool FileExists(const string& file, Status* status) { *status = Env::Default()->FileExists(file); return status->ok(); } -Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path, +Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path, GraphDef* result) { Status status; if (FileExists(graph_def_pbtxt_path, &status)) { diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h index 627dd5359f..4b9cb0a9ad 100644 --- a/tensorflow/core/grappler/inputs/utils.h +++ b/tensorflow/core/grappler/inputs/utils.h @@ -29,9 +29,9 @@ bool FilesExist(const std::vector<string>& files, std::vector<Status>* status = nullptr); bool FilesExist(const std::set<string>& files); -bool FileExists(const std::string& file, Status* status); +bool FileExists(const string& file, Status* status); -Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path, +Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path, GraphDef* result); } // end namespace grappler diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index e78239bd43..3521669b63 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -491,7 +491,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) { } } // Queue ops modify the queue which is a side effect. - if (node.op().find("Queue") != std::string::npos) { + if (node.op().find("Queue") != string::npos) { return false; } return !ModifiesInputsInPlace(node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 39517edc06..bc838c6659 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -581,7 +581,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); - EXPECT_EQ(std::string("\0\0\0@", 4), + EXPECT_EQ(string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); @@ -625,7 +625,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); - EXPECT_EQ(std::string("\0\0\0@", 4), + EXPECT_EQ(string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 5a7fe19265..d4ab444036 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -273,7 +273,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, string name = string(prefix); int id = graph->node_size(); while (ContainsGraphNodeWithName(name, *graph)) { - if (name.rfind("_generated") != std::string::npos && + if (name.rfind("_generated") != string::npos && (name.rfind("_generated") == (name.size() - strlen("_generated")))) { name.insert(name.rfind("_generated"), strings::StrCat("/_", id)); } else { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 8c99598748..7ed4a67333 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -72,6 +72,16 @@ bool IsRunOnceOptimizer(const string& name) { name == "loop_optimizer"; } +// Check if the graphdef contains nodes that indicate TPU execution. +bool IsTPUGraphDef(const GraphDef& def) { + for (auto node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} + } // namespace #define MK_OPT(NAME, VALUE) \ @@ -338,6 +348,19 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph)); VLOG(1) << "Optimized main graph."; + // Skip optimizing functions if this is a TPU graph. Currently, Grappler + // passes do not handle TPU functions correctly in a variety of ways (Note + // that due to the pre-placement TPU graph rewriting passes, the TPU-related + // ops are encapsulated away into functions). For example, TPU graphs contain + // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler + // passes could prune that away. Grappler passes could also cause issues + // around shape inference. Since the desired and existing behavior is to not + // optimize TPU functions with Grappler, this check preserves that. + if (IsTPUGraphDef(*optimized_graph)) { + VLOG(2) << "Skipping optimizing funcs for TPU graphs"; + return Status::OK(); + } + // 2. Optimize function library FunctionLibraryDefinition flib(OpRegistry::Global(), optimized_graph->library()); diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc index 533d0bd5d2..da357339c9 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc @@ -26,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size) } } +namespace { +// Determines what strategy to use for increasing the buffer size limit. For +// limits less than the threshold, an exponential increase is used, while for +// limits greater than or equal to the threshold, a linear increase is used. +size_t kBufferLimitThreshold = 2048; +} // namespace + void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) { switch (mode_) { case Mode::kDisabled: @@ -37,7 +44,11 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) { return; case Mode::kDownswing: if (current_buffer_size == 0) { - buffer_limit_ *= 2; // Increase the buffer size. + if (buffer_limit_ >= kBufferLimitThreshold) { + buffer_limit_ += kBufferLimitThreshold; + } else { + buffer_limit_ *= 2; + } mode_ = Mode::kUpswing; } return; diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc index 750efca592..ae451be7e2 100644 --- a/tensorflow/core/kernels/decode_bmp_op.cc +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel { errors::InvalidArgument( "Number of channels must be 1, 3 or 4, was ", channels_)); - OP_REQUIRES(context, width > 0 && header_size >= 0, + OP_REQUIRES(context, width > 0, errors::InvalidArgument("Width must be positive")); + OP_REQUIRES(context, height != 0, + errors::InvalidArgument("Height must be nonzero")); OP_REQUIRES(context, header_size >= 0, errors::InvalidArgument("header size must be nonnegative")); diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 1a1ed04e0d..8a100fe975 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 13) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 14) @tf_export("compat.forward_compatible") diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index c6749468c8..fed07c4120 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -209,8 +209,27 @@ class Model(Network): for metric in metrics: metric_fn = training_utils.get_metric_function( metric, output_shape=output_shape, loss_fn=loss_fn) - metric_name = self._get_metric_name( - metric, output_index, weighted=weights is not None) + + if (context.executing_eagerly() and y_true is not None and + y_pred is not None): + # In eager mode, when executing metric_fn during training, we do not + # need to generate unique metric name and add it to the model + # as we have done that during compile already. + prefix = 'weighted_' if weights is not None else '' + suffix = metric_fn.name if hasattr(metric_fn, + 'name') else metric_fn.__name__ + metric_name = prefix + suffix + else: + # Get metric name that is to be added to the model. + metric_name = self._get_metric_name( + metric, output_index, weighted=weights is not None) + # Keep track of metric name. + self.metrics_names.append(metric_name) + + # Keep track of stateful metric attributes (name and metric function). + if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: + self.stateful_metric_names.append(metric_name) + self.stateful_metric_functions.append(metric_fn) with K.name_scope(metric_name): # If both outputs and targets are available, call the metric function. @@ -250,16 +269,10 @@ class Model(Network): self.metrics_tensors.append(metric_result) metric_results.append(metric_result) - # Keep track of metric name. - self.metrics_names.append(metric_name) - - # Keep track of stateful metric attributes (name and metric function). - if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: - self.stateful_metric_names.append(metric_name) - self.stateful_metric_functions.append(metric_fn) - if not context.executing_eagerly(): - # Keep track of updates created by stateful metrics. - self.metrics_updates += metric_fn.updates + if (isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful and + not context.executing_eagerly()): + # Keep track of updates created by stateful metrics. + self.metrics_updates += metric_fn.updates return metric_results def _handle_metrics(self, diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 380130095b..30be4131a4 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -2256,7 +2256,26 @@ class TestTrainingWithMetrics(test.TestCase): 'dense_binary_accuracy', 'dropout_mean_squared_error', 'dropout_binary_accuracy' ] + reference_stateful_metric_names = [ + 'dense_binary_accuracy', 'dropout_binary_accuracy' + ] + self.assertEqual(reference_metric_names, model.metrics_names) + self.assertEqual(reference_stateful_metric_names, + model.stateful_metric_names) + + # Verify that model metric names are not altered during training. + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + model.fit([input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5) self.assertEqual(reference_metric_names, model.metrics_names) + self.assertEqual(reference_stateful_metric_names, + model.stateful_metric_names) @tf_test_util.run_in_graph_and_eager_modes def test_metrics_correctness(self): diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index 94c7d88b5c..a404507627 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -234,6 +234,7 @@ def create_file_writer(logdir, """ if logdir is None: return SummaryWriter(None, None) + logdir = str(logdir) with ops.device("cpu:0"): if max_queue is None: max_queue = constant_op.constant(10) diff --git a/tensorflow/python/summary/writer/event_file_writer.py b/tensorflow/python/summary/writer/event_file_writer.py index 2936a279bd..14dec982a6 100644 --- a/tensorflow/python/summary/writer/event_file_writer.py +++ b/tensorflow/python/summary/writer/event_file_writer.py @@ -62,7 +62,7 @@ class EventFileWriter(object): filename_suffix: A string. Every event file's name is suffixed with `filename_suffix`. """ - self._logdir = logdir + self._logdir = str(logdir) if not gfile.IsDirectory(self._logdir): gfile.MakeDirs(self._logdir) self._event_queue = six.moves.queue.Queue(max_queue) diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 99bed5714f..d06c7f2d49 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -174,7 +174,7 @@ class ApiCompatibilityTest(test.TestCase): verbose_diff_message = diff_message else: # Do not truncate diff - self.maxDiffs = None # pylint: disable=invalid-name + self.maxDiff = None # pylint: disable=invalid-name # Now we can run an actual proto diff. try: self.assertProtoEquals(expected_dict[key], actual_dict[key]) diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh index 48b3989d86..03a2a07fb1 100755 --- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh +++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh @@ -31,6 +31,28 @@ TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-4} # future and to use a rounder number, we set it to 1G. export TF_PER_DEVICE_MEMORY_LIMIT_MB=1024 +# ******************************************************************* +# This section of the script is needed to +# make things work on windows under msys. +# ******************************************************************* +RUNFILES_MANIFEST_FILE="${TEST_SRCDIR}/MANIFEST" +function rlocation() { + if is_absolute "$1" ; then + # If the file path is already fully specified, simply return it. + echo "$1" + elif [[ -e "$TEST_SRCDIR/$1" ]]; then + # If the file exists in the $TEST_SRCDIR then just use it. + echo "$TEST_SRCDIR/$1" + elif [[ -e "$RUNFILES_MANIFEST_FILE" ]]; then + # If a runfiles manifest file exists then use it. + echo "$(grep "^$1 " "$RUNFILES_MANIFEST_FILE" | sed 's/[^ ]* //')" + fi +} + +TEST_BINARY="$(rlocation $TEST_WORKSPACE/${1#./})" +shift +# ******************************************************************* + mkdir -p /var/lock # Try to acquire any of the TF_GPU_COUNT * TF_TESTS_PER_GPU # slots to run a test at. @@ -46,8 +68,8 @@ for j in `seq 0 $((TF_TESTS_PER_GPU-1))`; do # This export only works within the brackets, so it is isolated to one # single command. export CUDA_VISIBLE_DEVICES=$i - echo "Running test $@ on GPU $CUDA_VISIBLE_DEVICES" - $@ + echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES" + "$TEST_BINARY" $@ ) return_code=$? flock -u "$lock_fd" diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 65314a4a06..25698da1c9 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -106,11 +106,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "com_google_absl", urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/02451914b9ad5320f81f56a89f3eef1f8683227c.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/02451914b9ad5320f81f56a89f3eef1f8683227c.tar.gz", + "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz", ], - sha256 = "345fa25136484a9e5d918880d66ee577a9cb24377f8978d4e5a6c543706a1011", - strip_prefix = "abseil-cpp-02451914b9ad5320f81f56a89f3eef1f8683227c", + sha256 = "006931f9705484041eed65189038f87931a87cff200bb296f94b3d42339c4cd9", + strip_prefix = "abseil-cpp-8ff1374008259719b54a8cb128ef951c02da164c", build_file = clean_dep("//third_party:com_google_absl.BUILD"), ) |