aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar (David) Siu-Kei Muk <muksiukei@gmail.com>2018-09-14 19:58:26 +0800
committerGravatar (David) Siu-Kei Muk <muksiukei@gmail.com>2018-09-14 19:58:26 +0800
commitae7e8d01372a2df39dc5669b00735529c5cfffb9 (patch)
tree3183e7729343d1426efce92a06dc4ce886d9844b
parent51d72a7d7f74784b68916819edd04e890b36f957 (diff)
parent54cbee5d034af8693aa39cc5877c3dfcd62d3740 (diff)
Merge branch 'master' of https://github.com/tensorflow/tensorflow into est_spec_metrics_ops_check_tensor
-rw-r--r--tensorflow/compiler/jit/BUILD6
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc17
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h6
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc360
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.h60
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc346
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc9
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc19
-rw-r--r--tensorflow/compiler/tests/BUILD3
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py35
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.h13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc33
-rw-r--r--tensorflow/compiler/tf2xla/test_util.cc8
-rw-r--r--tensorflow/compiler/tf2xla/test_util.h16
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/protobuf_util.cc29
-rw-r--r--tensorflow/compiler/xla/protobuf_util.h4
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc64
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h16
-rw-r--r--tensorflow/compiler/xla/service/service.cc7
-rw-r--r--tensorflow/compiler/xla/service/service.h4
-rw-r--r--tensorflow/contrib/cmake/README.md4
-rw-r--r--tensorflow/contrib/data/__init__.py9
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py3
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py39
-rw-r--r--tensorflow/contrib/distribute/python/BUILD1
-rw-r--r--tensorflow/contrib/distribute/python/single_loss_example.py6
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc3
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.cc14
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.h7
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc61
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc116
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h20
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h28
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs4
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h38
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py2
-rw-r--r--tensorflow/contrib/lite/toco/BUILD7
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc2
-rw-r--r--tensorflow/contrib/lite/toco/model.h5
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc14
-rw-r--r--tensorflow/contrib/lite/tutorials/post_training_quant.ipynb702
-rw-r--r--tensorflow/core/BUILD7
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc4
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc8
-rw-r--r--tensorflow/core/grappler/costs/utils.cc8
-rw-r--r--tensorflow/core/grappler/costs/utils.h2
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc8
-rw-r--r--tensorflow/core/grappler/inputs/utils.cc7
-rw-r--r--tensorflow/core/grappler/inputs/utils.h4
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc23
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.cc13
-rw-r--r--tensorflow/core/kernels/decode_bmp_op.cc4
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/keras/engine/training.py37
-rw-r--r--tensorflow/python/keras/engine/training_test.py19
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py1
-rw-r--r--tensorflow/python/summary/writer/event_file_writer.py2
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py2
-rwxr-xr-xtensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh26
-rwxr-xr-xtensorflow/workspace.bzl8
75 files changed, 2249 insertions, 176 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 7d5db713f6..f4e1bc5e83 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -363,6 +363,7 @@ cc_library(
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
+ "encapsulate_xla_computations_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@@ -371,6 +372,7 @@ cc_library(
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
+ "encapsulate_xla_computations_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
@@ -397,6 +399,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
@@ -475,6 +478,7 @@ tf_cc_test(
size = "small",
srcs = [
"encapsulate_subgraphs_pass_test.cc",
+ "encapsulate_xla_computations_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],
@@ -490,7 +494,9 @@ tf_cc_test(
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index ae7a22f451..e0632ff7e4 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
+void SortControlInputs(GraphDef* gdef) {
+ int64 num_nodes = gdef->node_size();
+ for (int64 i = 0; i < num_nodes; ++i) {
+ NodeDef* node = gdef->mutable_node(i);
+ // Stable sort control inputs and leave the order of data inputs unchanged.
+ std::stable_sort(node->mutable_input()->begin(),
+ node->mutable_input()->end(),
+ [](const string& a, const string& b) {
+ bool a_is_control = absl::StartsWith(a, "^");
+ bool b_is_control = absl::StartsWith(b, "^");
+ return (!a_is_control && b_is_control) ||
+ (a_is_control && b_is_control && a < b);
+ });
+ }
+}
+
namespace {
bool AreAllParentsGuaranteedConst(
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 926589546f..90354a801a 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr;
// Name of the attribute containing the number of resource variable arguments.
extern const char* const kXlaNumResourceArgsAttr;
+// Sorts each node's control inputs by their names. This guarantees that for two
+// structually equivalent GraphDefs, we get the same traversal ordering on
+// node's control input fields.
+// TODO(hpucha): Move the utilities to a more appropriate place.
+void SortControlInputs(GraphDef* gdef);
+
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
new file mode 100644
index 0000000000..97ef8cd3cb
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -0,0 +1,360 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+
+namespace tensorflow {
+
+const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
+ "_xla_compile_id";
+
+namespace {
+
+const char* const kXlaClusterOutput = "XlaClusterOutput";
+
+// Checks if a graph node is marked to be a guaranteed constant.
+bool is_guaranteed_constant(const Node& n) {
+ bool guaranteed_constant = false;
+ if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
+ .ok()) {
+ return false;
+ }
+ return guaranteed_constant;
+}
+
+// Finds the `index` of an _Arg or _Retval node.
+Status GetIndexAttr(const Node& n, int num_args, int* index) {
+ TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
+ if (*index < 0 || *index >= num_args) {
+ return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
+ *index);
+ }
+ return Status::OK();
+}
+
+// Returns the data type of the destination of an edge.
+DataType EdgeType(const Edge* edge) {
+ return edge->dst()->input_type(edge->dst_input());
+}
+
+// Adds the control inputs of `node` to `*deps`.
+void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+ for (const Edge* edge : node.in_edges()) {
+ if (edge->IsControlEdge()) {
+ deps->insert(edge->src());
+ }
+ }
+}
+
+// Adds the control outputs of `node` to `*deps`.
+void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+ for (const Edge* edge : node.out_edges()) {
+ if (edge->IsControlEdge()) {
+ deps->insert(edge->dst());
+ }
+ }
+}
+
+// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
+// the arguments into the order expected by XlaLaunch computations:
+// 1) arguments
+// 2) resource variable arguments
+// See the documentation of EncapsulateSubgraphsInFunctions for the meaning
+// of the arguments.
+//
+// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed.
+Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* graph_ptr,
+ std::vector<int>* input_permutation,
+ std::vector<int>* output_permutation,
+ NodeDef* call_def) {
+ Graph* graph = graph_ptr->get();
+ const int num_args = input_permutation->size();
+ const int num_retvals = output_permutation->size();
+
+ std::vector<Node*> args;
+ std::vector<Node*> retvals;
+ args.reserve(num_args);
+ retvals.reserve(num_retvals);
+ for (Node* n : graph->nodes()) {
+ if (n->type_string() == "_Arg") {
+ // Check if this is a guaranteed constant.
+ if (is_guaranteed_constant(*n)) {
+ return errors::InvalidArgument(
+ "Guaranteed constants are not supported (", n->name(), ")");
+ }
+ args.push_back(n);
+ } else if (n->type_string() == "_Retval") {
+ retvals.push_back(n);
+ }
+ }
+
+ if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
+ return errors::InvalidArgument("Missing or non-consecutive arguments");
+ }
+
+ // Reorders the arguments.
+ std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
+ // Non-resources appear before resources
+ bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
+ bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
+ // Uses the name as a tiebreaker so the output is deterministic.
+ StringPiece a_name(a->name());
+ StringPiece b_name(b->name());
+ return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name);
+ });
+
+ // Sorts the retvals by name so the order is deterministic.
+ std::sort(retvals.begin(), retvals.end(),
+ [](Node* a, Node* b) { return a->name() < b->name(); });
+
+ // Computes the permutation to produce the correct argument order, and update
+ // the argument indices.
+ int variable_start_index = num_args;
+ for (int i = 0; i < num_args; ++i) {
+ int index;
+ TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
+ if (args[i]->output_type(0) == DT_RESOURCE &&
+ variable_start_index == num_args) {
+ variable_start_index = i;
+ }
+ (*input_permutation)[index] = i;
+ args[i]->AddAttr("index", i);
+ }
+ VLOG(4) << "variable_start_index: " << variable_start_index;
+
+ // Computes the permutation to produce the correct retval order, and update
+ // the argument indices.
+ for (int i = 0; i < num_retvals; ++i) {
+ int index;
+ TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
+ (*output_permutation)[index] = i;
+ retvals[i]->AddAttr("index", i);
+ }
+
+ AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
+ call_def);
+ AddNodeAttr("_variable_start_index", variable_start_index, call_def);
+
+ // Uniquify the function name.
+ GraphDef gdef;
+ graph->ToGraphDef(&gdef);
+
+ // Before serialization, sort each node's control inputs to achieve
+ // determinism. Sorting control inputs could help (but not necessarily) create
+ // a deterministic serialization and fingerprint. Other sources of
+ // nondeterminism include unstable node ordering.
+ SortControlInputs(&gdef);
+ // Fingerprint the function.
+ // Nondeterminism in serialization would not lead to incorrect results, but
+ // may cause spurious cache misses. DeterministicSerialization is a
+ // best-effort deterministic serialization.
+ string serialized;
+ TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
+ uint64 fingerprint = Fingerprint64(serialized);
+ LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
+ call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
+ return Status::OK();
+}
+
+} // namespace
+
+/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate(
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ // Check for undeclared outputs before Encapsulation, so we can give a better
+ // error message.
+ // TODO(phawkins): merge this with the encapsulation code to avoid the extra
+ // O(n) pass over the edges.
+ for (const Edge* e : (*graph)->edges()) {
+ if (!e->IsControlEdge() &&
+ e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
+ e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
+ e->dst()->type_string() != kXlaClusterOutput) {
+ return errors::InvalidArgument(
+ "Undeclared output of XLA computation. A common cause of this error "
+ "is variable initializers that depend on the XLA computation. Edge: ",
+ e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
+ e->dst_input());
+ }
+ }
+
+ auto output = absl::make_unique<Graph>((*graph)->op_registry());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ EncapsulateSubgraphsInFunctions(
+ kXlaClusterAttr, "", **graph, RewriteSubgraph,
+ /*reuse_existing_functions=*/true, &output, flib_def),
+ "EncapsulateXlaComputationsPass failed");
+ graph->swap(output);
+ return Status::OK();
+}
+
+/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps(
+ Graph* graph) {
+ // Finds all of the XlaLaunch function calls, to avoid mutating the graph
+ // while iterating.
+ std::vector<Node*> launch_nodes;
+ for (Node* n : graph->nodes()) {
+ string name;
+ if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
+ launch_nodes.push_back(n);
+ }
+ }
+
+ // Replaces each launch function call together with its neighboring
+ // XlaClusterOutput nodes with a XlaLaunch node.
+ for (Node* launch : launch_nodes) {
+ int variable_start_index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index",
+ &variable_start_index));
+
+ std::vector<const Edge*> in_edges;
+ TF_RETURN_IF_ERROR(launch->input_edges(&in_edges));
+
+ const int num_inputs = in_edges.size();
+ const int num_variables = num_inputs - variable_start_index;
+ const int num_args = variable_start_index;
+
+ VLOG(4) << "Launch node '" << launch->name() << "'"
+ << " input edges: " << in_edges.size() << " num_args: " << num_args
+ << " num_variables: " << num_variables;
+
+ std::vector<Node*> nodes_to_remove = {launch};
+
+ // Data and control inputs to the new XlaLaunch node.
+ std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
+ gtl::FlatSet<Node*> control_inputs;
+ DataTypeVector arg_types(num_args);
+
+ AddControlInputs(*launch, &control_inputs);
+
+ for (int i = 0; i < num_args; ++i) {
+ const Edge* edge = in_edges[i];
+ data_inputs[i] = {edge->src(), edge->src_output()};
+ arg_types[i] = EdgeType(edge);
+ }
+
+ // Appends the variable inputs.
+ for (int i = 0; i < num_variables; ++i) {
+ int pos = variable_start_index + i;
+ const Edge* edge = in_edges[pos];
+ data_inputs[pos] = {edge->src(), edge->src_output()};
+ }
+
+ // Outputs.
+ const int num_outputs = launch->output_types().size();
+ gtl::FlatSet<Node*> control_outputs;
+ std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
+ DataTypeVector output_types(num_outputs);
+
+ for (const Edge* le : launch->out_edges()) {
+ if (le->IsControlEdge()) {
+ control_outputs.insert(le->dst());
+ } else {
+ TF_RET_CHECK(le->src_output() < num_outputs);
+ Node* output_node = le->dst();
+
+ TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput)
+ << le->DebugString();
+ nodes_to_remove.push_back(output_node);
+
+ for (const Edge* oe : output_node->out_edges()) {
+ TF_RET_CHECK(!oe->IsControlEdge());
+ data_outputs[le->src_output()].push_back(
+ {oe->dst(), oe->dst_input()});
+ }
+ output_types[le->src_output()] = output_node->input_type(0);
+
+ AddControlOutputs(*output_node, &control_outputs);
+ }
+ }
+
+ NodeDef def;
+ def.set_name(launch->name());
+
+ // Target the XLA CPU/GPU backends.
+ VLOG(2) << "Replacing with XlaLaunch";
+ def.set_op("XlaLaunch");
+ AddNodeAttr("Tconstants", DataTypeVector{}, &def);
+ AddNodeAttr("Targs", arg_types, &def);
+ AddNodeAttr("Nresources", num_variables, &def);
+ AddNodeAttr("Tresults", output_types, &def);
+ NameAttrList function;
+ function.set_name(launch->type_string());
+ AddNodeAttr("function", function, &def);
+
+ for (Node* node : nodes_to_remove) {
+ VLOG(2) << "Deleting node " << node->DebugString();
+ // Ensure that we do not attempt to add control edges to nodes that are
+ // deleted.
+ control_inputs.erase(node);
+ control_outputs.erase(node);
+ graph->RemoveNode(node);
+ }
+
+ Status status;
+ Node* xla_launch = graph->AddNode(def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ for (int i = 0; i < data_inputs.size(); ++i) {
+ graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch,
+ i);
+ }
+ for (Node* n : control_inputs) {
+ graph->AddControlEdge(n, xla_launch);
+ }
+ for (int i = 0; i < data_outputs.size(); ++i) {
+ for (const auto& successor : data_outputs[i]) {
+ graph->AddEdge(xla_launch, i, successor.first, successor.second);
+ }
+ }
+ for (Node* n : control_outputs) {
+ graph->AddControlEdge(xla_launch, n);
+ }
+ }
+ return Status::OK();
+}
+
+Status EncapsulateXlaComputationsPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ VLOG(1) << "EncapsulateXlaComputations(): "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
+ **options.graph, options.flib_def);
+
+ TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
+ VLOG(1) << "EncapsulateXlaComputations() half-way: "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
+ **options.graph, options.flib_def);
+
+ TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get()));
+ VLOG(1) << "EncapsulateXlaComputations() finished: "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
+ **options.graph, options.flib_def);
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
new file mode 100644
index 0000000000..99e9dfd598
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+// Rewrites computations generated by the xla.compile() Python code into
+// XlaLaunch nodes.
+//
+// xla.compile() does two main things:
+// a) marks operators that make up an XLA computation with the attribute
+// _xla_compile_id=XYZ, where XYZ is a unique key.
+// b) adds XlaClusterOutput nodes to represent outputs of the computation.
+// These nodes are not marked with the _xla_compile_id attribute.
+
+#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/env.h"
+
+ namespace tensorflow {
+
+// Encapsulates nodes marked with the _xla_compile_id attribute into
+// XlaLaunch operators.
+class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
+ public:
+ static const char* const kXlaClusterAttr; // _xla_compile_id
+
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ // The following methods are public only for unit tests.
+
+ // This pass has two stages:
+ // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes
+ // marked with the same _xla_compile_id attribute into functions. These
+ // functions contain the computations to be passed to XlaLaunch. During
+ // encapsulation, we sort the arguments into the order expected by
+ // XlaLaunch.
+ static Status Encapsulate(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def);
+
+ // b) we rewrite the function calls generated in phase (a) into XlaLaunch
+ // operators. We also convert the XlaClusterOutput output nodes of the
+ // function call into the outputs of the XlaLaunch operator.
+ static Status BuildXlaLaunchOps(Graph* graph);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
new file mode 100644
index 0000000000..f643fb0cfe
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -0,0 +1,346 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/test_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+static std::unique_ptr<Graph> MakeOuterGraph(
+ const FunctionLibraryDefinition& flib_def, const string& function) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
+
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ NodeDef def;
+ TF_CHECK_OK(
+ NodeDefBuilder("launch0", function, &flib_def)
+ .Input(a.node()->name(), 0, DT_INT32)
+ .Input(b.node()->name(), 0, DT_FLOAT)
+ .Input(c.node()->name(), 0, DT_INT32)
+ .Input(d.node()->name(), 0, DT_FLOAT)
+ .Input(u.node()->name(), 0, DT_RESOURCE)
+ .Input(v.node()->name(), 0, DT_RESOURCE)
+ .Input(w.node()->name(), 0, DT_RESOURCE)
+ .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
+ .Attr("_variable_start_index", 4)
+ .Finalize(&def));
+
+ Status status;
+ Node* launch = scope.graph()->AddNode(def, &status);
+ TF_CHECK_OK(status);
+ TF_CHECK_OK(scope.DoShapeInference(launch));
+ scope.graph()->AddEdge(a.node(), 0, launch, 0);
+ scope.graph()->AddEdge(b.node(), 0, launch, 1);
+ scope.graph()->AddEdge(c.node(), 0, launch, 2);
+ scope.graph()->AddEdge(d.node(), 0, launch, 3);
+ scope.graph()->AddEdge(u.node(), 0, launch, 4);
+ scope.graph()->AddEdge(v.node(), 0, launch, 5);
+ scope.graph()->AddEdge(w.node(), 0, launch, 6);
+
+ auto out0 =
+ ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
+ auto out1 =
+ ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
+ auto out2 =
+ ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
+ auto out3 =
+ ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
+
+ auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+ auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+ auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+ auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+ auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+ auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ return graph;
+}
+
+// Makes an encapsulate body graph for use in tests.
+static std::unique_ptr<Graph> MakeBodyGraph() {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+
+ auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
+ auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
+ auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
+ auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
+
+ auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
+ auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
+ auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ };
+
+ auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
+
+ auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
+ add_attrs(read_u.node());
+ auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
+ add_attrs(read_v.node());
+ auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
+ add_attrs(read_w.node());
+
+ auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
+ add_attrs(e.node());
+ auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+ add_attrs(f.node());
+ auto g = ops::Add(scope.WithOpName("G"), f, arg3);
+ add_attrs(g.node());
+
+ auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
+ b_identity, 0);
+ auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
+ auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
+ auto out3 =
+ ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ return graph;
+}
+
+TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
+ // Test that control edge insertion order doesn't affect the cache key
+ // (cluster name) generated by TPU encapsulate pass.
+ auto get_serialized_graph = [](bool control_input_reversed,
+ bool operand_reversed) -> string {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
+ auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
+
+ ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
+ : ops::Add(scope.WithOpName("E"), a1, a0);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
+ "launch0");
+ };
+ add_attrs(e.node());
+
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ auto get_node_in_graph = [&graph](Node* node) {
+ return graph->FindNodeId(node->id());
+ };
+ // Insert control edge in different order. The order should not affect
+ // the encapsulated or serialized graph.
+ if (!control_input_reversed) {
+ graph->AddControlEdge(get_node_in_graph(a0.node()),
+ get_node_in_graph(e.node()), true);
+ graph->AddControlEdge(get_node_in_graph(a1.node()),
+ get_node_in_graph(e.node()), true);
+ } else {
+ graph->AddControlEdge(get_node_in_graph(a1.node()),
+ get_node_in_graph(e.node()), true);
+ graph->AddControlEdge(get_node_in_graph(a0.node()),
+ get_node_in_graph(e.node()), true);
+ }
+ }
+ TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+ GraphDef gdef;
+ graph->ToGraphDef(&gdef);
+ // Before serialization, sort control inputs first to remove
+ // nondeterminism.
+ SortControlInputs(&gdef);
+ string serialized;
+ SerializeToStringDeterministic(gdef, &serialized);
+ return serialized;
+ };
+
+ // Changing the order of control input shouldn't affect the graph generated.
+ EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
+ /*operand_reversed=*/false),
+ get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/false));
+
+ // Changing the order of data input should affect the graph generated.
+ EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/true),
+ get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/false));
+}
+
+TEST(EncapsulateXlaComputations, Encapsulate) {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ };
+
+ auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
+ add_attrs(b_identity.node());
+
+ auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
+ add_attrs(read_u.node());
+ auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
+ add_attrs(read_v.node());
+ auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
+ add_attrs(read_w.node());
+
+ auto e = ops::Add(scope.WithOpName("E"), a, c);
+ add_attrs(e.node());
+ auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+ add_attrs(f.node());
+ auto g = ops::Add(scope.WithOpName("G"), f, d);
+ add_attrs(g.node());
+
+ auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
+ auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
+ auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
+ auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
+
+ auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+ auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+ auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+ auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+ auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+ auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ }
+
+ std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
+ CopyGraph(*graph, graph_copy.get());
+
+ TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+
+ std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
+ string function = index.at("launch0")->type_string();
+
+ // Tests the outer graph is as expected.
+ {
+ std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
+ GraphDef expected_def;
+ outer->ToGraphDef(&expected_def);
+
+ GraphDef actual_def;
+ graph->ToGraphDef(&actual_def);
+ TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
+ }
+
+ // Tests the encapsulated body graph is as expected.
+ {
+ std::unique_ptr<Graph> body = MakeBodyGraph();
+ GraphDef expected_body_def;
+ body->ToGraphDef(&expected_body_def);
+
+ InstantiationResultForTest result;
+ TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
+
+ EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
+ DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
+ result.arg_types);
+ EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
+ result.ret_types);
+ TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
+ }
+
+ // Encapsulates the same computation again, verifies we reuse the same
+ // function. Encapsulation should be deterministic to avoid recompilation.
+ TF_ASSERT_OK(
+ EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
+ std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
+ string function_copy = index_copy.at("launch0")->type_string();
+ EXPECT_EQ(function, function_copy);
+}
+
+TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
+ std::unique_ptr<Graph> body_graph = MakeBodyGraph();
+ FunctionDefLibrary flib;
+ TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
+
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
+
+ std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
+ TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
+
+ Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ NameAttrList function;
+ function.set_name("launch0");
+ auto launch = ops::XlaLaunch(
+ scope.WithOpName("launch0"), std::initializer_list<Input>{},
+ std::initializer_list<Input>{a, b, c, d},
+ std::initializer_list<Input>{u, v, w},
+ DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
+
+ auto consumer0_a =
+ ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
+ auto consumer0_b =
+ ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
+ auto consumer0_c =
+ ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
+ auto consumer1 =
+ ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
+ auto consumer2 =
+ ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
+ auto consumer3 =
+ ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
+
+ GraphDef expected_def;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
+
+ GraphDef actual_def;
+ graph->ToGraphDef(&actual_def);
+ TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 5dcf754969..3770eea6d0 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
@@ -23,6 +24,11 @@ namespace tensorflow {
// PRE_PLACEMENT passes:
+// EncapsulateXlaComputationsPass rewrites computations generated by the
+// xla.compile() Python code into XlaLaunch nodes.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
+ EncapsulateXlaComputationsPass);
+
// from
// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
// FunctionalizeControlFlowPass: 27
@@ -32,7 +38,8 @@ namespace tensorflow {
// control flow structure (XlaIf/XlaWhile). Following passes must
// handle those FunctionDef correctly.
-// POST_REWRITE_FOR_EXEC passes:
+// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA:
+
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index f2473d98ff..1a29c3caab 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
@@ -32,4 +36,19 @@ REGISTER_OP("XlaLaunch")
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
+REGISTER_OP("XlaClusterOutput")
+ .Input("input: T")
+ // Note: when replication is supported, this op will have N outputs.
+ .Output("outputs: T")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(0));
+ }
+ return Status::OK();
+ })
+ .Doc(
+ "Operator that connects the output of an XLA computation to other "
+ "consumer graph nodes.");
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 2176eaebe4..97ed554171 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -277,9 +277,10 @@ tf_xla_py_test(
],
)
+# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors
tf_xla_py_test(
name = "concat_ops_test",
- size = "medium",
+ size = "large",
srcs = ["concat_ops_test.py"],
deps = [
":xla_test",
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 37e5318bb5..2d225ad226 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase):
ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
array_ops.concat([scalar, scalar, scalar], dim)
+ # The purpose of this is to ensure that XLA on GPU will not run out of memory
+ # with too many arguments.
+ def testConcatLargeNumberOfTensors(self):
+ with self.cached_session():
+ with self.test_scope():
+ for concat_dim in range(2):
+ params = {}
+ p = []
+ shape = np.array([7, 13])
+ num_tensors = 1001
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ placeholder = array_ops.placeholder(
+ dtypes.float32, shape=input_shape)
+ p.append(placeholder)
+ params[placeholder] = np.random.rand(*input_shape).astype(
+ np.float32)
+
+ concat_inputs = p
+ c = array_ops.concat(concat_inputs, concat_dim)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)]
+ index[concat_dim] = slice(
+ cur_offset, cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ self.assertAllEqual(result[index], params[p[i]])
+
class ConcatOffsetTest(xla_test.XLATestCase):
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index d549e7bb59..ba1e3b2b4f 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -611,6 +611,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h
index ab7cac7100..e9f02201cf 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.h
+++ b/tensorflow/compiler/tf2xla/graph_compiler.h
@@ -55,17 +55,17 @@ namespace tensorflow {
// op registration infrastructure instead of FunctionLibraryRuntime.
class GraphCompiler {
public:
- GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device,
- Graph* graph, FunctionLibraryRuntime* flib,
+ GraphCompiler(XlaCompilationDevice* device, Graph* graph,
+ FunctionLibraryRuntime* flib,
ScopedStepContainer* step_container)
- : xla_context_(xla_context),
- device_(device),
+ : device_(device),
graph_(graph),
flib_(flib),
step_container_(step_container) {}
- // Compiles the graph. The results are written in `xla_context` that is passed
- // into the compiler.
+ // Compiles the graph. The results are written in xla_context stored in the
+ // resource_manager of the 'XlaCompilationDevice' that's passed into the
+ // constructor.
Status Compile();
private:
@@ -82,7 +82,6 @@ class GraphCompiler {
// using `compiler_`.
Status CompileFunctionalNode(Node* n, OpKernelContext* op_context);
- XlaContext* xla_context_;
XlaCompilationDevice* device_;
Graph* graph_;
FunctionLibraryRuntime* flib_;
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index df17da4c1c..0d9a768a6f 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -66,6 +66,9 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ if (DataTypeIsUnsigned(dtype)) {
+ return xla::Div(x, y);
+ }
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index f410605104..0ae23aa6df 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -37,6 +37,16 @@ limitations under the License.
namespace tensorflow {
namespace {
+// Used to determine the number of Tensors allowed in a Concat op to prevent
+// going over the max gpu parameter memory size. This is an issue because concat
+// is variadic and can have an unlimited number of arguments when called.
+// Concat ops with more Tensors than this will be split into multiple concat
+// ops.
+//
+// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass
+// along with boxing large numbers of parameters.
+constexpr int64 kMaxConcatArgsPerOp = 500;
+
// --------------------------------------------------------------------------
class ConcatBaseOp : public XlaOpKernel {
public:
@@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel {
// Make a vector holding the XlaOp for each of the inputs that has non-zero
// elements.
std::vector<xla::XlaOp> input_data;
+ std::vector<xla::XlaOp> partial_concats;
int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
@@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel {
input_data.push_back(handle);
}
output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
+
+ // Concat is associative, so it can be split into many operations when too
+ // many arguments are in a single op. This is a temporary workaround for
+ // b/112613927 where too many parameters in an XlaLaunchOp later result in
+ // too many parameters to a single GPU kernel.
+ if (i && i % kMaxConcatArgsPerOp == 0) {
+ partial_concats.push_back(
+ xla::ConcatInDim(ctx->builder(), input_data, axis));
+ input_data.clear();
+ }
}
+ // Add any inputs that have not been put into another concat yet.
+ partial_concats.insert(partial_concats.end(), input_data.begin(),
+ input_data.end());
VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
- ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
+ // Don't add an additional "identity" concatenate for better readibility of
+ // IR.
+ if (partial_concats.size() == 1) {
+ ctx->SetOutput(0, partial_concats.front());
+ } else {
+ ctx->SetOutput(0,
+ xla::ConcatInDim(ctx->builder(), partial_concats, axis));
+ }
}
private:
diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc
index 3c6c9a91b6..f31bfb45a2 100644
--- a/tensorflow/compiler/tf2xla/test_util.cc
+++ b/tensorflow/compiler/tf2xla/test_util.cc
@@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name,
return Status::OK();
}
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) {
+ std::unordered_map<string, Node*> index;
+ for (Node* node : graph.nodes()) {
+ index[node->name()] = node;
+ }
+ return index;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h
index e6e4ae92ed..350a868568 100644
--- a/tensorflow/compiler/tf2xla/test_util.h
+++ b/tensorflow/compiler/tf2xla/test_util.h
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
@@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name,
const FunctionLibraryDefinition& library,
InstantiationResultForTest* result);
+// Builds a map from node name to Node* for `graph`.
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph);
+
} // namespace tensorflow
+// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for
+// equality.
+#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \
+ do { \
+ string diff; \
+ EqualGraphDefOptions eq_options; \
+ eq_options.ignore_internal_attrs = false; \
+ EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \
+ << diff << "\nActual: " << SummarizeGraphDef(actual); \
+ } while (false)
+
#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 105f3b61d5..739e47778a 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -325,8 +325,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
- GraphCompiler graph_compiler(xla_context, device, graph.get(), flib,
- step_container.get());
+ GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
TF_RETURN_IF_ERROR(graph_compiler.Compile());
// Explicitly clean up the step container, to capture the cleanup status.
step_container.reset();
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 76e36f3c46..ef70c1f8ac 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -193,6 +193,7 @@ cc_library(
":types",
":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/synchronization",
],
)
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
index 787725e884..b507a2ef79 100644
--- a/tensorflow/compiler/xla/protobuf_util.cc
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
@@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) {
return safe_file_name;
}
+std::pair<tensorflow::mutex*, std::vector<std::function<string(string)>>*>
+GetDirectoryExpanders() {
+ static auto* mutex = new tensorflow::mutex;
+ static auto* singleton = new std::vector<std::function<string(string)>>;
+ return {mutex, singleton};
+}
+
+// Runs all the directory expanders over x and returns the result.
+string Expand(string x) {
+ auto pair = GetDirectoryExpanders();
+ tensorflow::mutex_lock lock(*pair.first);
+ for (const auto& f : *pair.second) {
+ x = f(x);
+ }
+ return x;
+}
+
} // namespace
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
tensorflow::Env* env = tensorflow::Env::Default();
- TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
+ string expanded_dir = Expand(directory);
+ TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir));
string safe_file_name = SanitizeFileName(file_name) + ".pb";
- const string path = tensorflow::io::JoinPath(directory, safe_file_name);
+ const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name);
return tensorflow::WriteBinaryProto(env, path, message);
}
+void RegisterDirectoryExpander(const std::function<string(string)>& expander) {
+ auto pair = GetDirectoryExpanders();
+ tensorflow::mutex_lock lock(*pair.first);
+ pair.second->push_back(expander);
+}
+
} // namespace protobuf_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h
index 3667621367..f22fc8b849 100644
--- a/tensorflow/compiler/xla/protobuf_util.h
+++ b/tensorflow/compiler/xla/protobuf_util.h
@@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
+// Registers a function that may either expand a dirpath or forward the original
+// dirpath along as-is.
+void RegisterDirectoryExpander(const std::function<string(string)>& expander);
+
} // namespace protobuf_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index e5a6c28478..96bd2616f5 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
HloModule::CreateFromProto(instance.computation, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module));
hlo_modules.push_back(std::move(hlo_module));
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index c607aea1a8..f528e62b17 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -221,25 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
allocator = &*se_allocator;
}
- // Allocate space for the input, filter, and output of the convolution. We
- // use a ScratchAllocator for this instead of calling allocator_ directly so
- // that our allocations don't leak.
- ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(params.input_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(params.filter_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(params.output_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
- if (cross_check_enabled) {
- // Broadcast a constant to the buffer, instead of zeroing the buffer. A
- // non-zero constant is useful for the cross checking, because zero-inputs
- // may not always reveal the bugs.
- const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) {
+ const auto initialize_buffer = [&stream, cross_check_enabled](
+ DeviceMemoryBase buffer) {
+ if (cross_check_enabled) {
+ // Broadcast a constant to the buffer, instead of zeroing the buffer. A
+ // non-zero constant is useful for the cross checking, because zero-inputs
+ // may not always reveal the bugs.
CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4);
size_t left_over_bytes = buffer.size() % 4;
CHECK_EQ(0, left_over_bytes % 2);
@@ -257,19 +244,32 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
DeviceMemoryBase left_over(
static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
- };
- initialize_f16(params.input_buf);
- initialize_f16(params.filter_buf);
- initialize_f16(params.output_buf);
- } else {
- // Although we don't have evidence this matters, zero out the buffers before
- // autotuning. It's conceivable that using uninitialized memory as the
- // inputs might affect performance if e.g. the inputs contain denormals, and
- // this is easy enough.
- stream.ThenMemZero(&params.input_buf, params.input_buf.size())
- .ThenMemZero(&params.filter_buf, params.filter_buf.size())
- .ThenMemZero(&params.output_buf, params.output_buf.size());
- }
+ } else {
+ // Although we don't have evidence this matters, zero out the buffers
+ // before autotuning. It's conceivable that using uninitialized memory as
+ // the inputs might affect performance if e.g. the inputs contain
+ // denormals, and this is easy enough.
+ stream.ThenMemZero(&buffer, buffer.size());
+ }
+ };
+
+ // Allocate space for the input, filter, and output of the convolution. We
+ // use a ScratchAllocator for this instead of calling allocator_ directly so
+ // that our allocations don't leak.
+ ScratchAllocator input_output_allocator(device_ordinal, allocator);
+ TF_ASSIGN_OR_RETURN(params.input_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(input_shape)));
+ TF_ASSIGN_OR_RETURN(params.filter_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(filter_shape)));
+ TF_ASSIGN_OR_RETURN(params.output_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(output_shape)));
+
+ initialize_buffer(params.input_buf);
+ initialize_buffer(params.filter_buf);
+ initialize_buffer(params.output_buf);
DeviceMemoryBase* result_buf = [&] {
switch (params.kind) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 064b86493d..06b6d5b559 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1339,6 +1339,12 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) {
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
VLOG(2) << "Finished visiting " << hlo->ToString()
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
+ // Out of convenience the literal may have been produced with a different
+ // layout. Relayout as indicated by the HLO instruction.
+ if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
+ hlo->shape())) {
+ evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 16411eb078..01e88566a5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -2570,6 +2570,25 @@ ENTRY main {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg})));
}
+TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) {
+ // Regression test for b/114735354.
+ const string hlo_text = R"(
+HloModule SliceWithDifferentLayout
+
+ENTRY main {
+ arg = f32[2,2,2]{0,1,2} parameter(0)
+ ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+
+ Literal arg = LiteralUtil::CreateR3WithLayout<float>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ LayoutUtil::MakeLayout({0, 1, 2}));
+ Literal actual = Evaluate({&arg});
+ EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
+}
+
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
::testing::ValuesIn(use_bf16_params));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 7f090a52db..8fb17a0033 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -249,12 +249,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
-
- if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
@@ -265,11 +260,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
@@ -2350,8 +2341,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = LiteralUtil::CreateFromDimensions(
- shape.element_type(), AsInt64Slice(shape.dimensions()));
+ Literal result(shape);
TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 922ebdf0e3..b27a92f2a0 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -812,7 +812,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(module_proto, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
@@ -1160,7 +1160,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
return replicas;
}
-Status Service::MaybeDumpHloModule(const HloModule& module) const {
+Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const {
const string xla_dump_unoptimized_hlo_proto_to =
module.config().debug_options().xla_dump_unoptimized_hlo_proto_to();
if (xla_dump_unoptimized_hlo_proto_to.empty()) {
@@ -1168,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const {
}
HloProto proto = MakeHloProto(module);
return protobuf_util::DumpProtoToDirectory(
- proto, xla_dump_unoptimized_hlo_proto_to, module.name());
+ proto, xla_dump_unoptimized_hlo_proto_to,
+ StrCat(module.name(), ".unoptimized"));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 44c5248b15..1f62fad4c8 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -271,7 +271,9 @@ class Service : public ServiceInterface {
StatusOr<std::vector<se::StreamExecutor*>> Replicas(
const Backend& backend, const DeviceHandle& device_handle) const;
- Status MaybeDumpHloModule(const HloModule& module) const;
+ // Dumps the (unoptimized) module given if the corresponding DebugOptions
+ // field has been set.
+ Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const;
// Returns the device handle that represents the replicated device for a
// single computation that is not model-parallelized.
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 0b79f718d4..789dab81ed 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -1,6 +1,10 @@
TensorFlow CMake build
======================
+CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all
+platforms. For details, see the
+[TensorFlow install guide](https://www.tensorflow.org/install/).
+
This directory contains CMake files for building TensorFlow on Microsoft
Windows. [CMake](https://cmake.org) is a cross-platform tool that can
generate build scripts for multiple build systems, including Microsoft
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index baec238c62..c378b1ce8d 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -62,6 +62,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@sloppy_interleave
@@unbatch
@@unique
+
+@@AUTOTUNE
"""
from __future__ import absolute_import
@@ -91,6 +93,10 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
+
from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -113,6 +119,3 @@ from tensorflow.python.data.ops.optional_ops import Optional
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)
-
-# A constant that can be used to enable auto-tuning.
-AUTOTUNE = -1
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 4b45cc7e36..a14781cd93 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -80,6 +80,7 @@ py_library(
":batching",
":gen_dataset_ops",
":interleave_ops",
+ ":optimization",
":parsing_ops",
":shuffle_ops",
"//tensorflow/python:constant_op",
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 4114b62e29..73840452df 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -24,6 +24,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# A constant that can be used to enable auto-tuning.
+AUTOTUNE = -1
+
# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
# account for indexing) and transformation sequence.
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 4c466781f7..785b395707 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import parsing_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
@@ -214,18 +215,17 @@ def _maybe_shuffle_and_repeat(
return dataset
-def make_tf_record_dataset(
- file_pattern,
- batch_size,
- parser_fn=None,
- num_epochs=None,
- shuffle=True,
- shuffle_buffer_size=None,
- shuffle_seed=None,
- prefetch_buffer_size=None,
- num_parallel_reads=None,
- num_parallel_parser_calls=None,
- drop_final_batch=False):
+def make_tf_record_dataset(file_pattern,
+ batch_size,
+ parser_fn=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=None,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=None,
+ num_parallel_parser_calls=None,
+ drop_final_batch=False):
"""Reads and optionally parses TFRecord files into a dataset.
Provides common functionality such as batching, optional parsing, shuffling,
@@ -300,8 +300,6 @@ def make_tf_record_dataset(
parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
drop_remainder=drop_final_batch))
- if prefetch_buffer_size is None:
- prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE
if prefetch_buffer_size == 0:
return dataset
else:
@@ -323,7 +321,7 @@ def make_csv_dataset(
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
- prefetch_buffer_size=1,
+ prefetch_buffer_size=optimization.AUTOTUNE,
num_parallel_reads=1,
sloppy=False,
num_rows_for_inference=100,
@@ -386,9 +384,10 @@ def make_csv_dataset(
shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
ensures better shuffling, but increases memory usage and startup time.
shuffle_seed: Randomization seed to use for shuffling.
- prefetch_buffer_size: An int specifying the number of feature batches to
- prefetch for performance improvement. Recommended value is the number of
- batches consumed per training step.
+ prefetch_buffer_size: An int specifying the number of feature
+ batches to prefetch for performance improvement. Recommended value is the
+ number of batches consumed per training step. Defaults to auto-tune.
+
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
sloppy: If `True`, reading performance will be improved at
@@ -666,7 +665,7 @@ def make_batched_features_dataset(file_pattern,
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
- prefetch_buffer_size=1,
+ prefetch_buffer_size=optimization.AUTOTUNE,
reader_num_threads=1,
parser_num_threads=2,
sloppy_ordering=False,
@@ -739,7 +738,7 @@ def make_batched_features_dataset(file_pattern,
shuffle_seed: Randomization seed to use for shuffling.
prefetch_buffer_size: Number of feature batches to prefetch in order to
improve performance. Recommended value is the number of batches consumed
- per training step (default is 1).
+ per training step. Defaults to auto-tune.
reader_num_threads: Number of threads used to read `Example` records. If >1,
the results will be interleaved.
parser_num_threads: Number of threads to use for parsing `Example` tensors
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 87f76eaa94..aaecbb0eb1 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -485,7 +485,6 @@ py_library(
srcs = ["single_loss_example.py"],
deps = [
":step_fn",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:layers",
diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py
index 5aa19cf6a9..09b351ffa4 100644
--- a/tensorflow/contrib/distribute/python/single_loss_example.py
+++ b/tensorflow/contrib/distribute/python/single_loss_example.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import step_fn
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -59,10 +58,9 @@ def minimize_loss_example(optimizer_fn,
def dataset_fn():
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
- # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be
+ # TODO(isaprykin): batch with drop_remainder causes shapes to be
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
- return dataset.apply(
- batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True))
+ return dataset.batch(1, drop_remainder=True)
# An Optimizer instance is created either outside or inside model_fn.
outer_optimizer = None
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
index fa43e6a024..be9d551ee4 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data.h
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -25,6 +25,9 @@ extern "C" {
// TODO(aselle): Consider using "if this then that" for testing.
+// IMPORTANT: All new members of structs must be added at the end to ensure
+// backwards compatibility.
+
// Possible padding types (for convolutions)
typedef enum {
kTfLitePaddingUnknown = 0,
@@ -71,11 +74,15 @@ typedef struct {
} TfLitePoolParams;
typedef struct {
+ // Parameters for DepthwiseConv version 1 or above.
TfLitePadding padding;
int stride_width;
int stride_height;
int depth_multiplier;
TfLiteFusedActivation activation;
+ // Parameters for DepthwiseConv version 2 or above.
+ int dilation_width_factor;
+ int dilation_height_factor;
} TfLiteDepthwiseConvParams;
typedef struct {
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index eef4b6d831..f4d2839b1b 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -216,6 +216,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->depth_multiplier = conv_params->depth_multiplier();
params->activation =
parse_activation(conv_params->fused_activation_function());
+
+ params->dilation_width_factor = conv_params->dilation_w_factor();
+ params->dilation_height_factor = conv_params->dilation_h_factor();
}
*builtin_data = reinterpret_cast<void*>(params);
break;
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
index 52b17faf82..555a9cc4b0 100644
--- a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
@@ -117,6 +117,8 @@ Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
FlatBufferBuilder* fbb) {
+ // Initialized to -1.
+ // A value of -1 means this tensor will not be exported.
tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
std::vector<Offset<Tensor>> tensors;
@@ -135,15 +137,17 @@ Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
int curr_output_index = 0;
for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
tensor_index++) {
- if (!tensor_is_temporary[tensor_index]) {
+ // Temporary tensors and unused tensors will not be written.
+ if (!tensor_is_temporary[tensor_index] &&
+ unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
tensor_to_written_tensor_[tensor_index] = curr_output_index++;
}
}
for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
++tensor_index) {
- // Skip temporaries.
- if (tensor_is_temporary[tensor_index]) continue;
+ // Tensor not exported.
+ if (tensor_to_written_tensor_[tensor_index] == -1) continue;
if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
// We only need to convert non temporaries
@@ -215,7 +219,9 @@ std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
std::vector<int> output;
output.reserve(input.size());
for (int x : input) {
- output.push_back(tensor_to_written_tensor_[x]);
+ if (tensor_to_written_tensor_[x] != -1) {
+ output.push_back(tensor_to_written_tensor_[x]);
+ }
}
return output;
}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
index a98108b496..a5f14697cf 100644
--- a/tensorflow/contrib/lite/experimental/writer/writer_lib.h
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
@@ -62,6 +62,10 @@ class InterpreterWriter {
// caller to change the custom data.
TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
CustomWriter custom_writer);
+ // Tensors that are unused and shouldn't be written.
+ void SetUnusedTensors(const std::set<int>& unused_tensors) {
+ unused_tensors_ = unused_tensors;
+ }
private:
template <class T>
@@ -111,8 +115,9 @@ class InterpreterWriter {
int builtin;
std::string custom;
};
+ std::set<int> unused_tensors_;
// For every tensor index in the interpreter, the index in the written.
- // This is different due to temporary tensors not being written.
+ // This is different due to temporary and unused tensors not being written.
std::vector<int> tensor_to_written_tensor_;
// List of used opcodes
std::vector<OpCode> opcodes_;
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 347515f289..3e1ce60113 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -126,23 +126,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto compute_out_size = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+ int dilation_rate) -> int {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (image_size - effective_filter_size + stride) / stride
: 0;
};
- int out_width = compute_out_size(width, filter_width, params->stride_width);
+ int out_width = compute_out_size(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
int out_height =
- compute_out_size(height, filter_height, params->stride_height);
+ compute_out_size(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
- data->padding.height = ComputePadding(params->stride_height, 1, height,
- filter_height, out_height);
+ data->padding.height =
+ ComputePadding(params->stride_height, params->dilation_height_factor,
+ height, filter_height, out_height);
data->padding.width =
- ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+ ComputePadding(params->stride_width, params->dilation_width_factor, width,
+ filter_width, out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@@ -177,8 +182,19 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
const Dims<4>&, const float*, const Dims<4>&, int, int,
- int, int, int, float, float, float*, const Dims<4>&);
- if (kernel_type == kReference) {
+ int, int, int, int, int, float, float, float*,
+ const Dims<4>&);
+ KernelType effective_kernel_type;
+ // TODO(suharshs): Currently only the reference implementation supports
+ // dilations.
+ if ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1)) {
+ effective_kernel_type = kReference;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ if (effective_kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -188,7 +204,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(filter), GetTensorDims(filter),
GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width, data->padding.height,
params->depth_multiplier, output_activation_min, output_activation_max,
GetTensorData<float>(output), GetTensorDims(output));
}
@@ -204,9 +221,20 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
const Dims<4>&, int32, const int32*, const Dims<4>&,
- int, int, int, int, int, int32, int32, int, int32,
- int32, uint8*, const Dims<4>&);
- if (kernel_type == kReference) {
+ int, int, int, int, int, int, int, int32, int32, int,
+ int32, int32, uint8*, const Dims<4>&);
+
+ KernelType effective_kernel_type;
+ // TODO(suharshs): Currently only the reference implementation supports
+ // dilations.
+ if ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1)) {
+ effective_kernel_type = kReference;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ if (effective_kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -216,7 +244,8 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width, data->padding.height,
params->depth_multiplier, output_offset, data->output_multiplier,
data->output_shift, data->output_activation_min,
data->output_activation_max, GetTensorData<uint8_t>(output),
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index c00cafb9fb..2af26ab80a 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -30,7 +30,8 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
// stride values.
BaseDepthwiseConvolutionOpModel(const TensorData& input,
const TensorData& filter,
- const TensorData& output) {
+ const TensorData& output,
+ int dilation_factor = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -56,7 +57,8 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOptions_DepthwiseConv2DOptions,
CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
- ActivationFunctionType_NONE)
+ ActivationFunctionType_NONE,
+ dilation_factor, dilation_factor)
.Union());
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
@@ -110,6 +112,58 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) {
}));
}
+TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ DepthwiseConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
class QuantizedDepthwiseConvolutionOpModel
: public BaseDepthwiseConvolutionOpModel {
public:
@@ -207,6 +261,64 @@ TEST(QuantizedDepthwiseConvolutionOpTest,
ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}
+TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ QuantizedDepthwiseConvolutionOpModel m(
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 7f6eea2d5d..70810ca784 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -1067,6 +1067,26 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
+ // be implemented.
+ TFLITE_DCHECK(dilation_width_factor == 1);
+ TFLITE_DCHECK(dilation_height_factor == 1);
+
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index 3fd00c8930..f707279600 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1964,6 +1964,30 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ // TODO(suharshs): Optimized implementation of dilation depthwise is not
+ // supported yet.
+ TFLITE_DCHECK(dilation_width_factor == 1);
+ TFLITE_DCHECK(dilation_height_factor == 1);
+
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index 9aabee5000..bb5d590775 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -25,8 +25,9 @@ namespace reference_ops {
inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
float output_activation_min,
float output_activation_max, float* output_data,
const Dims<4>& output_dims) {
@@ -52,8 +53,9 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
float total = 0.f;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -81,6 +83,20 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index d57739279f..5e3e8997fc 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -30,8 +30,9 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
int32 output_offset, int32 output_multiplier,
int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
@@ -58,8 +59,9 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
int32 acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -90,6 +92,24 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index d5da4fcccf..f0db22d581 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -276,11 +276,15 @@ table Pool2DOptions {
}
table DepthwiseConv2DOptions {
+ // Parameters for DepthwiseConv version 1 or above.
padding:Padding;
stride_w:int;
stride_h:int;
depth_multiplier:int;
fused_activation_function:ActivationFunctionType;
+ // Parameters for DepthwiseConv version 2 or above.
+ dilation_w_factor:int = 1;
+ dilation_h_factor:int = 1;
}
table ConcatEmbeddingsOptions {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 0b9c57480e..8c086a5e67 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -2339,12 +2339,16 @@ struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable {
int32_t stride_h;
int32_t depth_multiplier;
ActivationFunctionType fused_activation_function;
+ int32_t dilation_w_factor;
+ int32_t dilation_h_factor;
DepthwiseConv2DOptionsT()
: padding(Padding_SAME),
stride_w(0),
stride_h(0),
depth_multiplier(0),
- fused_activation_function(ActivationFunctionType_NONE) {
+ fused_activation_function(ActivationFunctionType_NONE),
+ dilation_w_factor(1),
+ dilation_h_factor(1) {
}
};
@@ -2355,7 +2359,9 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
VT_STRIDE_W = 6,
VT_STRIDE_H = 8,
VT_DEPTH_MULTIPLIER = 10,
- VT_FUSED_ACTIVATION_FUNCTION = 12
+ VT_FUSED_ACTIVATION_FUNCTION = 12,
+ VT_DILATION_W_FACTOR = 14,
+ VT_DILATION_H_FACTOR = 16
};
Padding padding() const {
return static_cast<Padding>(GetField<int8_t>(VT_PADDING, 0));
@@ -2372,6 +2378,12 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
+ int32_t dilation_w_factor() const {
+ return GetField<int32_t>(VT_DILATION_W_FACTOR, 1);
+ }
+ int32_t dilation_h_factor() const {
+ return GetField<int32_t>(VT_DILATION_H_FACTOR, 1);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_PADDING) &&
@@ -2379,6 +2391,8 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
VerifyField<int32_t>(verifier, VT_STRIDE_H) &&
VerifyField<int32_t>(verifier, VT_DEPTH_MULTIPLIER) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR) &&
verifier.EndTable();
}
DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2404,6 +2418,12 @@ struct DepthwiseConv2DOptionsBuilder {
void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
+ void add_dilation_w_factor(int32_t dilation_w_factor) {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
+ }
+ void add_dilation_h_factor(int32_t dilation_h_factor) {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
+ }
explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2422,8 +2442,12 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions(
int32_t stride_w = 0,
int32_t stride_h = 0,
int32_t depth_multiplier = 0,
- ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) {
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ int32_t dilation_w_factor = 1,
+ int32_t dilation_h_factor = 1) {
DepthwiseConv2DOptionsBuilder builder_(_fbb);
+ builder_.add_dilation_h_factor(dilation_h_factor);
+ builder_.add_dilation_w_factor(dilation_w_factor);
builder_.add_depth_multiplier(depth_multiplier);
builder_.add_stride_h(stride_h);
builder_.add_stride_w(stride_w);
@@ -7064,6 +7088,8 @@ inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const
{ auto _e = stride_h(); _o->stride_h = _e; };
{ auto _e = depth_multiplier(); _o->depth_multiplier = _e; };
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; };
+ { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; };
}
inline flatbuffers::Offset<DepthwiseConv2DOptions> DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7079,13 +7105,17 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions(
auto _stride_h = _o->stride_h;
auto _depth_multiplier = _o->depth_multiplier;
auto _fused_activation_function = _o->fused_activation_function;
+ auto _dilation_w_factor = _o->dilation_w_factor;
+ auto _dilation_h_factor = _o->dilation_h_factor;
return tflite::CreateDepthwiseConv2DOptions(
_fbb,
_padding,
_stride_w,
_stride_h,
_depth_multiplier,
- _fused_activation_function);
+ _fused_activation_function,
+ _dilation_w_factor,
+ _dilation_h_factor);
}
inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 5d0895c72f..3754b58b23 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1434,6 +1434,7 @@ def make_depthwiseconv_tests(zip_path):
"input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
"filter_size": [[1, 1], [1, 2], [3, 3]],
"strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+ "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]],
"channel_multiplier": [1, 2],
"rate": [[1, 1]],
"padding": ["SAME", "VALID"],
@@ -1444,6 +1445,7 @@ def make_depthwiseconv_tests(zip_path):
"input_shape": [[1, 3, 4, 3]],
"filter_size": [[1, 1]],
"strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
+ "dilations": [[1, 1, 1, 1], [1, 2, 2, 1]],
"channel_multiplier": [2],
"rate": [[2, 2]], # Only [1, 1] is supported
"padding": ["SAME"],
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 72c71b2841..96b88b60fc 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -331,7 +331,6 @@ cc_library(
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
- "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
] + select({
# Placeholder for internal darwin rule.
@@ -348,6 +347,7 @@ tf_cc_test(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
+ "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
],
@@ -408,8 +408,11 @@ tf_cc_binary(
":toco_port",
":toco_tooling",
":types_proto_cc",
- "//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:lib",
+ # We cannot embed the core:ops dependency directly into :toco_tooling as
+ # it can conflict with downstream deps when toco is used as a library.
+ "//tensorflow/core:ops",
],
)
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index eb36b3411d..efc1007925 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1089,6 +1089,8 @@ tensorflow::Status ConvertUnsupportedOperator(
ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
} else {
LOG(INFO) << "Op node missing output type attribute: " << node.name();
+ op->output_data_types.clear();
+ break;
}
}
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 2e100e37f6..164b70f2df 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -477,6 +477,11 @@ struct DepthwiseConvOperator : Operator {
int stride_height = 0;
int stride_width = 0;
int depth_multiplier = 0;
+ // A dilation_rate of 0 is invalid and this field is an optional attribute.
+ // Thus initializing it to 1 to allow default conv behavior when the
+ // attribute is not present.
+ int dilation_width_factor = 1;
+ int dilation_height_factor = 1;
};
// Depth-to-space transform operator.
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 5486012176..1061e7c7c4 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -107,7 +107,8 @@ class DepthwiseConvolution
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDepthwiseConv2DOptions(
*builder, padding, op.stride_width, op.stride_height,
- op.depth_multiplier, activation_function);
+ op.depth_multiplier, activation_function, op.dilation_width_factor,
+ op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
@@ -118,9 +119,18 @@ class DepthwiseConvolution
op->depth_multiplier = options.depth_multiplier();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
+ op->dilation_width_factor = options.dilation_w_factor();
+ op->dilation_height_factor = options.dilation_h_factor();
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
+ if (conv_op.dilation_width_factor != 1 ||
+ conv_op.dilation_height_factor != 1) {
+ return 2;
+ }
+ return 1;
+ }
};
class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
new file mode 100644
index 0000000000..a96e2c4e1b
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
@@ -0,0 +1,702 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "6Y8E0lw5eYWm"
+ },
+ "source": [
+ "# Post Training Quantization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "CIGrZZPTZVeO"
+ },
+ "source": [
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ "\u003c/table\u003e"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "BTC1rDAuei_1"
+ },
+ "source": [
+ "## Overview\n",
+ "\n",
+ "[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) now supports\n",
+ "converting weights to 8 bit precision as part of model conversion from\n",
+ "tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n",
+ "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
+ "fly quantization and dequantization of activations to allow for:\n",
+ "\n",
+ "1. Using quantized kernels for faster implementation when available.\n",
+ "\n",
+ "2. Mixing of floating-point kernels with quantized kernels for different parts\n",
+ " of the graph.\n",
+ "\n",
+ "Note that the activations are always stored in floating point. For ops that\n",
+ "support quantized kernels, the activations are quantized to 8 bits of precision\n",
+ "dynamically prior to processing and are de-quantized to float precision after\n",
+ "processing. Depending on the model being converted, this can give a speedup over\n",
+ "pure floating point computation.\n",
+ "\n",
+ "In contrast to\n",
+ "[quantization aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)\n",
+ ", the weights are quantized post training and the activations are quantized dynamically \n",
+ "at inference in this method.\n",
+ "Therefore, the model weights are not retrained to compensate for quantization\n",
+ "induced errors. It is important to check the accuracy of the quantized model to\n",
+ "ensure that the degradation is acceptable.\n",
+ "\n",
+ "In this tutorial, we train an MNIST model from scratch, check its accuracy in\n",
+ "tensorflow and then convert the saved model into a Tensorflow Lite flatbuffer\n",
+ "with weight quantization. We finally check the\n",
+ "accuracy of the converted model and compare it to the original saved model. We\n",
+ "run the training script mnist.py from\n",
+ "[Tensorflow official mnist tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2XsEP17Zelz9"
+ },
+ "source": [
+ "## Building an MNIST model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "dDqqUIZjZjac"
+ },
+ "source": [
+ "### Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "gyqAw1M9lyab"
+ },
+ "outputs": [],
+ "source": [
+ "! pip uninstall -y tensorflow\n",
+ "! pip install -U tf-nightly"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "WsN6s5L1ieNl"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "00U0taBoe-w7"
+ },
+ "outputs": [],
+ "source": [
+ "! git clone --depth 1 https://github.com/tensorflow/models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "4XZPtSh-fUOc"
+ },
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "import os\n",
+ "\n",
+ "if sys.version_info.major \u003e= 3:\n",
+ " import pathlib\n",
+ "else:\n",
+ " import pathlib2 as pathlib\n",
+ "\n",
+ "# Add `models` to the python path.\n",
+ "models_path = os.path.join(os.getcwd(), \"models\")\n",
+ "sys.path.append(models_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "eQ6Q0qqKZogR"
+ },
+ "source": [
+ "### Train and export the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "eMsw_6HujaqM"
+ },
+ "outputs": [],
+ "source": [
+ "saved_models_root = \"/tmp/mnist_saved_model\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "hWSAjQWagIHl"
+ },
+ "outputs": [],
+ "source": [
+ "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
+ "# Note: channels_last is required here or the conversion may fail. \n",
+ "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "5NMaNZQCkW9X"
+ },
+ "source": [
+ "For the example, we only trained the model for a single epoch, so it only trains to ~96% accuracy.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xl8_fzVAZwOh"
+ },
+ "source": [
+ "### Convert to a TFLite model\n",
+ "\n",
+ "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Xp5oClaZkbtn"
+ },
+ "outputs": [],
+ "source": [
+ "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
+ "saved_model_dir"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "AT8BgkKmljOy"
+ },
+ "source": [
+ "Using the python `TocoConverter`, the saved model can be converted into a TFLite model.\n",
+ "\n",
+ "First load the model using the `TocoConverter`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "_i8B2nDZmAgQ"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir)\n",
+ "tflite_model = converter.convert()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "F2o2ZfF0aiCx"
+ },
+ "source": [
+ "Write it out to a tflite file:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "vptWZq2xnclo"
+ },
+ "outputs": [],
+ "source": [
+ "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
+ "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Ie9pQaQrn5ue"
+ },
+ "outputs": [],
+ "source": [
+ "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
+ "tflite_model_file.write_bytes(tflite_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "7BONhYtYocQY"
+ },
+ "source": [
+ "To quantize the model on export, set the `post_training_quantize` flag:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "g8PUvLWDlmmz"
+ },
+ "outputs": [],
+ "source": [
+ "# Note: If you don't have a recent tf-nightly installed, the\n",
+ "# \"post_training_quantize\" line will have no effect.\n",
+ "tf.logging.set_verbosity(tf.logging.INFO)\n",
+ "converter.post_training_quantize = True\n",
+ "tflite_quant_model = converter.convert()\n",
+ "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
+ "tflite_model_quant_file.write_bytes(tflite_quant_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "PhMmUTl4sbkz"
+ },
+ "source": [
+ "Note how the resulting file, with `post_training_quantize` set, is approximately `1/4` the size."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "JExfcfLDscu4"
+ },
+ "outputs": [],
+ "source": [
+ "!ls -lh {tflite_models_dir}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L8lQHMp_asCq"
+ },
+ "source": [
+ "## Run the TFLite models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "-5l6-ciItvX6"
+ },
+ "source": [
+ "We can run the TensorFlow Lite model using the python TensorFlow Lite\n",
+ "Interpreter. \n",
+ "\n",
+ "### load the test data\n",
+ "\n",
+ "First let's load the mnist test data to feed to it:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "eTIuU07NuKFL"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()\n",
+ "images, labels = tf.to_float(mnist_test[0])/255.0, mnist_test[1]\n",
+ "\n",
+ "# Note: If you change the batch size, then use \n",
+ "# `tf.contrib.lite.Interpreter.resize_tensor_input` to also change it for\n",
+ "# the interpreter.\n",
+ "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Ap_jE7QRvhPf"
+ },
+ "source": [
+ "### Load the model into an interpreter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Jn16Rc23zTss"
+ },
+ "outputs": [],
+ "source": [
+ "interpreter = tf.contrib.lite.Interpreter(model_path=str(tflite_model_file))\n",
+ "interpreter.allocate_tensors()\n",
+ "input_index = interpreter.get_input_details()[0][\"index\"]\n",
+ "output_index = interpreter.get_output_details()[0][\"index\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "J8Pztk1mvNVL"
+ },
+ "outputs": [],
+ "source": [
+ "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
+ "interpreter_quant = tf.contrib.lite.Interpreter(model_path=str(tflite_model_quant_file))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Afl6yGvWyqAr"
+ },
+ "outputs": [],
+ "source": [
+ "interpreter_quant.allocate_tensors()\n",
+ "input_index = interpreter_quant.get_input_details()[0][\"index\"]\n",
+ "output_index = interpreter_quant.get_output_details()[0][\"index\"]\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2opUt_JTdyEu"
+ },
+ "source": [
+ "### Test the model on one image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "AKslvo2kwWac"
+ },
+ "outputs": [],
+ "source": [
+ "for img, label in mnist_ds.take(1):\n",
+ " break\n",
+ "\n",
+ "interpreter.set_tensor(input_index, img)\n",
+ "interpreter.invoke()\n",
+ "predictions = interpreter.get_tensor(output_index)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "XZClM2vo3_bm"
+ },
+ "outputs": [],
+ "source": [
+ "import matplotlib.pylab as plt\n",
+ "\n",
+ "plt.imshow(img[0])\n",
+ "template = \"True:{true}, predicted:{predict}\"\n",
+ "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+ " predict=str(predictions[0,0])))\n",
+ "plt.grid(False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "LwN7uIdCd8Gw"
+ },
+ "source": [
+ "### Evaluate the models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "05aeAuWjvjPx"
+ },
+ "outputs": [],
+ "source": [
+ "def eval_model(interpreter, mnist_ds):\n",
+ " total_seen = 0\n",
+ " num_correct = 0\n",
+ "\n",
+ " for img, label in mnist_ds:\n",
+ " total_seen += 1\n",
+ " interpreter.set_tensor(input_index, img)\n",
+ " interpreter.invoke()\n",
+ " predictions = interpreter.get_tensor(output_index)\n",
+ " if predictions == label.numpy():\n",
+ " num_correct += 1\n",
+ "\n",
+ " if total_seen % 500 == 0:\n",
+ " print(\"Accuracy after %i images: %f\" %\n",
+ " (total_seen, float(num_correct) / float(total_seen)))\n",
+ "\n",
+ " return float(num_correct) / float(total_seen)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "DqXBnDfJ7qxL"
+ },
+ "outputs": [],
+ "source": [
+ "print(eval_model(interpreter_quant, mnist_ds))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Km3cY9ry8ZlG"
+ },
+ "source": [
+ "We can repeat the evaluation on the weight quantized model to obtain:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "-9cnwiPp6EGm"
+ },
+ "outputs": [],
+ "source": [
+ "print(eval_model(interpreter_quant, mnist_ds))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L7lfxkor8pgv"
+ },
+ "source": [
+ "\n",
+ "In this example, we have compressed model with no difference in the accuracy."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "M0o1FtmWeKZm"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "## Optimizing an existing model\n",
+ "\n",
+ "We now consider another example. Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n",
+ " Pre-trained frozen graph for resnet-v2-101 is available at the\n",
+ " [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md).\n",
+ "\n",
+ "We can convert the frozen graph to a TFLite flatbuffer with quantization by:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "v5p5VcNPjILQ"
+ },
+ "outputs": [],
+ "source": [
+ "archive_path = tf.keras.utils.get_file(\"resnet_v2_101.tgz\", \"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz\", extract=True)\n",
+ "archive_path = pathlib.Path(archive_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "-sxnXQuC4ThD"
+ },
+ "source": [
+ "The `info.txt` file lists the input and output names. You can also find them using TensorBoard to visually inspect the graph."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "g_Q_OMEJ4LIc"
+ },
+ "outputs": [],
+ "source": [
+ "! cat {archive_path}/resnet_v2_101_299_info.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "ujCAFhqm-C6H"
+ },
+ "outputs": [],
+ "source": [
+ "graph_def_file = pathlib.Path(archive_path).parent/\"resnet_v2_101_299_frozen.pb\"\n",
+ "input_arrays = [\"input\"] \n",
+ "output_arrays = [\"output\"]\n",
+ "converter = tf.contrib.lite.TocoConverter.from_frozen_graph(\n",
+ " str(graph_def_file), input_arrays, output_arrays, input_shapes={\"input\":[1,299,299,3]})\n",
+ "converter.post_training_quantize = True\n",
+ "resnet_tflite_file = graph_def_file.parent/\"resnet_v2_101_quantized.tflite\"\n",
+ "resnet_tflite_file.write_bytes(converter.convert())\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "vhOjeg1x9Knp"
+ },
+ "outputs": [],
+ "source": [
+ "archive_dir = str(archive_path.parent)\n",
+ "!ls -lh {archive_dir}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qqHLaqFMCjRZ"
+ },
+ "source": [
+ "\n",
+ "The model size reduces from 171 MB to 43 MB.\n",
+ "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/accuracy/ilsvrc).\n",
+ "\n",
+ "The optimized model top-1 accuracy is 76.8, the same as the floating point model."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "post-training-quant.ipynb",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
+ },
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 8f32bc2844..1a86bff5cd 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1921,6 +1921,13 @@ tf_pyclif_proto_library(
)
tf_pyclif_proto_library(
+ name = "protobuf/config_pyclif",
+ proto_lib = ":protos_all_cc",
+ proto_srcfile = "protobuf/config.proto",
+ visibility = ["//visibility:public"],
+)
+
+tf_pyclif_proto_library(
name = "protobuf/device_properties_pyclif",
proto_lib = ":protos_all_cc",
proto_srcfile = "protobuf/device_properties.proto",
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 7f260b3139..4475fa979e 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -561,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph(
grappler::GrapplerItem item;
item.id = "tf_graph";
graph_->ToGraphDef(&item.graph);
+ // TODO(b/114748242): Add a unit test to test this bug fix.
+ if (flib_def_) {
+ *item.graph.mutable_library() = flib_def_->ToProto();
+ }
item.fetch.insert(item.fetch.end(),
options.callable_options.fetch().begin(),
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 2e644fe987..f5b0105862 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d273eddf81..56c8339d57 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -260,13 +260,13 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
}
bool IsEnqueue(const NodeDef& n) {
- return (n.op().find("Enqueue") != std::string::npos &&
- n.op().find("EnqueueMany") == std::string::npos);
+ return (n.op().find("Enqueue") != string::npos &&
+ n.op().find("EnqueueMany") == string::npos);
}
bool IsDequeue(const NodeDef& n) {
- return (n.op().find("Dequeue") != std::string::npos &&
- n.op().find("DequeueMany") == std::string::npos);
+ return (n.op().find("Dequeue") != string::npos &&
+ n.op().find("DequeueMany") == string::npos);
}
bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index aad00ce039..83434ea40f 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -127,7 +127,7 @@ static void ExtractExtraProperties(
// For filename input, the file size can also be useful.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("filename") != std::string::npos) {
+ op_def->input_arg(i).name().find("filename") != string::npos) {
Tensor tensor;
if (!tensor.FromProto(t)) {
continue;
@@ -153,7 +153,7 @@ static void ExtractExtraProperties(
// When the input is a handle (e.g. look up table handle), the information
// in the op itself is not sufficient to predict the op memory.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("handle") != std::string::npos) {
+ op_def->input_arg(i).name().find("handle") != string::npos) {
string new_key = strings::StrCat("parent_", i, "_op");
AttrValue attr;
attr.set_s(input_node->op());
@@ -320,8 +320,8 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) {
buckets_.begin(), std::plus<uint64>());
}
-std::string TensorSizeHistogram::ToString() const {
- std::string r;
+string TensorSizeHistogram::ToString() const {
+ string r;
char buf[200];
snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_);
r.append(buf);
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index d2c7c67666..5fd6717712 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -80,7 +80,7 @@ class TensorSizeHistogram {
uint64 Max() const { return max_; }
uint64 NumElem() const { return num_elem_; }
uint64 SumElem() const { return sum_elem_; }
- std::string ToString() const;
+ string ToString() const;
protected:
const int Index(const uint64 value) const;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 02a379fca8..80889afc86 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -1999,13 +1999,13 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
// Helper lambda to extract port num from _Send and _Recv op name.
auto get_port_num = [](const string& name) -> int {
- if (name.find("bn_0") != std::string::npos) {
+ if (name.find("bn_0") != string::npos) {
return 0;
- } else if (name.find("bn_1") != std::string::npos) {
+ } else if (name.find("bn_1") != string::npos) {
return 1;
- } else if (name.find("bn_2") != std::string::npos) {
+ } else if (name.find("bn_2") != string::npos) {
return 2;
- } else if (name.find("bn_minus1") != std::string::npos) {
+ } else if (name.find("bn_minus1") != string::npos) {
return -1;
}
return -999;
diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc
index 5029dff877..def9198a69 100644
--- a/tensorflow/core/grappler/inputs/utils.cc
+++ b/tensorflow/core/grappler/inputs/utils.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/inputs/utils.h"
-#include "tensorflow/core/platform/env.h"
#include <vector>
+#include "tensorflow/core/platform/env.h"
+
namespace tensorflow {
namespace grappler {
@@ -29,12 +30,12 @@ bool FilesExist(const std::set<string>& files) {
return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr);
}
-bool FileExists(const std::string& file, Status* status) {
+bool FileExists(const string& file, Status* status) {
*status = Env::Default()->FileExists(file);
return status->ok();
}
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result) {
Status status;
if (FileExists(graph_def_pbtxt_path, &status)) {
diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h
index 627dd5359f..4b9cb0a9ad 100644
--- a/tensorflow/core/grappler/inputs/utils.h
+++ b/tensorflow/core/grappler/inputs/utils.h
@@ -29,9 +29,9 @@ bool FilesExist(const std::vector<string>& files,
std::vector<Status>* status = nullptr);
bool FilesExist(const std::set<string>& files);
-bool FileExists(const std::string& file, Status* status);
+bool FileExists(const string& file, Status* status);
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result);
} // end namespace grappler
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index e78239bd43..3521669b63 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -491,7 +491,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
}
}
// Queue ops modify the queue which is a side effect.
- if (node.op().find("Queue") != std::string::npos) {
+ if (node.op().find("Queue") != string::npos) {
return false;
}
return !ModifiesInputsInPlace(node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 39517edc06..bc838c6659 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -581,7 +581,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -625,7 +625,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 5a7fe19265..d4ab444036 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -273,7 +273,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
string name = string(prefix);
int id = graph->node_size();
while (ContainsGraphNodeWithName(name, *graph)) {
- if (name.rfind("_generated") != std::string::npos &&
+ if (name.rfind("_generated") != string::npos &&
(name.rfind("_generated") == (name.size() - strlen("_generated")))) {
name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
} else {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 8c99598748..7ed4a67333 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -72,6 +72,16 @@ bool IsRunOnceOptimizer(const string& name) {
name == "loop_optimizer";
}
+// Check if the graphdef contains nodes that indicate TPU execution.
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (auto node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
#define MK_OPT(NAME, VALUE) \
@@ -338,6 +348,19 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
VLOG(1) << "Optimized main graph.";
+ // Skip optimizing functions if this is a TPU graph. Currently, Grappler
+ // passes do not handle TPU functions correctly in a variety of ways (Note
+ // that due to the pre-placement TPU graph rewriting passes, the TPU-related
+ // ops are encapsulated away into functions). For example, TPU graphs contain
+ // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler
+ // passes could prune that away. Grappler passes could also cause issues
+ // around shape inference. Since the desired and existing behavior is to not
+ // optimize TPU functions with Grappler, this check preserves that.
+ if (IsTPUGraphDef(*optimized_graph)) {
+ VLOG(2) << "Skipping optimizing funcs for TPU graphs";
+ return Status::OK();
+ }
+
// 2. Optimize function library
FunctionLibraryDefinition flib(OpRegistry::Global(),
optimized_graph->library());
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index 533d0bd5d2..da357339c9 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -26,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
}
}
+namespace {
+// Determines what strategy to use for increasing the buffer size limit. For
+// limits less than the threshold, an exponential increase is used, while for
+// limits greater than or equal to the threshold, a linear increase is used.
+size_t kBufferLimitThreshold = 2048;
+} // namespace
+
void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
switch (mode_) {
case Mode::kDisabled:
@@ -37,7 +44,11 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
return;
case Mode::kDownswing:
if (current_buffer_size == 0) {
- buffer_limit_ *= 2; // Increase the buffer size.
+ if (buffer_limit_ >= kBufferLimitThreshold) {
+ buffer_limit_ += kBufferLimitThreshold;
+ } else {
+ buffer_limit_ *= 2;
+ }
mode_ = Mode::kUpswing;
}
return;
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index 750efca592..ae451be7e2 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel {
errors::InvalidArgument(
"Number of channels must be 1, 3 or 4, was ", channels_));
- OP_REQUIRES(context, width > 0 && header_size >= 0,
+ OP_REQUIRES(context, width > 0,
errors::InvalidArgument("Width must be positive"));
+ OP_REQUIRES(context, height != 0,
+ errors::InvalidArgument("Height must be nonzero"));
OP_REQUIRES(context, header_size >= 0,
errors::InvalidArgument("header size must be nonnegative"));
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 1a1ed04e0d..8a100fe975 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 13)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 14)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index c6749468c8..fed07c4120 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -209,8 +209,27 @@ class Model(Network):
for metric in metrics:
metric_fn = training_utils.get_metric_function(
metric, output_shape=output_shape, loss_fn=loss_fn)
- metric_name = self._get_metric_name(
- metric, output_index, weighted=weights is not None)
+
+ if (context.executing_eagerly() and y_true is not None and
+ y_pred is not None):
+ # In eager mode, when executing metric_fn during training, we do not
+ # need to generate unique metric name and add it to the model
+ # as we have done that during compile already.
+ prefix = 'weighted_' if weights is not None else ''
+ suffix = metric_fn.name if hasattr(metric_fn,
+ 'name') else metric_fn.__name__
+ metric_name = prefix + suffix
+ else:
+ # Get metric name that is to be added to the model.
+ metric_name = self._get_metric_name(
+ metric, output_index, weighted=weights is not None)
+ # Keep track of metric name.
+ self.metrics_names.append(metric_name)
+
+ # Keep track of stateful metric attributes (name and metric function).
+ if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
+ self.stateful_metric_names.append(metric_name)
+ self.stateful_metric_functions.append(metric_fn)
with K.name_scope(metric_name):
# If both outputs and targets are available, call the metric function.
@@ -250,16 +269,10 @@ class Model(Network):
self.metrics_tensors.append(metric_result)
metric_results.append(metric_result)
- # Keep track of metric name.
- self.metrics_names.append(metric_name)
-
- # Keep track of stateful metric attributes (name and metric function).
- if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
- self.stateful_metric_names.append(metric_name)
- self.stateful_metric_functions.append(metric_fn)
- if not context.executing_eagerly():
- # Keep track of updates created by stateful metrics.
- self.metrics_updates += metric_fn.updates
+ if (isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful and
+ not context.executing_eagerly()):
+ # Keep track of updates created by stateful metrics.
+ self.metrics_updates += metric_fn.updates
return metric_results
def _handle_metrics(self,
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 380130095b..30be4131a4 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -2256,7 +2256,26 @@ class TestTrainingWithMetrics(test.TestCase):
'dense_binary_accuracy', 'dropout_mean_squared_error',
'dropout_binary_accuracy'
]
+ reference_stateful_metric_names = [
+ 'dense_binary_accuracy', 'dropout_binary_accuracy'
+ ]
+ self.assertEqual(reference_metric_names, model.metrics_names)
+ self.assertEqual(reference_stateful_metric_names,
+ model.stateful_metric_names)
+
+ # Verify that model metric names are not altered during training.
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5)
self.assertEqual(reference_metric_names, model.metrics_names)
+ self.assertEqual(reference_stateful_metric_names,
+ model.stateful_metric_names)
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness(self):
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 94c7d88b5c..a404507627 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -234,6 +234,7 @@ def create_file_writer(logdir,
"""
if logdir is None:
return SummaryWriter(None, None)
+ logdir = str(logdir)
with ops.device("cpu:0"):
if max_queue is None:
max_queue = constant_op.constant(10)
diff --git a/tensorflow/python/summary/writer/event_file_writer.py b/tensorflow/python/summary/writer/event_file_writer.py
index 2936a279bd..14dec982a6 100644
--- a/tensorflow/python/summary/writer/event_file_writer.py
+++ b/tensorflow/python/summary/writer/event_file_writer.py
@@ -62,7 +62,7 @@ class EventFileWriter(object):
filename_suffix: A string. Every event file's name is suffixed with
`filename_suffix`.
"""
- self._logdir = logdir
+ self._logdir = str(logdir)
if not gfile.IsDirectory(self._logdir):
gfile.MakeDirs(self._logdir)
self._event_queue = six.moves.queue.Queue(max_queue)
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 99bed5714f..d06c7f2d49 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -174,7 +174,7 @@ class ApiCompatibilityTest(test.TestCase):
verbose_diff_message = diff_message
else:
# Do not truncate diff
- self.maxDiffs = None # pylint: disable=invalid-name
+ self.maxDiff = None # pylint: disable=invalid-name
# Now we can run an actual proto diff.
try:
self.assertProtoEquals(expected_dict[key], actual_dict[key])
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index 48b3989d86..03a2a07fb1 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -31,6 +31,28 @@ TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-4}
# future and to use a rounder number, we set it to 1G.
export TF_PER_DEVICE_MEMORY_LIMIT_MB=1024
+# *******************************************************************
+# This section of the script is needed to
+# make things work on windows under msys.
+# *******************************************************************
+RUNFILES_MANIFEST_FILE="${TEST_SRCDIR}/MANIFEST"
+function rlocation() {
+ if is_absolute "$1" ; then
+ # If the file path is already fully specified, simply return it.
+ echo "$1"
+ elif [[ -e "$TEST_SRCDIR/$1" ]]; then
+ # If the file exists in the $TEST_SRCDIR then just use it.
+ echo "$TEST_SRCDIR/$1"
+ elif [[ -e "$RUNFILES_MANIFEST_FILE" ]]; then
+ # If a runfiles manifest file exists then use it.
+ echo "$(grep "^$1 " "$RUNFILES_MANIFEST_FILE" | sed 's/[^ ]* //')"
+ fi
+}
+
+TEST_BINARY="$(rlocation $TEST_WORKSPACE/${1#./})"
+shift
+# *******************************************************************
+
mkdir -p /var/lock
# Try to acquire any of the TF_GPU_COUNT * TF_TESTS_PER_GPU
# slots to run a test at.
@@ -46,8 +68,8 @@ for j in `seq 0 $((TF_TESTS_PER_GPU-1))`; do
# This export only works within the brackets, so it is isolated to one
# single command.
export CUDA_VISIBLE_DEVICES=$i
- echo "Running test $@ on GPU $CUDA_VISIBLE_DEVICES"
- $@
+ echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES"
+ "$TEST_BINARY" $@
)
return_code=$?
flock -u "$lock_fd"
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 65314a4a06..25698da1c9 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -106,11 +106,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_absl",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/02451914b9ad5320f81f56a89f3eef1f8683227c.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/02451914b9ad5320f81f56a89f3eef1f8683227c.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz",
],
- sha256 = "345fa25136484a9e5d918880d66ee577a9cb24377f8978d4e5a6c543706a1011",
- strip_prefix = "abseil-cpp-02451914b9ad5320f81f56a89f3eef1f8683227c",
+ sha256 = "006931f9705484041eed65189038f87931a87cff200bb296f94b3d42339c4cd9",
+ strip_prefix = "abseil-cpp-8ff1374008259719b54a8cb128ef951c02da164c",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)