aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <vomjom@vomjom.net>2017-06-08 12:33:47 -0700
committerGravatar GitHub <noreply@github.com>2017-06-08 12:33:47 -0700
commitf1c0b328700ac5ba9eb0368374e78990212c3bbe (patch)
tree3bdc6145d6aa3cc0f3438527f6cebe682c9de8d4
parent4bc9ac90d0dd82360561dae092959d12fb45c3c3 (diff)
parent50d80ddf926423c16864f886a4fd2297d7725da1 (diff)
Merge pull request #10570 from jhseu/branch_158391996
Branch 158391996
-rwxr-xr-xconfigure18
-rw-r--r--tensorflow/BUILD2
-rw-r--r--tensorflow/cc/BUILD3
-rw-r--r--tensorflow/compiler/tests/BUILD19
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc1327
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py1018
-rw-r--r--tensorflow/compiler/tests/xla_test.py16
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/arg_op.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc538
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.h39
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc39
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h4
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc29
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h44
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc41
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h11
-rw-r--r--tensorflow/compiler/xla/literal_util.h1
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/BUILD7
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc95
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc237
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc498
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h46
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc58
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc109
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h3
-rw-r--r--tensorflow/compiler/xla/shape_util.cc46
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc28
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc9
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc10
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h6
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h1
-rwxr-xr-xtensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/cluster_resolver/BUILD47
-rw-r--r--tensorflow/contrib/cluster_resolver/README.md5
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/__init__.py23
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py171
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py238
-rw-r--r--tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc218
-rw-r--r--tensorflow/contrib/framework/ops/checkpoint_ops.cc4
-rw-r--r--tensorflow/contrib/framework/python/ops/checkpoint_ops.py50
-rw-r--r--tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py239
-rw-r--r--tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj431
-rw-r--r--tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc92
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py6
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py1130
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py9
-rw-r--r--tensorflow/contrib/rnn/python/ops/gru_ops.py2
-rw-r--r--tensorflow/core/grappler/costs/virtual_placer.cc24
-rw-r--r--tensorflow/core/grappler/costs/virtual_placer.h5
-rw-r--r--tensorflow/core/grappler/costs/virtual_placer_test.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc3
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc107
-rw-r--r--tensorflow/core/kernels/fft_ops.cc55
-rw-r--r--tensorflow/core/kernels/iterator_ops.cc2
-rw-r--r--tensorflow/core/kernels/sparse_cross_op.cc92
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt14
-rw-r--r--tensorflow/core/ops/spectral_ops.cc26
-rw-r--r--tensorflow/core/public/version.h3
-rw-r--r--tensorflow/examples/ios/.gitignore (renamed from tensorflow/contrib/ios_examples/.gitignore)0
-rw-r--r--tensorflow/examples/ios/README.md (renamed from tensorflow/contrib/ios_examples/README.md)86
-rw-r--r--tensorflow/examples/ios/benchmark/AppDelegate.h (renamed from tensorflow/contrib/ios_examples/benchmark/AppDelegate.h)0
-rw-r--r--tensorflow/examples/ios/benchmark/AppDelegate.mm (renamed from tensorflow/contrib/ios_examples/benchmark/AppDelegate.mm)0
-rw-r--r--tensorflow/examples/ios/benchmark/Benchmark-Info.plist (renamed from tensorflow/contrib/ios_examples/benchmark/Benchmark-Info.plist)6
-rw-r--r--tensorflow/examples/ios/benchmark/BenchmarkViewController.h (renamed from tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.h)0
-rw-r--r--tensorflow/examples/ios/benchmark/BenchmarkViewController.mm (renamed from tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.mm)26
-rw-r--r--tensorflow/examples/ios/benchmark/BenchmarkViewController.xib (renamed from tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.xib)0
-rw-r--r--tensorflow/examples/ios/benchmark/Podfile5
-rw-r--r--tensorflow/examples/ios/benchmark/data/grace_hopper.jpg (renamed from tensorflow/contrib/ios_examples/benchmark/data/grace_hopper.jpg)bin73746 -> 73746 bytes
-rw-r--r--tensorflow/examples/ios/benchmark/ios_image_load.h (renamed from tensorflow/contrib/ios_examples/benchmark/ios_image_load.h)0
-rw-r--r--tensorflow/examples/ios/benchmark/ios_image_load.mm (renamed from tensorflow/contrib/ios_examples/benchmark/ios_image_load.mm)0
-rw-r--r--tensorflow/examples/ios/benchmark/main.mm (renamed from tensorflow/contrib/ios_examples/benchmark/main.mm)0
-rw-r--r--tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj (renamed from tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj)305
-rw-r--r--tensorflow/examples/ios/camera/CameraExampleAppDelegate.h (renamed from tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.h)0
-rw-r--r--tensorflow/examples/ios/camera/CameraExampleAppDelegate.m (renamed from tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.m)0
-rw-r--r--tensorflow/examples/ios/camera/CameraExampleViewController.h (renamed from tensorflow/contrib/ios_examples/camera/CameraExampleViewController.h)1
-rw-r--r--tensorflow/examples/ios/camera/CameraExampleViewController.mm (renamed from tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm)31
-rw-r--r--tensorflow/examples/ios/camera/Info.plist (renamed from tensorflow/contrib/ios_examples/camera/Info.plist)2
-rw-r--r--tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard (renamed from tensorflow/contrib/ios_examples/camera/en.lproj/MainStoryboard_iPhone.storyboard)0
-rw-r--r--tensorflow/examples/ios/camera/Podfile5
-rw-r--r--tensorflow/examples/ios/camera/data/grace_hopper.jpg (renamed from tensorflow/contrib/ios_examples/simple/data/grace_hopper.jpg)bin73746 -> 73746 bytes
-rw-r--r--tensorflow/examples/ios/camera/ios_image_load.h (renamed from tensorflow/contrib/ios_examples/camera/ios_image_load.h)0
-rw-r--r--tensorflow/examples/ios/camera/ios_image_load.mm (renamed from tensorflow/contrib/ios_examples/camera/ios_image_load.mm)0
-rw-r--r--tensorflow/examples/ios/camera/main.mm (renamed from tensorflow/contrib/ios_examples/camera/main.mm)0
-rw-r--r--tensorflow/examples/ios/camera/tensorflow_utils.h (renamed from tensorflow/contrib/ios_examples/camera/tensorflow_utils.h)0
-rw-r--r--tensorflow/examples/ios/camera/tensorflow_utils.mm (renamed from tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm)14
-rw-r--r--tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj412
-rw-r--r--tensorflow/examples/ios/simple/AppDelegate.h (renamed from tensorflow/contrib/ios_examples/simple/AppDelegate.h)0
-rw-r--r--tensorflow/examples/ios/simple/AppDelegate.mm (renamed from tensorflow/contrib/ios_examples/simple/AppDelegate.mm)0
-rw-r--r--tensorflow/examples/ios/simple/Podfile5
-rw-r--r--tensorflow/examples/ios/simple/RunModel-Info.plist (renamed from tensorflow/contrib/ios_examples/simple/RunModel-Info.plist)6
-rw-r--r--tensorflow/examples/ios/simple/RunModelViewController.h (renamed from tensorflow/contrib/ios_examples/simple/RunModelViewController.h)0
-rw-r--r--tensorflow/examples/ios/simple/RunModelViewController.mm (renamed from tensorflow/contrib/ios_examples/simple/RunModelViewController.mm)20
-rw-r--r--tensorflow/examples/ios/simple/RunModelViewController.xib (renamed from tensorflow/contrib/ios_examples/simple/RunModelViewController.xib)0
-rw-r--r--tensorflow/examples/ios/simple/data/grace_hopper.jpgbin0 -> 73746 bytes
-rw-r--r--tensorflow/examples/ios/simple/ios_image_load.h (renamed from tensorflow/contrib/ios_examples/simple/ios_image_load.h)0
-rw-r--r--tensorflow/examples/ios/simple/ios_image_load.mm (renamed from tensorflow/contrib/ios_examples/simple/ios_image_load.mm)0
-rw-r--r--tensorflow/examples/ios/simple/main.mm (renamed from tensorflow/contrib/ios_examples/simple/main.mm)0
-rw-r--r--tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj (renamed from tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj)197
-rw-r--r--tensorflow/go/op/wrappers.go28
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined_test.py31
-rw-r--r--tensorflow/python/estimator/canned/dnn_test.py51
-rw-r--r--tensorflow/python/estimator/canned/head.py139
-rw-r--r--tensorflow/python/estimator/canned/head_test.py149
-rw-r--r--tensorflow/python/estimator/canned/linear_test.py382
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py552
-rw-r--r--tensorflow/python/estimator/estimator.py6
-rw-r--r--tensorflow/python/estimator/estimator_test.py44
-rw-r--r--tensorflow/python/kernel_tests/fft_ops_test.py99
-rw-r--r--tensorflow/python/layers/convolutional.py6
-rw-r--r--tensorflow/python/ops/rnn.py3
-rw-r--r--tensorflow/python/ops/spectral_ops.py52
-rw-r--r--tensorflow/python/training/saver_test_utils.py12
-rw-r--r--tensorflow/tensorboard/BUILD1
-rw-r--r--tensorflow/tensorboard/components/BUILD18
-rw-r--r--tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html1
-rw-r--r--tensorflow/tensorboard/components/tf_trace_viewer/BUILD30
-rw-r--r--tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD17
-rw-r--r--tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json105
-rw-r--r--tensorflow/tensorboard/components/tf_trace_viewer/demo.html30
-rw-r--r--tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html127
-rw-r--r--tensorflow/tensorboard/components/trace_viewer.html28
-rw-r--r--tensorflow/tensorboard/components/vz_projector/bundle.html8
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-app.html4
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.html2
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html3
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-input.html4
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html5
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html4
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html2
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html2
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector.html3
-rw-r--r--tensorflow/workspace.bzl23
-rw-r--r--third_party/llvm/llvm.BUILD239
140 files changed, 6872 insertions, 3796 deletions
diff --git a/configure b/configure
index c968a1855b..e1aaddabda 100755
--- a/configure
+++ b/configure
@@ -3,6 +3,8 @@
set -e
set -o pipefail
+MIN_BAZEL_VERSION=0.4.5
+
# Find out the absolute path to where ./configure resides
pushd `dirname $0` > /dev/null
SOURCE_BASE_DIR=`pwd -P`
@@ -151,6 +153,22 @@ function setup_python {
echo "export PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" > tools/python_bin_path.sh
}
+function version {
+ echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }';
+}
+
+
+bazel version > bazel.version
+curr_bazel_version=$(head -n 1 bazel.version | cut -d ' ' -f3)
+rm -f bazel.version
+
+echo "You have bazel $curr_bazel_version installed."
+if [ "$(version "$MIN_BAZEL_VERSION")" -gt "$(version "$curr_bazel_version")" ]; then
+ echo "Please upgrade your bazel installation to version $MIN_BAZEL_VERSION or higher to build TensorFlow!"
+ echo "Exiting..."
+ exit 1
+fi
+
# This file contains customized config settings.
rm -f .tf_configure.bazelrc
touch .tf_configure.bazelrc
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index d16f9ccd14..44ce95f9da 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -243,6 +243,7 @@ filegroup(
"//tensorflow/contrib/boosted_trees/resources:all_files",
"//tensorflow/contrib/cloud:all_files",
"//tensorflow/contrib/cloud/kernels:all_files",
+ "//tensorflow/contrib/cluster_resolver:all_files",
"//tensorflow/contrib/compiler:all_files",
"//tensorflow/contrib/copy_graph:all_files",
"//tensorflow/contrib/crf:all_files",
@@ -392,6 +393,7 @@ filegroup(
"//tensorflow/tensorboard/components/tf_storage/test:all_files",
"//tensorflow/tensorboard/components/tf_tensorboard:all_files",
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
+ "//tensorflow/tensorboard/components/tf_trace_viewer:all_files",
"//tensorflow/tensorboard/components/vz_distribution_chart:all_files",
"//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files",
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 71f375d048..fbc96685c8 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -454,7 +454,6 @@ cc_library(
":client_session",
":ops",
":scope",
- "//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib_internal",
"//tensorflow/core:tensorflow",
@@ -479,7 +478,7 @@ cc_binary(
],
deps = [
":cc_ops",
- "//tensorflow/core:all_kernels",
+ "//tensorflow/core:all_kernels", # buildcleaner: keep
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 19f7ff8354..d18e51e32c 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -347,6 +347,25 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "tensor_array_ops_test",
+ size = "small",
+ srcs = ["tensor_array_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:nn_ops_gen",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:tensor_array_grad",
+ "//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "ternary_ops_test",
size = "small",
srcs = ["ternary_ops_test.py"],
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 2a71543f3f..50ac4a6c25 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -76,6 +76,7 @@ namespace {
// Command line flags: see main() below.
int64 tf_xla_random_seed = 0;
int32 tf_xla_test_repetitions = 20;
+int64 tf_xla_max_tensor_size = 100000LL;
string* tf_xla_test_device_ptr; // initial value set in main()
bool tf_xla_test_use_jit = true;
@@ -96,6 +97,11 @@ class OpTestBuilder {
// Adds an input 'tensor'.
OpTestBuilder& Input(const Tensor& tensor);
+ // Adds a random input tensor with 'type'. If 'dims' is not provided,
+ // RandomDims() is used.
+ OpTestBuilder& RandomInput(DataType type);
+ OpTestBuilder& RandomInput(DataType type, std::vector<int64> dims);
+
// Sets an attribute.
template <class T>
OpTestBuilder& Attr(StringPiece attr_name, T&& value);
@@ -116,11 +122,19 @@ class OpTestBuilder {
std::vector<string>* inputs,
std::vector<string>* outputs) const;
- const std::vector<Tensor>& inputs() const { return inputs_; }
+ struct InputDescription {
+ Tensor tensor;
+
+ DataType type = DT_INVALID;
+ bool has_dims = false;
+ std::vector<int64> dims;
+ };
+
+ const std::vector<InputDescription>& inputs() const { return inputs_; }
private:
NodeDef node_def_;
- std::vector<Tensor> inputs_;
+ std::vector<InputDescription> inputs_;
};
OpTestBuilder::OpTestBuilder(const string& op_name) {
@@ -129,7 +143,28 @@ OpTestBuilder::OpTestBuilder(const string& op_name) {
OpTestBuilder& OpTestBuilder::Input(const Tensor& tensor) {
VLOG(1) << "Adding input: " << tensor.DebugString();
- inputs_.push_back(tensor);
+ InputDescription input;
+ input.tensor = tensor;
+ inputs_.push_back(input);
+ return *this;
+}
+
+OpTestBuilder& OpTestBuilder::RandomInput(DataType type) {
+ VLOG(1) << "Adding random input: " << type;
+ InputDescription input;
+ input.type = type;
+ inputs_.push_back(input);
+ return *this;
+}
+
+OpTestBuilder& OpTestBuilder::RandomInput(DataType type,
+ std::vector<int64> dims) {
+ VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
+ InputDescription input;
+ input.type = type;
+ input.has_dims = true;
+ input.dims = std::move(dims);
+ inputs_.push_back(input);
return *this;
}
@@ -207,16 +242,30 @@ class OpTest : public ::testing::Test {
public:
OpTest();
- // Runs 'fn' up to --tf_xla_test_repetitions times, or until a failure occurs;
- // whichever happens first.
- void Repeatedly(const std::function<void(void)>& fn);
+ enum TestResult {
+ // The test saw an unrecoverable error. Don't try any more runs.
+ kFatalError,
+ // The parameters of the test were invalid (e.g., the "golden"
+ // implementation failed, or the parameters are oversize). Reruns are ok.
+ kInvalid,
+ // The test ran successfully, and we have a verdict. Does *not* mean the
+ // test passed.
+ kOk,
+ };
+
+ // Runs 'fn' up to --tf_xla_test_repetitions times, or until a test failure
+ // occurs; whichever happens first. Reruns if the TestResult is kInvalid.
+ void Repeatedly(const std::function<TestResult(void)>& fn);
// Select a random element from 'candidates'.
template <typename T>
T Choose(gtl::ArraySlice<T> candidates);
static constexpr int kDefaultMaxRank = 5;
- static constexpr int64 kDefaultMaxDimensionSize = 20LL;
+ static constexpr int64 kDefaultMaxDimensionSize = 256LL;
+
+ // Returns true if 'dims' have a size less than tf_xla_max_tensor_size.
+ bool TensorSizeIsOk(gtl::ArraySlice<int64> dims);
// Returns a random dimension size, in the range [min, max).
int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
@@ -278,8 +327,9 @@ class OpTest : public ::testing::Test {
// element-wise difference between x and y must no more than
// atol + rtol * abs(x); or both elements may be NaN or infinity. For
// non-floating-point tensors the element values must match exactly.
- void ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
- double atol = 1e-2, double rtol = 1e-2);
+ TestResult ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
+ double atol = 1e-2,
+ double rtol = 1e-2);
protected:
// Per-test state:
@@ -315,10 +365,35 @@ OpTest::OpTest() {
TF_CHECK_OK(session_->Create(def));
}
-void OpTest::Repeatedly(const std::function<void(void)>& fn) {
+void OpTest::Repeatedly(const std::function<TestResult(void)>& fn) {
int const max_repetitions = tf_xla_test_repetitions;
- for (int i = 0; !HasFailure() && i < max_repetitions; ++i) {
- fn();
+ int valid_test_runs = 0;
+ // We run up to 10 * max_repetitions times; the idea is that if we roll the
+ // dice enough times we will find some valid parameters. We want to put an
+ // upper limit on the number iterations just in case the probability of
+ // finding feasible parameters is very low.
+ for (int i = 0; !HasFailure() && i < max_repetitions * 10 &&
+ valid_test_runs < max_repetitions;
+ ++i) {
+ TestResult result = fn();
+ switch (result) {
+ case kOk:
+ ++valid_test_runs;
+ break;
+
+ case kFatalError:
+ ASSERT_TRUE(false) << "Test had fatal failure";
+ return;
+
+ case kInvalid:
+ break;
+ }
+ }
+ if (!HasFailure()) {
+ EXPECT_GE(valid_test_runs, max_repetitions)
+ << "Not enough test instances passed; this means that either the "
+ "golden implementation is buggy or the operator harness is not "
+ "producing well-formed test cases with a high probability.";
}
}
@@ -333,6 +408,14 @@ int64 OpTest::RandomDim(int64 min, int64 max) {
return size_distribution(generator());
}
+bool OpTest::TensorSizeIsOk(gtl::ArraySlice<int64> dims) {
+ int64 size = 1LL;
+ for (int64 dim : dims) {
+ size *= dim;
+ }
+ return size < tf_xla_max_tensor_size;
+}
+
std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
int64 min_size, int64 max_size) {
CHECK_LE(0, min_rank);
@@ -340,9 +423,13 @@ std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
std::uniform_int_distribution<int> rank_distribution(min_rank, max_rank);
int rank = rank_distribution(generator());
std::vector<int64> dims(rank);
- std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() {
- return RandomDim(min_size, max_size);
- });
+ // TODO(phawkins): too small a maximum tensor size could lead to an infinite
+ // loop here.
+ do {
+ std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() {
+ return RandomDim(min_size, max_size);
+ });
+ } while (!TensorSizeIsOk(dims));
return dims;
}
@@ -606,53 +693,84 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
}
}
-void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
- double atol, double rtol) {
+OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
+ const OpTestBuilder& builder, double atol, double rtol) {
+ const std::vector<OpTestBuilder::InputDescription>& inputs = builder.inputs();
+ std::vector<Tensor> input_tensors;
+ input_tensors.reserve(inputs.size());
+ for (const OpTestBuilder::InputDescription& input : inputs) {
+ if (input.type == DT_INVALID) {
+ VLOG(1) << "Input: " << input.tensor.DebugString();
+ input_tensors.push_back(input.tensor);
+ } else {
+ VLOG(1) << "Input: " << input.type << " "
+ << TensorShape(input.dims).DebugString();
+ std::vector<int64> dims;
+ if (input.has_dims) {
+ dims = input.dims;
+ } else {
+ dims = RandomDims();
+ }
+ if (!TensorSizeIsOk(dims)) {
+ VLOG(1) << "Ignoring oversize dims.";
+ return kInvalid;
+ }
+ input_tensors.push_back(RandomTensor(input.type, dims));
+ }
+ }
+
string cpu_device =
LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0"));
string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
DeviceNameUtils::ParsedName parsed_name;
- ASSERT_TRUE(
- DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name));
+ if (!DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)) {
+ LOG(ERROR) << "Could not parse device name: " << *tf_xla_test_device_ptr;
+ return kFatalError;
+ }
DeviceType test_device_type(parsed_name.type);
++num_tests_;
GraphDef graph;
std::vector<string> expected_inputs, test_inputs;
std::vector<string> expected_fetches, test_fetches;
- TF_ASSERT_OK(builder.BuildGraph(
+ Status status = builder.BuildGraph(
strings::StrCat("test", num_tests_, "_expected"), cpu_device,
/* use_jit= */ false, &graph, /* test_node_def= */ nullptr,
- &expected_inputs, &expected_fetches));
+ &expected_inputs, &expected_fetches);
+ if (!status.ok()) {
+ LOG(ERROR) << "Expected graph construction failed: " << status;
+ return kFatalError;
+ }
NodeDef* node_def;
- TF_ASSERT_OK(builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"),
- test_device, tf_xla_test_use_jit, &graph,
- &node_def, &test_inputs, &test_fetches));
+ status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"),
+ test_device, tf_xla_test_use_jit, &graph,
+ &node_def, &test_inputs, &test_fetches);
+ if (!status.ok()) {
+ LOG(ERROR) << "Test graph construction failed: " << status;
+ return kFatalError;
+ }
// Check that there's a kernel corresponding to 'node_def' on the device under
// test.
- Status status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr);
+ status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr);
if (!status.ok()) {
VLOG(1) << "Skipping test because there is no corresponding registered "
<< "kernel on the test device: " << status;
- return;
+ return kInvalid;
}
- TF_ASSERT_OK(session_->Extend(graph));
-
- const std::vector<Tensor>& input_tensors = builder.inputs();
- if (VLOG_IS_ON(1)) {
- for (const Tensor& input : input_tensors) {
- VLOG(1) << "Input: " << input.DebugString();
- }
+ status = session_->Extend(graph);
+ if (!status.ok()) {
+ LOG(ERROR) << "Session::Extend() failed: " << status;
+ return kFatalError;
}
std::vector<std::pair<string, Tensor>> expected_feeds(expected_inputs.size());
std::vector<std::pair<string, Tensor>> test_feeds(test_inputs.size());
- ASSERT_EQ(input_tensors.size(), expected_inputs.size());
- ASSERT_EQ(input_tensors.size(), test_inputs.size());
+ CHECK_EQ(input_tensors.size(), expected_inputs.size());
+ CHECK_EQ(input_tensors.size(), test_inputs.size());
for (int i = 0; i < input_tensors.size(); ++i) {
expected_feeds[i] = {expected_inputs[i], input_tensors[i]};
@@ -664,21 +782,27 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
Status s =
session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs);
if (!s.ok()) {
- VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test";
- return;
+ VLOG(1) << "Expected graph failed with status: " << s << ". Ignoring test";
+ return kInvalid;
}
for (const Tensor& expected : expected_outputs) {
VLOG(1) << "Expected: " << expected.DebugString();
}
VLOG(1) << "Running test graph";
- TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs));
+ status = session_->Run(test_feeds, test_fetches, {}, &test_outputs);
+ if (!status.ok()) {
+ LOG(ERROR) << "Test graph failed: " << status;
+ return kFatalError;
+ }
- ASSERT_EQ(expected_outputs.size(), test_outputs.size());
+ CHECK_EQ(expected_outputs.size(), test_outputs.size());
for (int j = 0; s.ok() && j < test_outputs.size(); ++j) {
s = TensorsAreClose(expected_outputs[j], test_outputs[j], atol, rtol);
}
TF_EXPECT_OK(s);
+
+ return kOk;
}
// Helper that converts 'values' to an int32 or int64 Tensor.
@@ -698,8 +822,8 @@ Tensor AsIntTensor(DataType dtype, const std::vector<int64>& values) {
TEST_F(OpTest, Abs) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Abs").Input(RandomTensor(type)).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Abs").RandomInput(type).Attr("T", type));
});
}
@@ -707,10 +831,10 @@ TEST_F(OpTest, Add) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -725,49 +849,50 @@ TEST_F(OpTest, AddN) {
builder.Attr("T", type);
builder.Attr("N", n);
for (int i = 0; i < n; ++i) {
- builder.Input(RandomTensor(type, shape));
+ builder.RandomInput(type, shape);
}
- ExpectTfAndXlaOutputsAreClose(builder);
+ return ExpectTfAndXlaOutputsAreClose(builder);
});
}
TEST_F(OpTest, All) {
Repeatedly([this]() {
- Tensor data = RandomTensor(DT_BOOL);
- Tensor indices = RandomReductionIndices(data.dims());
+ std::vector<int64> data_dims = RandomDims();
+ Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("All").Input(data).Input(indices).Attr("keep_dims",
- keep_dims));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("All")
+ .RandomInput(DT_BOOL, data_dims)
+ .Input(indices)
+ .Attr("keep_dims", keep_dims));
});
}
TEST_F(OpTest, Any) {
Repeatedly([this]() {
- Tensor data = RandomTensor(DT_BOOL);
- Tensor indices = RandomReductionIndices(data.dims());
+ std::vector<int64> data_dims = RandomDims();
+ Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Any").Input(data).Input(indices).Attr("keep_dims",
- keep_dims));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Any")
+ .RandomInput(DT_BOOL, data_dims)
+ .Input(indices)
+ .Attr("keep_dims", keep_dims));
});
}
TEST_F(OpTest, AvgPool) {
Repeatedly([this]() {
- WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
std::uniform_int_distribution<int> random_int(1, 5);
-
- int kernel_rows = random_int(generator()),
- kernel_cols = random_int(generator());
+ std::vector<int64> dims = RandomDims(4, 4, 1);
+ int kernel_rows =
+ std::uniform_int_distribution<int>(1, dims[1])(generator());
+ int kernel_cols =
+ std::uniform_int_distribution<int>(1, dims[2])(generator());
int stride_rows = random_int(generator()),
stride_cols = random_int(generator());
string padding = Choose<string>({"SAME", "VALID"});
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("AvgPool")
- .Input(
- RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows),
- RandomDim(kernel_cols), RandomDim(1)}))
+ .RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT)
.Attr("ksize", {1, kernel_rows, kernel_cols, 1})
.Attr("strides", {1, stride_rows, stride_cols, 1})
@@ -781,23 +906,28 @@ TEST_F(OpTest, AvgPool) {
TEST_F(OpTest, AvgPool3D) {
Repeatedly([this]() {
std::uniform_int_distribution<int> random_int(1, 5);
+ std::vector<int64> dims = RandomDims(5, 5, 1);
+
std::vector<int64> input_dims, kernel_dims, stride_dims;
for (int i = 0; i < 3; ++i) {
- kernel_dims.push_back(random_int(generator()));
- input_dims.push_back(RandomDim(kernel_dims.back()));
+ kernel_dims.push_back(
+ std::uniform_int_distribution<int>(1, dims[i])(generator()));
+ input_dims.push_back(dims[i]);
stride_dims.push_back(random_int(generator()));
}
+ int64 batch = dims[3];
+ int64 feature = dims[4];
string padding = Choose<string>({"SAME", "VALID"});
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("AvgPool3D")
- .Input(RandomTensor(DT_FLOAT, ImageDims(FORMAT_NHWC, RandomDim(1),
- RandomDim(1), input_dims)))
+ .RandomInput(DT_FLOAT,
+ ImageDims(FORMAT_NHWC, batch, feature, input_dims))
.Attr("T", DT_FLOAT)
.Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, kernel_dims))
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, stride_dims))
.Attr("padding", padding)
- .Attr("data_format", "NHWC"));
+ .Attr("data_format", "NDHWC"));
});
// TODO(phawkins): test NCHW format (not supported by CPU)
}
@@ -810,15 +940,15 @@ TEST_F(OpTest, AvgPoolGrad) {
AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims));
std::vector<int64> output_dims =
ImageDims(FORMAT_NHWC, batch, features, d.output_dims);
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("AvgPoolGrad")
.Input(test::AsTensor<int32>(input_dims))
- .Input(RandomTensor(DT_FLOAT, output_dims))
+ .RandomInput(DT_FLOAT, output_dims)
.Attr("T", DT_FLOAT)
.Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims))
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
- .Attr("data_format", "NHWC"));
+ .Attr("data_format", "NDHWC"));
});
}
@@ -830,15 +960,15 @@ TEST_F(OpTest, AvgPool3DGrad) {
AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims));
std::vector<int64> output_dims =
ImageDims(FORMAT_NHWC, batch, features, d.output_dims);
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("AvgPool3DGrad")
.Input(test::AsTensor<int32>(input_dims))
- .Input(RandomTensor(DT_FLOAT, output_dims))
+ .RandomInput(DT_FLOAT, output_dims)
.Attr("T", DT_FLOAT)
.Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims))
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
- .Attr("data_format", "NHWC"));
+ .Attr("data_format", "NDHWC"));
});
}
@@ -850,32 +980,23 @@ TEST_F(OpTest, BatchMatMul) {
std::vector<int64> x_dims(output_dims), y_dims(output_dims);
x_dims[ndims - 1] = inner_dim;
y_dims[ndims - 2] = inner_dim;
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
- .Input(RandomTensor(DT_FLOAT, x_dims))
- .Input(RandomTensor(DT_FLOAT, y_dims))
- .Attr("T", DT_FLOAT));
-
- std::swap(x_dims[ndims - 1], x_dims[ndims - 2]);
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
- .Input(RandomTensor(DT_FLOAT, x_dims))
- .Input(RandomTensor(DT_FLOAT, y_dims))
- .Attr("T", DT_FLOAT)
- .Attr("adj_x", true));
-
- std::swap(y_dims[ndims - 1], y_dims[ndims - 2]);
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
- .Input(RandomTensor(DT_FLOAT, x_dims))
- .Input(RandomTensor(DT_FLOAT, y_dims))
- .Attr("T", DT_FLOAT)
- .Attr("adj_x", true)
- .Attr("adj_y", true));
-
- std::swap(x_dims[ndims - 1], x_dims[ndims - 2]);
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
- .Input(RandomTensor(DT_FLOAT, x_dims))
- .Input(RandomTensor(DT_FLOAT, y_dims))
- .Attr("T", DT_FLOAT)
- .Attr("adj_y", true));
+
+ std::bernoulli_distribution random_bool;
+ bool adj_x = random_bool(generator());
+ bool adj_y = random_bool(generator());
+ if (adj_x) {
+ std::swap(x_dims[ndims - 1], x_dims[ndims - 2]);
+ }
+ if (adj_y) {
+ std::swap(y_dims[ndims - 1], y_dims[ndims - 2]);
+ }
+
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
+ .RandomInput(DT_FLOAT, x_dims)
+ .RandomInput(DT_FLOAT, y_dims)
+ .Attr("T", DT_FLOAT)
+ .Attr("adj_x", adj_x)
+ .Attr("adj_y", adj_y));
});
}
@@ -905,11 +1026,11 @@ TEST_F(OpTest, BatchToSpace) {
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
TensorShape({num_block_dims, 2})));
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
- .Input(RandomTensor(DT_FLOAT, input_dims))
- .Input(crops)
- .Attr("T", DT_FLOAT)
- .Attr("block_size", block_size));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
+ .RandomInput(DT_FLOAT, input_dims)
+ .Input(crops)
+ .Attr("T", DT_FLOAT)
+ .Attr("block_size", block_size));
});
}
@@ -942,9 +1063,9 @@ TEST_F(OpTest, BatchToSpaceND) {
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
TensorShape({num_block_dims, 2})));
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("BatchToSpaceND")
- .Input(RandomTensor(DT_FLOAT, input_dims))
+ .RandomInput(DT_FLOAT, input_dims)
.Input(test::AsTensor<int32>(
std::vector<int32>(block_dims.begin(), block_dims.end())))
.Input(crops)
@@ -954,29 +1075,32 @@ TEST_F(OpTest, BatchToSpaceND) {
TEST_F(OpTest, BiasAdd) {
Repeatedly([this]() {
- auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank));
- auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)});
+ auto x_dims = RandomDims(2, kDefaultMaxRank);
+ auto y_dims = {x_dims[x_dims.size() - 1]};
// TODO(phawkins): test both data formats.
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("BiasAdd").Input(x).Input(y).Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd")
+ .RandomInput(DT_FLOAT, x_dims)
+ .RandomInput(DT_FLOAT, y_dims)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, BiasAddGrad) {
Repeatedly([this]() {
- auto x = RandomTensor(DT_FLOAT);
// TODO(phawkins): test both data formats.
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("BiasAddGrad").Input(x).Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("BiasAddGrad").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, BiasAddV1) {
Repeatedly([this]() {
- auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank));
- auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("BiasAddV1").Input(x).Input(y).Attr("T", DT_FLOAT));
+ auto x_dims = RandomDims(2, kDefaultMaxRank);
+ auto y_dims = {x_dims[x_dims.size() - 1]};
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1")
+ .RandomInput(DT_FLOAT, x_dims)
+ .RandomInput(DT_FLOAT, y_dims)
+ .Attr("T", DT_FLOAT));
});
}
@@ -986,10 +1110,11 @@ TEST_F(OpTest, BroadcastGradientArgs) {
// DataType type = Choose<DataType>({DT_INT32, DT_INT64});
DataType type = DT_INT32;
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BroadcastGradientArgs")
- .Input(AsIntTensor(type, dims.first))
- .Input(AsIntTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("BroadcastGradientArgs")
+ .Input(AsIntTensor(type, dims.first))
+ .Input(AsIntTensor(type, dims.second))
+ .Attr("T", type));
});
}
@@ -998,18 +1123,17 @@ TEST_F(OpTest, Cast) {
DataType src_type, dst_type;
src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
- .Input(RandomTensor(src_type))
- .Attr("SrcT", src_type)
- .Attr("DstT", dst_type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
+ .RandomInput(src_type)
+ .Attr("SrcT", src_type)
+ .Attr("DstT", dst_type));
});
}
TEST_F(OpTest, Ceil) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Ceil")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Ceil").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
@@ -1029,9 +1153,9 @@ TEST_F(OpTest, Concat) {
for (int i = 0; i < n; ++i) {
std::vector<int64> shape = dims;
shape[concat_dim] = RandomDim();
- builder.Input(RandomTensor(type, shape));
+ builder.RandomInput(type, shape);
}
- ExpectTfAndXlaOutputsAreClose(builder);
+ return ExpectTfAndXlaOutputsAreClose(builder);
});
}
@@ -1051,7 +1175,7 @@ TEST_F(OpTest, ConcatOffset) {
shape[concat_dim] = RandomDim();
builder.Input(test::AsTensor<int32>(shape));
}
- ExpectTfAndXlaOutputsAreClose(builder);
+ return ExpectTfAndXlaOutputsAreClose(builder);
});
}
@@ -1064,15 +1188,15 @@ TEST_F(OpTest, Conv2D) {
int64 batch = RandomDim();
- Tensor data = RandomTensor(
- DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims));
+ std::vector<int64> data_dims =
+ ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
- Tensor kernel = RandomTensor(DT_FLOAT, {d.kernel_dims[0], d.kernel_dims[1],
- features_in, features_out});
- ExpectTfAndXlaOutputsAreClose(
+ std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
+ features_in, features_out};
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv2D")
- .Input(data)
- .Input(kernel)
+ .RandomInput(DT_FLOAT, data_dims)
+ .RandomInput(DT_FLOAT, kernel_dims)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
@@ -1087,17 +1211,17 @@ TEST_F(OpTest, Conv2DBackpropFilter) {
int features_in = random_int(generator());
int features_out = random_int(generator());
int32 batch = RandomDim();
- Tensor activations = RandomTensor(
- DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims));
- Tensor backprop = RandomTensor(
- DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims));
+ std::vector<int64> activations =
+ ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
+ std::vector<int64> backprop =
+ ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
{d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}));
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv2DBackpropFilter")
- .Input(activations)
+ .RandomInput(DT_FLOAT, activations)
.Input(kernel_shape)
- .Input(backprop)
+ .RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
@@ -1114,15 +1238,15 @@ TEST_F(OpTest, Conv2DBackpropInput) {
int32 batch = RandomDim();
Tensor in_shape = test::AsTensor<int32>(
AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
- Tensor backprop = RandomTensor(
- DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims));
- Tensor kernel = RandomTensor(DT_FLOAT, {d.kernel_dims[0], d.kernel_dims[1],
- features_in, features_out});
- ExpectTfAndXlaOutputsAreClose(
+ std::vector<int64> backprop =
+ ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
+ std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
+ features_in, features_out};
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv2DBackpropInput")
.Input(in_shape)
- .Input(kernel)
- .Input(backprop)
+ .RandomInput(DT_FLOAT, kernel)
+ .RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
@@ -1136,17 +1260,15 @@ TEST_F(OpTest, Conv3D) {
std::uniform_int_distribution<int> random_int(1, 5);
int features_in = random_int(generator());
int features_out = random_int(generator());
- Tensor data =
- RandomTensor(DT_FLOAT, {RandomDim(), d.input_dims[0], d.input_dims[1],
- d.input_dims[2], features_in});
-
- Tensor kernel =
- RandomTensor(DT_FLOAT, {d.kernel_dims[0], d.kernel_dims[1],
- d.kernel_dims[2], features_in, features_out});
- ExpectTfAndXlaOutputsAreClose(
+ std::vector<int64> data = {RandomDim(), d.input_dims[0], d.input_dims[1],
+ d.input_dims[2], features_in};
+
+ std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
+ d.kernel_dims[2], features_in, features_out};
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv3D")
- .Input(data)
- .Input(kernel)
+ .RandomInput(DT_FLOAT, data)
+ .RandomInput(DT_FLOAT, kernel)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
@@ -1160,18 +1282,18 @@ TEST_F(OpTest, Conv3DBackpropFilter) {
int features_in = random_int(generator());
int features_out = random_int(generator());
int32 batch = RandomDim(1);
- Tensor activations = RandomTensor(
- DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims));
- Tensor backprop = RandomTensor(
- DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims));
+ std::vector<int64> activations =
+ ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
+ std::vector<int64> backprop =
+ ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
Tensor kernel_shape = test::AsTensor<int32>(
AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2],
features_in, features_out}));
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv3DBackpropFilterV2")
- .Input(activations)
+ .RandomInput(DT_FLOAT, activations)
.Input(kernel_shape)
- .Input(backprop)
+ .RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
@@ -1187,16 +1309,15 @@ TEST_F(OpTest, Conv3DBackpropInput) {
int32 batch = RandomDim(1);
Tensor in_shape = test::AsTensor<int32>(
AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
- Tensor backprop = RandomTensor(
- DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims));
- Tensor kernel =
- RandomTensor(DT_FLOAT, {d.kernel_dims[0], d.kernel_dims[1],
- d.kernel_dims[2], features_in, features_out});
- ExpectTfAndXlaOutputsAreClose(
+ std::vector<int64> backprop =
+ ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
+ std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
+ d.kernel_dims[2], features_in, features_out};
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv3DBackpropInputV2")
.Input(in_shape)
- .Input(kernel)
- .Input(backprop)
+ .RandomInput(DT_FLOAT, kernel)
+ .RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
@@ -1206,9 +1327,15 @@ TEST_F(OpTest, Conv3DBackpropInput) {
TEST_F(OpTest, Diag) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Diag")
- .Input(RandomTensor(type, RandomDims(1)))
- .Attr("T", type));
+ std::vector<int64> dims;
+ // Diag causes a quadratic blowup in output size.
+ int64 size;
+ do {
+ dims = RandomDims(1);
+ size = TensorShape(dims).num_elements();
+ } while (size * size < tf_xla_max_tensor_size);
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type));
});
}
@@ -1220,9 +1347,9 @@ TEST_F(OpTest, DiagPart) {
std::vector<int64> doubled_dims(dims.size() * 2);
std::copy(dims.begin(), dims.end(), doubled_dims.begin());
std::copy(dims.begin(), dims.end(), doubled_dims.begin() + dims.size());
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart")
- .Input(RandomTensor(type, doubled_dims))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart")
+ .RandomInput(type, doubled_dims)
+ .Attr("T", type));
});
}
@@ -1230,10 +1357,10 @@ TEST_F(OpTest, Div) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -1282,27 +1409,26 @@ TEST_F(OpTest, DynamicStitch) {
std::vector<int64> dims(index_dims[i].begin(), index_dims[i].end());
std::copy(constant_dims.begin(), constant_dims.end(),
std::back_inserter(dims));
- Tensor t = RandomTensor(type, dims);
- builder.Input(t);
+ builder.RandomInput(type, dims);
}
- ExpectTfAndXlaOutputsAreClose(builder);
+ return ExpectTfAndXlaOutputsAreClose(builder);
});
}
TEST_F(OpTest, Elu) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Elu").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Elu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, EluGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad")
- .Input(RandomTensor(DT_FLOAT, dims))
- .Input(RandomTensor(DT_FLOAT, dims))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
});
}
@@ -1310,50 +1436,51 @@ TEST_F(OpTest, Equal) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
TEST_F(OpTest, Exp) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Exp").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Exp").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, ExpandDims) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- Tensor in = RandomTensor(type);
+ std::vector<int64> in_dims = RandomDims();
Tensor dim(DT_INT32, TensorShape());
- std::uniform_int_distribution<int32> d(-1 - in.dims(), in.dims());
+ std::uniform_int_distribution<int32> d(-1 - in_dims.size(), in_dims.size());
dim.scalar<int32>()() = d(generator());
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("ExpandDims").Input(in).Input(dim).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ExpandDims")
+ .RandomInput(type, in_dims)
+ .Input(dim)
+ .Attr("T", type));
});
}
TEST_F(OpTest, Fill) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- Tensor scalar = RandomTensor(type, {});
std::vector<int64> dims = RandomDims();
std::vector<int32> shape(dims.begin(), dims.end());
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Fill")
- .Input(test::AsTensor<int32>(shape))
- .Input(scalar)
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Fill")
+ .Input(test::AsTensor<int32>(shape))
+ .RandomInput(type, {})
+ .Attr("T", type));
});
}
TEST_F(OpTest, Floor) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Floor")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Floor").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
@@ -1361,10 +1488,10 @@ TEST_F(OpTest, FloorDiv) {
Repeatedly([this]() {
DataType type = DT_INT32;
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -1372,10 +1499,10 @@ TEST_F(OpTest, FloorMod) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -1383,10 +1510,10 @@ TEST_F(OpTest, Greater) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -1394,18 +1521,10 @@ TEST_F(OpTest, GreaterEqual) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
- });
-}
-
-TEST_F(OpTest, Reciprocal) {
- Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reciprocal")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -1413,9 +1532,9 @@ TEST_F(OpTest, L2Loss) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
// TODO(b/31644876): scalars currently crash.
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss")
- .Input(RandomTensor(type, RandomDims(1)))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss")
+ .RandomInput(type, RandomDims(1))
+ .Attr("T", type));
});
}
@@ -1423,10 +1542,10 @@ TEST_F(OpTest, Less) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -1434,10 +1553,10 @@ TEST_F(OpTest, LessEqual) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -1449,10 +1568,10 @@ TEST_F(OpTest, LinSpace) {
};
std::uniform_int_distribution<int> distribution(-50, 50);
DataType type = Choose<DataType>({DT_INT32, DT_INT64});
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("LinSpace")
- .Input(RandomTensor(DT_FLOAT, {}))
- .Input(RandomTensor(DT_FLOAT, {}))
+ .RandomInput(DT_FLOAT, {})
+ .RandomInput(DT_FLOAT, {})
.Input(ToScalar(type, distribution(generator())))
.Attr("T", DT_FLOAT)
.Attr("Tidx", type));
@@ -1461,62 +1580,62 @@ TEST_F(OpTest, LinSpace) {
TEST_F(OpTest, Log) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Log").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Log").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, LogicalAnd) {
Repeatedly([this]() {
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("LogicalAnd")
- .Input(RandomTensor(DT_BOOL, dims.first))
- .Input(RandomTensor(DT_BOOL, dims.second)));
+ .RandomInput(DT_BOOL, dims.first)
+ .RandomInput(DT_BOOL, dims.second));
});
}
TEST_F(OpTest, LogicalNot) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("LogicalNot").Input(RandomTensor(DT_BOOL)));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("LogicalNot").RandomInput(DT_BOOL));
});
}
TEST_F(OpTest, LogicalOr) {
Repeatedly([this]() {
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("LogicalOr")
- .Input(RandomTensor(DT_BOOL, dims.first))
- .Input(RandomTensor(DT_BOOL, dims.second)));
+ .RandomInput(DT_BOOL, dims.first)
+ .RandomInput(DT_BOOL, dims.second));
});
}
TEST_F(OpTest, LogSoftmax) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("LogSoftmax")
- .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2)))
+ .RandomInput(DT_FLOAT, RandomDims(2, 2))
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, LRN) {
Repeatedly([this]() {
- Tensor data;
// TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
- data = RandomTensor(DT_FLOAT, RandomDims(4, 4, 1, 8));
+ std::vector<int64> data_dims = RandomDims(4, 4, 1, 8);
// CuDNN requires depth_radius > 0.
- std::uniform_int_distribution<int> radius(1, data.dim_size(3));
+ std::uniform_int_distribution<int> radius(1, data_dims[3]);
std::uniform_real_distribution<float> coeff(0.01, 2.0);
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRN")
- .Input(data)
- .Attr("T", DT_FLOAT)
- .Attr("depth_radius", radius(generator()))
- .Attr("bias", coeff(generator()))
- .Attr("alpha", coeff(generator()))
- .Attr("beta", coeff(generator())));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("LRN")
+ .RandomInput(DT_FLOAT, data_dims)
+ .Attr("T", DT_FLOAT)
+ .Attr("depth_radius", radius(generator()))
+ .Attr("bias", coeff(generator()))
+ .Attr("alpha", coeff(generator()))
+ .Attr("beta", coeff(generator())));
});
}
@@ -1524,21 +1643,19 @@ TEST_F(OpTest, LRNGrad) {
Repeatedly([this]() {
// TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
std::vector<int64> dims = RandomDims(4, 4, 1, 8);
- Tensor input_grads = RandomTensor(DT_FLOAT, dims);
- Tensor input_image = RandomTensor(DT_FLOAT, dims);
- Tensor output_image = RandomTensor(DT_FLOAT, dims);
// CuDNN requires depth_radius > 0.
- std::uniform_int_distribution<int> radius(1, input_grads.dim_size(3));
+ std::uniform_int_distribution<int> radius(1, dims[3]);
std::uniform_real_distribution<float> coeff(0.0, 2.0);
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRNGrad")
- .Input(input_grads)
- .Input(input_image)
- .Input(output_image)
- .Attr("T", DT_FLOAT)
- .Attr("depth_radius", radius(generator()))
- .Attr("bias", coeff(generator()))
- .Attr("alpha", coeff(generator()))
- .Attr("beta", coeff(generator())));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("LRNGrad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT)
+ .Attr("depth_radius", radius(generator()))
+ .Attr("bias", coeff(generator()))
+ .Attr("alpha", coeff(generator()))
+ .Attr("beta", coeff(generator())));
});
}
@@ -1548,59 +1665,57 @@ TEST_F(OpTest, MatMul) {
int64 y = RandomDim();
int64 z = RandomDim();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
- .Input(RandomTensor(DT_FLOAT, {x, y}))
- .Input(RandomTensor(DT_FLOAT, {y, z}))
- .Attr("T", DT_FLOAT));
-
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
- .Input(RandomTensor(DT_FLOAT, {y, x}))
- .Input(RandomTensor(DT_FLOAT, {y, z}))
- .Attr("T", DT_FLOAT)
- .Attr("transpose_a", true));
+ std::vector<int64> a_dims = {x, y};
+ std::vector<int64> b_dims = {y, z};
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
- .Input(RandomTensor(DT_FLOAT, {x, y}))
- .Input(RandomTensor(DT_FLOAT, {z, y}))
- .Attr("T", DT_FLOAT)
- .Attr("transpose_b", true));
+ std::bernoulli_distribution random_bool;
+ bool transpose_a = random_bool(generator());
+ bool transpose_b = random_bool(generator());
+ if (transpose_a) {
+ std::swap(a_dims[0], a_dims[1]);
+ }
+ if (transpose_b) {
+ std::swap(b_dims[0], b_dims[1]);
+ }
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
- .Input(RandomTensor(DT_FLOAT, {y, x}))
- .Input(RandomTensor(DT_FLOAT, {z, y}))
- .Attr("T", DT_FLOAT)
- .Attr("transpose_a", true)
- .Attr("transpose_b", true));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
+ .RandomInput(DT_FLOAT, a_dims)
+ .RandomInput(DT_FLOAT, b_dims)
+ .Attr("T", DT_FLOAT)
+ .Attr("transpose_a", transpose_a)
+ .Attr("transpose_b", transpose_b));
});
}
TEST_F(OpTest, MatrixDiag) {
Repeatedly([this]() {
- DataType type = Choose<DataType>({DT_BOOL, DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
- .Input(RandomTensor(type, RandomDims(1)))
- .Attr("T", type));
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
+ .RandomInput(type, RandomDims(1))
+ .Attr("T", type));
});
}
TEST_F(OpTest, MatrixDiagPart) {
Repeatedly([this]() {
- DataType type = Choose<DataType>({DT_BOOL, DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
- .Input(RandomTensor(type, RandomDims(2)))
- .Attr("T", type));
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
+ .RandomInput(type, RandomDims(2))
+ .Attr("T", type));
});
}
TEST_F(OpTest, Max) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- Tensor data = RandomTensor(type);
- Tensor indices = RandomReductionIndices(data.dims());
+ std::vector<int64> data_dims = RandomDims();
+ Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Max").Input(data).Input(indices).Attr("T", type).Attr(
- "keep_dims", keep_dims));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Max")
+ .RandomInput(type, data_dims)
+ .Input(indices)
+ .Attr("T", type)
+ .Attr("keep_dims", keep_dims));
});
}
@@ -1608,26 +1723,28 @@ TEST_F(OpTest, Maximum) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
TEST_F(OpTest, MaxPool) {
Repeatedly([this]() {
std::uniform_int_distribution<int> random_int(1, 5);
- int kernel_rows = random_int(generator()),
- kernel_cols = random_int(generator());
+ std::vector<int64> dims = RandomDims(4, 4, 1);
+ int kernel_rows =
+ std::uniform_int_distribution<int>(1, dims[1])(generator());
+ int kernel_cols =
+ std::uniform_int_distribution<int>(1, dims[2])(generator());
int stride_rows = random_int(generator()),
stride_cols = random_int(generator());
+
string padding = Choose<string>({"SAME", "VALID"});
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("MaxPool")
- .Input(
- RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows),
- RandomDim(kernel_cols), RandomDim(1)}))
+ .RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT)
.Attr("ksize", {1, kernel_rows, kernel_cols, 1})
.Attr("strides", {1, stride_rows, stride_cols, 1})
@@ -1640,28 +1757,32 @@ TEST_F(OpTest, MaxPool) {
TEST_F(OpTest, MaxPool3D) {
Repeatedly([this]() {
std::uniform_int_distribution<int> random_int(1, 5);
- std::vector<int64> input_dims;
- std::vector<int32> kernel_dims, stride_dims;
- input_dims.push_back(RandomDim(1));
+ std::vector<int64> dims = RandomDims(5, 5, 1);
+
+ std::vector<int64> input_dims, kernel_dims, stride_dims;
kernel_dims.push_back(1);
stride_dims.push_back(1);
for (int i = 0; i < 3; ++i) {
- kernel_dims.push_back(random_int(generator()));
- input_dims.push_back(RandomDim(kernel_dims.back()));
+ kernel_dims.push_back(
+ std::uniform_int_distribution<int>(1, dims[i])(generator()));
+ input_dims.push_back(dims[i]);
stride_dims.push_back(random_int(generator()));
}
- input_dims.push_back(RandomDim(1));
kernel_dims.push_back(1);
stride_dims.push_back(1);
+ int64 batch = dims[3];
+ int64 feature = dims[4];
string padding = Choose<string>({"SAME", "VALID"});
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MaxPool3D")
- .Input(RandomTensor(DT_FLOAT, input_dims))
- .Attr("T", DT_FLOAT)
- .Attr("ksize", kernel_dims)
- .Attr("strides", stride_dims)
- .Attr("padding", padding)
- .Attr("data_format", "NHWC"));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("MaxPool3D")
+ .RandomInput(DT_FLOAT,
+ ImageDims(FORMAT_NHWC, batch, feature, input_dims))
+ .Attr("T", DT_FLOAT)
+ .Attr("ksize", kernel_dims)
+ .Attr("strides", stride_dims)
+ .Attr("padding", padding)
+ .Attr("data_format", "NDHWC"));
});
// TODO(phawkins): test NCHW format (not supported by CPU)
}
@@ -1671,24 +1792,28 @@ TEST_F(OpTest, Mean) {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
// TODO(phawkins): CPU and XLA differ output for reducing across a
// size-0 dimension (nan vs 0). For now, require size >= 1.
- Tensor data = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 1));
- Tensor indices = RandomReductionIndices(data.dims());
+ std::vector<int64> data_dims = RandomDims(0, kDefaultMaxRank, 1);
+ Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Mean").Input(data).Input(indices).Attr("T", type).Attr(
- "keep_dims", keep_dims));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mean")
+ .RandomInput(type, data_dims)
+ .Input(indices)
+ .Attr("T", type)
+ .Attr("keep_dims", keep_dims));
});
}
TEST_F(OpTest, Min) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- Tensor data = RandomTensor(type);
- Tensor indices = RandomReductionIndices(data.dims());
+ std::vector<int64> data_dims = RandomDims();
+ Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Min").Input(data).Input(indices).Attr("T", type).Attr(
- "keep_dims", keep_dims));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Min")
+ .RandomInput(type, data_dims)
+ .Input(indices)
+ .Attr("T", type)
+ .Attr("keep_dims", keep_dims));
});
}
@@ -1696,21 +1821,20 @@ TEST_F(OpTest, Minimum) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
TEST_F(OpTest, Mod) {
Repeatedly([this]() {
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Mod")
- .Input(RandomTensor(DT_INT32, dims.first))
- .Input(RandomTensor(DT_INT32, dims.second))
- .Attr("T", DT_INT32));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mod")
+ .RandomInput(DT_INT32, dims.first)
+ .RandomInput(DT_INT32, dims.second)
+ .Attr("T", DT_INT32));
});
}
@@ -1718,18 +1842,18 @@ TEST_F(OpTest, Mul) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
TEST_F(OpTest, Neg) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Neg").Input(RandomTensor(type)).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Neg").RandomInput(type).Attr("T", type));
});
}
@@ -1737,10 +1861,10 @@ TEST_F(OpTest, NotEqual) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -1768,9 +1892,17 @@ TEST_F(OpTest, OneHot) {
builder.Attr("axis", axis);
builder.Input(indices);
builder.Input(test::AsScalar<int32>(depth));
- builder.Input(RandomTensor(type, {}));
- builder.Input(RandomTensor(type, {}));
- ExpectTfAndXlaOutputsAreClose(builder);
+ builder.RandomInput(type, {});
+ builder.RandomInput(type, {});
+ return ExpectTfAndXlaOutputsAreClose(builder);
+ });
+}
+
+TEST_F(OpTest, OnesLike) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type));
});
}
@@ -1789,9 +1921,9 @@ TEST_F(OpTest, Pack) {
builder.Attr("N", n);
builder.Attr("axis", axis);
for (int i = 0; i < n; ++i) {
- builder.Input(RandomTensor(type, dims));
+ builder.RandomInput(type, dims);
}
- ExpectTfAndXlaOutputsAreClose(builder);
+ return ExpectTfAndXlaOutputsAreClose(builder);
});
}
@@ -1799,23 +1931,26 @@ TEST_F(OpTest, Pack) {
TEST_F(OpTest, Pad) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- Tensor t = RandomTensor(type);
+ std::vector<int64> t_dims = RandomDims();
// TODO(b/31741996): re-enable DT_INT64 when bug is fixed.
// DataType tpaddings = Choose<DataType>({DT_INT32, DT_INT64});
DataType tpaddings = DT_INT32;
std::vector<int64> paddings_vec;
std::uniform_int_distribution<int> distribution(0, 7);
- for (int i = 0; i < t.dims(); ++i) {
+ for (int i = 0; i < t_dims.size(); ++i) {
paddings_vec.push_back(distribution(generator()));
paddings_vec.push_back(distribution(generator()));
}
Tensor paddings;
- CHECK(paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec),
- TensorShape({t.dims(), 2})));
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Pad").Input(t).Input(paddings).Attr("T", type).Attr(
- "Tpaddings", tpaddings));
+ CHECK(
+ paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec),
+ TensorShape({static_cast<int64>(t_dims.size()), 2})));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pad")
+ .RandomInput(type, t_dims)
+ .Input(paddings)
+ .Attr("T", type)
+ .Attr("Tpaddings", tpaddings));
});
}
@@ -1824,23 +1959,24 @@ TEST_F(OpTest, Pow) {
// nontermination.
Repeatedly([this]() {
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Pow")
- .Input(RandomTensor(DT_FLOAT, dims.first))
- .Input(RandomTensor(DT_FLOAT, dims.second))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow")
+ .RandomInput(DT_FLOAT, dims.first)
+ .RandomInput(DT_FLOAT, dims.second)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Prod) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- Tensor data = RandomTensor(type);
- Tensor indices = RandomReductionIndices(data.dims());
+ std::vector<int64> data_dims = RandomDims();
+ Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Prod").Input(data).Input(indices).Attr("T", type).Attr(
- "keep_dims", keep_dims));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Prod")
+ .RandomInput(type, data_dims)
+ .Input(indices)
+ .Attr("T", type)
+ .Attr("keep_dims", keep_dims));
});
}
@@ -1855,7 +1991,7 @@ TEST_F(OpTest, Range) {
};
std::uniform_int_distribution<int> distribution(-50, 50);
DataType tidx = Choose<DataType>({DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE});
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Range")
.Input(ToScalar(tidx, distribution(generator())))
.Input(ToScalar(tidx, distribution(generator())))
@@ -1867,8 +2003,8 @@ TEST_F(OpTest, Range) {
TEST_F(OpTest, Rank) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Rank").Input(RandomTensor(type)).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Rank").RandomInput(type).Attr("T", type));
});
}
@@ -1876,46 +2012,51 @@ TEST_F(OpTest, RealDiv) {
Repeatedly([this]() {
DataType type = DT_FLOAT;
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Reciprocal) {
+ Repeatedly([this]() {
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Reciprocal").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Relu) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Relu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Relu6) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Relu6").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Relu6Grad) {
Repeatedly([this]() {
auto dims = RandomDims(1);
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad")
- .Input(RandomTensor(DT_FLOAT, dims))
- .Input(RandomTensor(DT_FLOAT, dims))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, ReluGrad) {
Repeatedly([this]() {
auto dims = RandomDims(1);
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad")
- .Input(RandomTensor(DT_FLOAT, dims))
- .Input(RandomTensor(DT_FLOAT, dims))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
});
}
@@ -1937,10 +2078,9 @@ TEST_F(OpTest, Reshape) {
}
}
}
- Tensor data = RandomTensor(type, dims_before);
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Reshape")
- .Input(data)
+ .RandomInput(type, dims_before)
.Input(test::AsTensor<int32>(
std::vector<int32>(dims_after.begin(), dims_after.end())))
.Attr("T", type));
@@ -1952,56 +2092,54 @@ TEST_F(OpTest, Reverse) {
std::vector<int64> dims = RandomDims(1);
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
int64 rank = dims.size();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
- .Input(RandomTensor(type, dims))
- .Input(RandomTensor(DT_BOOL, {rank}))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
+ .RandomInput(type, dims)
+ .RandomInput(DT_BOOL, {rank})
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, ReverseV2) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- Tensor data = RandomTensor(type);
- Tensor indices = RandomReductionIndices(data.dims());
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
- .Input(data)
- .Input(indices)
- .Attr("T", DT_FLOAT));
+ std::vector<int64> data_dims = RandomDims();
+ Tensor indices = RandomReductionIndices(data_dims.size());
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
+ .RandomInput(type, data_dims)
+ .Input(indices)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Round) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Round")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Round").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Rsqrt) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Rsqrt")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Rsqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, RsqrtGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
- .Input(RandomTensor(DT_FLOAT, dims))
- .Input(RandomTensor(DT_FLOAT, dims))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Shape) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Shape").Input(RandomTensor(type)).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Shape").RandomInput(type).Attr("T", type));
});
}
@@ -2013,72 +2151,72 @@ TEST_F(OpTest, ShapeN) {
builder.Attr("T", type);
builder.Attr("N", n);
for (int i = 0; i < n; ++i) {
- builder.Input(RandomTensor(type));
+ builder.RandomInput(type);
}
- ExpectTfAndXlaOutputsAreClose(builder);
+ return ExpectTfAndXlaOutputsAreClose(builder);
});
}
TEST_F(OpTest, Sigmoid) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sigmoid")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Sigmoid").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SigmoidGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
- .Input(RandomTensor(DT_FLOAT, dims))
- .Input(RandomTensor(DT_FLOAT, dims))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Sign) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Sign").Input(RandomTensor(type)).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Sign").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Size) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Size").Input(RandomTensor(type)).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Size").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Slice) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- Tensor data = RandomTensor(type);
+ std::vector<int64> data_dims = RandomDims();
- std::vector<int32> begin(data.dims()), size(data.dims());
- for (int i = 0; i < data.dims(); ++i) {
- begin[i] = std::uniform_int_distribution<int32>(
- 0, data.dim_size(i))(generator());
+ std::vector<int32> begin(data_dims.size()), size(data_dims.size());
+ for (int i = 0; i < data_dims.size(); ++i) {
+ begin[i] =
+ std::uniform_int_distribution<int32>(0, data_dims[i])(generator());
size[i] = std::uniform_int_distribution<int32>(
- -1, data.dim_size(i) - begin[i])(generator());
+ -1, data_dims[i] - begin[i])(generator());
}
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice")
- .Input(data)
- .Input(test::AsTensor<int32>(begin))
- .Input(test::AsTensor<int32>(size))
- .Attr("T", type)
- .Attr("Index", DT_INT32));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Slice")
+ .RandomInput(type, data_dims)
+ .Input(test::AsTensor<int32>(begin))
+ .Input(test::AsTensor<int32>(size))
+ .Attr("T", type)
+ .Attr("Index", DT_INT32));
});
}
TEST_F(OpTest, Softmax) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Softmax")
- .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2)))
+ .RandomInput(DT_FLOAT, RandomDims(2, 2))
.Attr("T", DT_FLOAT));
});
}
@@ -2086,28 +2224,28 @@ TEST_F(OpTest, Softmax) {
TEST_F(OpTest, SoftmaxCrossEntropyWithLogits) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims(2, 2, 1);
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftmaxCrossEntropyWithLogits")
- .Input(RandomTensor(DT_FLOAT, dims))
- .Input(RandomTensor(DT_FLOAT, dims))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("SoftmaxCrossEntropyWithLogits")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Softplus) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Softplus")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Softplus").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SoftplusGrad) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad")
- .Input(RandomTensor(DT_FLOAT, dims))
- .Input(RandomTensor(DT_FLOAT, dims))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
});
}
@@ -2141,11 +2279,11 @@ TEST_F(OpTest, SpaceToBatch) {
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
TensorShape({num_block_dims, 2})));
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
- .Input(RandomTensor(DT_FLOAT, input_dims))
- .Input(paddings)
- .Attr("T", DT_FLOAT)
- .Attr("block_size", block_size));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
+ .RandomInput(DT_FLOAT, input_dims)
+ .Input(paddings)
+ .Attr("T", DT_FLOAT)
+ .Attr("block_size", block_size));
});
}
@@ -2182,9 +2320,9 @@ TEST_F(OpTest, SpaceToBatchND) {
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
TensorShape({num_block_dims, 2})));
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("SpaceToBatchND")
- .Input(RandomTensor(DT_FLOAT, input_dims))
+ .RandomInput(DT_FLOAT, input_dims)
.Input(test::AsTensor<int32>(
std::vector<int32>(block_dims.begin(), block_dims.end())))
.Input(paddings)
@@ -2198,33 +2336,26 @@ TEST_F(OpTest, SparseMatMul) {
int64 y = RandomDim();
int64 z = RandomDim();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
- .Input(RandomTensor(DT_FLOAT, {x, y}))
- .Input(RandomTensor(DT_FLOAT, {y, z}))
- .Attr("Ta", DT_FLOAT)
- .Attr("Tb", DT_FLOAT));
-
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
- .Input(RandomTensor(DT_FLOAT, {y, x}))
- .Input(RandomTensor(DT_FLOAT, {y, z}))
- .Attr("Ta", DT_FLOAT)
- .Attr("Tb", DT_FLOAT)
- .Attr("transpose_a", true));
+ std::vector<int64> a_dims = {x, y};
+ std::vector<int64> b_dims = {y, z};
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
- .Input(RandomTensor(DT_FLOAT, {x, y}))
- .Input(RandomTensor(DT_FLOAT, {z, y}))
- .Attr("Ta", DT_FLOAT)
- .Attr("Tb", DT_FLOAT)
- .Attr("transpose_b", true));
+ std::bernoulli_distribution random_bool;
+ bool transpose_a = random_bool(generator());
+ bool transpose_b = random_bool(generator());
+ if (transpose_a) {
+ std::swap(a_dims[0], a_dims[1]);
+ }
+ if (transpose_b) {
+ std::swap(b_dims[0], b_dims[1]);
+ }
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
- .Input(RandomTensor(DT_FLOAT, {y, x}))
- .Input(RandomTensor(DT_FLOAT, {z, y}))
- .Attr("Ta", DT_FLOAT)
- .Attr("Tb", DT_FLOAT)
- .Attr("transpose_a", true)
- .Attr("transpose_b", true));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
+ .RandomInput(DT_FLOAT, a_dims)
+ .RandomInput(DT_FLOAT, b_dims)
+ .Attr("Ta", DT_FLOAT)
+ .Attr("Tb", DT_FLOAT)
+ .Attr("transpose_a", transpose_a)
+ .Attr("transpose_b", transpose_b));
});
}
@@ -2240,9 +2371,9 @@ TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) {
std::uniform_int_distribution<int32>(0, num_classes - 1)(generator());
}
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits")
- .Input(RandomTensor(DT_FLOAT, dims))
+ .RandomInput(DT_FLOAT, dims)
.Input(test::AsTensor<int32>(indices))
.Attr("T", DT_FLOAT)
.Attr("Tlabels", DT_INT32));
@@ -2260,56 +2391,54 @@ TEST_F(OpTest, Split) {
// Ensure 'dim' is evenly divisible by 'n'.
dims[dim] /= n;
dims[dim] *= n;
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split")
- .Input(test::AsScalar<int32>(dim))
- .Input(RandomTensor(type, dims))
- .Attr("T", type)
- .Attr("num_split", n));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split")
+ .Input(test::AsScalar<int32>(dim))
+ .RandomInput(type, dims)
+ .Attr("T", type)
+ .Attr("num_split", n));
});
}
TEST_F(OpTest, Sqrt) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sqrt")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Sqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SquaredDifference) {
Repeatedly([this]() {
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("SquaredDifference")
- .Input(RandomTensor(DT_FLOAT, dims.first))
- .Input(RandomTensor(DT_FLOAT, dims.second))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SquaredDifference")
+ .RandomInput(DT_FLOAT, dims.first)
+ .RandomInput(DT_FLOAT, dims.second)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Square) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Square").Input(RandomTensor(type)).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Square").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Squeeze) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- Tensor t = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 0, 5));
+ std::vector<int64> t_dims = RandomDims(0, kDefaultMaxRank, 0, 5);
std::bernoulli_distribution random_bool;
std::vector<int> squeeze_dims;
- for (int i = 0; i < t.dims(); ++i) {
- if (t.dim_size(i) == 1 && random_bool(generator())) {
+ for (int i = 0; i < t_dims.size(); ++i) {
+ if (t_dims[i] == 1 && random_bool(generator())) {
squeeze_dims.push_back(i);
}
}
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze")
- .Input(t)
- .Attr("squeeze_dims", squeeze_dims)
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze")
+ .RandomInput(type, t_dims)
+ .Attr("squeeze_dims", squeeze_dims)
+ .Attr("T", type));
});
}
@@ -2317,58 +2446,59 @@ TEST_F(OpTest, Sub) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
TEST_F(OpTest, Sum) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- Tensor data = RandomTensor(type);
- Tensor indices = RandomReductionIndices(data.dims());
+ std::vector<int64> data_dims = RandomDims();
+ Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("Sum").Input(data).Input(indices).Attr("T", type).Attr(
- "keep_dims", keep_dims));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sum")
+ .RandomInput(type, data_dims)
+ .Input(indices)
+ .Attr("T", type)
+ .Attr("keep_dims", keep_dims));
});
}
TEST_F(OpTest, StridedSlice) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- Tensor data = RandomTensor(type);
-
- std::vector<int32> begin(data.dims()), end(data.dims());
- std::vector<int32> strides(data.dims());
- for (int i = 0; i < data.dims(); ++i) {
+ std::vector<int64> data_dims = RandomDims();
+ std::vector<int32> begin(data_dims.size()), end(data_dims.size());
+ std::vector<int32> strides(data_dims.size());
+ for (int i = 0; i < data_dims.size(); ++i) {
begin[i] = std::uniform_int_distribution<int32>(
- -2 * data.dim_size(i), 2 * data.dim_size(i))(generator());
+ -2 * data_dims[i], 2 * data_dims[i])(generator());
end[i] = std::uniform_int_distribution<int32>(
- -2 * data.dim_size(i), 2 * data.dim_size(i))(generator());
+ -2 * data_dims[i], 2 * data_dims[i])(generator());
// TODO(b/31360685): support strides other than 1 or -1
strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1;
}
- int64 max_bitmask = (1LL << data.dims()) - 1;
+ int64 max_bitmask = (1LL << data_dims.size()) - 1;
std::uniform_int_distribution<int64> bitmask_distribution(0, max_bitmask);
int64 begin_mask = bitmask_distribution(generator());
int64 end_mask = bitmask_distribution(generator());
// Create a ellipsis bitmask with at most one 1 bit set.
int64 ellipsis_mask = 0;
- if (data.dims() > 0 && std::bernoulli_distribution()(generator())) {
- int ellipsis_pos =
- std::uniform_int_distribution<int>(0, data.dims() - 1)(generator());
+ if (!data_dims.empty() && std::bernoulli_distribution()(generator())) {
+ int ellipsis_pos = std::uniform_int_distribution<int>(
+ 0, data_dims.size() - 1)(generator());
ellipsis_mask = 1LL << ellipsis_pos;
}
int64 new_axis_mask = bitmask_distribution(generator());
int64 shrink_axis_mask = bitmask_distribution(generator());
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("StridedSlice")
- .Input(data)
+ .RandomInput(type, data_dims)
.Input(test::AsTensor<int32>(begin))
.Input(test::AsTensor<int32>(end))
.Input(test::AsTensor<int32>(strides))
@@ -2418,13 +2548,13 @@ TEST_F(OpTest, StridedSliceGrad) {
// TODO(phawkins): use shape inference for the forward op to compute the
// gradient shape for the backward op. At present, there is a low
// probability of the golden op succeeding.
- ExpectTfAndXlaOutputsAreClose(
+ return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("StridedSliceGrad")
.Input(test::AsTensor<int64>(dims))
.Input(test::AsTensor<int64>(begin))
.Input(test::AsTensor<int64>(end))
.Input(test::AsTensor<int64>(strides))
- .Input(RandomTensor(type, RandomDims(1)))
+ .RandomInput(type, RandomDims(1))
.Attr("T", type)
.Attr("Index", DT_INT64)
.Attr("begin_mask", begin_mask)
@@ -2437,48 +2567,48 @@ TEST_F(OpTest, StridedSliceGrad) {
TEST_F(OpTest, Tanh) {
Repeatedly([this]() {
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tanh")
- .Input(RandomTensor(DT_FLOAT))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Tanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, TanhGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
- .Input(RandomTensor(DT_FLOAT, dims))
- .Input(RandomTensor(DT_FLOAT, dims))
- .Attr("T", DT_FLOAT));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Tile) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- Tensor t = RandomTensor(type, RandomDims(1));
- std::vector<int32> multiples(t.dims());
- for (int i = 0; i < t.dims(); ++i) {
+ std::vector<int64> t_dims = RandomDims(1);
+ std::vector<int32> multiples(t_dims.size());
+ for (int i = 0; i < t_dims.size(); ++i) {
multiples[i] = std::uniform_int_distribution<int>(1, 3)(generator());
}
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tile")
- .Input(t)
- .Input(test::AsTensor<int32>(multiples))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Tile")
+ .RandomInput(type, t_dims)
+ .Input(test::AsTensor<int32>(multiples))
+ .Attr("T", type));
});
}
TEST_F(OpTest, Transpose) {
Repeatedly([this]() {
DataType type = Choose<DataType>(kAllXlaTypes);
- Tensor data = RandomTensor(type);
- std::vector<int32> perm(data.dims());
+ std::vector<int64> data_dims = RandomDims();
+ std::vector<int32> perm(data_dims.size());
std::iota(perm.begin(), perm.end(), 0);
std::shuffle(perm.begin(), perm.end(), generator());
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose")
- .Input(data)
- .Input(test::AsTensor<int32>(perm))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose")
+ .RandomInput(type, data_dims)
+ .Input(test::AsTensor<int32>(perm))
+ .Attr("T", type));
});
}
@@ -2486,10 +2616,10 @@ TEST_F(OpTest, TruncateDiv) {
Repeatedly([this]() {
DataType type = DT_INT32;
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
@@ -2497,26 +2627,18 @@ TEST_F(OpTest, TruncateMod) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
auto dims = BroadcastableDims();
- ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod")
- .Input(RandomTensor(type, dims.first))
- .Input(RandomTensor(type, dims.second))
- .Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod")
+ .RandomInput(type, dims.first)
+ .RandomInput(type, dims.second)
+ .Attr("T", type));
});
}
TEST_F(OpTest, ZerosLike) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("ZerosLike").Input(RandomTensor(type)).Attr("T", type));
- });
-}
-
-TEST_F(OpTest, OnesLike) {
- Repeatedly([this]() {
- DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
- ExpectTfAndXlaOutputsAreClose(
- OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type));
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type));
});
}
@@ -2535,6 +2657,9 @@ int main(int argc, char** argv) {
tensorflow::Flag("tf_xla_test_repetitions",
&tensorflow::tf_xla_test_repetitions,
"Number of repetitions for each test."),
+ tensorflow::Flag("tf_xla_max_tensor_size",
+ &tensorflow::tf_xla_max_tensor_size,
+ "Maximum number of elements for random input tensors."),
tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr,
"Tensorflow device type to use for test"),
tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit,
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
new file mode 100644
index 0000000000..27a2977305
--- /dev/null
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -0,0 +1,1018 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for XLA TensorArray Ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def _make_converter(dtype):
+ def _converter(x):
+ return np.asarray(x).astype(dtype.as_numpy_dtype)
+ return _converter
+
+
+class TensorArrayTest(xla_test.XLATestCase):
+
+ def testTensorArrayWriteRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3)
+
+ w0 = ta.write(0, [[4.0, 5.0]])
+ w1 = w0.write(1, [[1.0, 3.0]])
+ w2 = w1.write(2, [[7.0, -8.5]])
+
+ r0 = w2.read(0)
+ r1 = w2.read(1)
+ r2 = w2.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual([[4.0, 5.0]], d0)
+ self.assertAllEqual([[1.0, 3.0]], d1)
+ self.assertAllEqual([[7.0, -8.5]], d2)
+
+ def _testTensorArrayWritePack(self, tf_dtype):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ convert = _make_converter(tf_dtype)
+
+ w0 = ta.write(0, convert([[4.0, 5.0]]))
+ w1 = w0.write(1, convert([[6.0, 7.0]]))
+ w2 = w1.write(2, convert([[8.0, 9.0]]))
+
+ c0 = w2.stack()
+
+ self.assertAllEqual(
+ convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval())
+
+ def testTensorArrayWritePack(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayWritePack(dtype)
+
+ def testEmptyTensorArrayPack(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+
+ empty_element = np.zeros((0, 1), dtype=np.float32)
+ w0 = ta.write(0, empty_element)
+ w1 = w0.write(1, empty_element)
+ w2 = w1.write(2, empty_element)
+
+ c0 = w2.stack()
+
+ self.assertAllEqual([3, 0, 1], c0.eval().shape)
+
+ def _testTensorArrayWriteConcat(self, tf_dtype):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ convert = _make_converter(tf_dtype)
+
+ w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0]]))
+ w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]]))
+ w2 = w1.write(2, convert([[8.0, 9.0], [204.0, 205.0]]))
+
+ c0 = w2.concat()
+
+ self.assertAllEqual(
+ convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0],
+ [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval())
+
+ def testTensorArrayWriteConcat(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayWriteConcat(dtype)
+
+ def _testTensorArrayUnpackRead(self, tf_dtype):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ convert = _make_converter(tf_dtype)
+
+ # Unpack a vector into scalars
+ w0 = ta.unstack(convert([1.0, 2.0, 3.0]))
+ r0 = w0.read(0)
+ r1 = w0.read(1)
+ r2 = w0.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert(1.0), d0)
+ self.assertAllEqual(convert(2.0), d1)
+ self.assertAllEqual(convert(3.0), d2)
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ # Unpack a matrix into vectors
+ w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
+ r0 = w1.read(0)
+ r1 = w1.read(1)
+ r2 = w1.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([1.0, 1.1]), d0)
+ self.assertAllEqual(convert([2.0, 2.1]), d1)
+ self.assertAllEqual(convert([3.0, 3.1]), d2)
+
+ # Reset ta because we're going to change the shape, else shape
+ # inference will throw an error.
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ # Try unpacking an empty matrix, which should not cause an error.
+ w2 = ta.unstack(convert([[], [], []]))
+ r0 = w2.read(0)
+ r1 = w2.read(1)
+ r2 = w2.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([]), d0)
+ self.assertAllEqual(convert([]), d1)
+ self.assertAllEqual(convert([]), d2)
+
+ def _testTensorArrayUnpackReadMaybeLegacy(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayUnpackRead(dtype)
+
+ def testTensorArrayUnpackRead(self):
+ self._testTensorArrayUnpackReadMaybeLegacy()
+
+ def _testTensorArraySplitRead(self, tf_dtype):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ convert = _make_converter(tf_dtype)
+
+ # Split an empty vector
+ lengths = constant_op.constant([0, 0, 0])
+ w0 = ta.split(convert([]), lengths=lengths)
+ r0 = w0.read(0)
+ r1 = w0.read(1)
+ r2 = w0.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([]), d0)
+ self.assertAllEqual(convert([]), d1)
+ self.assertAllEqual(convert([]), d2)
+
+ # Split a vector
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+ lengths = constant_op.constant([1, 1, 1])
+ w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths)
+ r0 = w0.read(0)
+ r1 = w0.read(1)
+ r2 = w0.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([1.0]), d0)
+ self.assertAllEqual(convert([2.0]), d1)
+ self.assertAllEqual(convert([3.0]), d2)
+
+ # Split a matrix
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+ lengths = constant_op.constant([1, 1, 1])
+ w0 = ta.split(
+ convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), lengths=lengths)
+ r0 = w0.read(0)
+ r1 = w0.read(1)
+ r2 = w0.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([[1.0, 101.0]]), d0)
+ self.assertAllEqual(convert([[2.0, 201.0]]), d1)
+ self.assertAllEqual(convert([[3.0, 301.0]]), d2)
+
+ def testTensorArraySplitRead(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArraySplitRead(dtype)
+
+ def testTensorGradArrayWriteRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3)
+
+ w0 = ta.write(0, [[4.0]])
+ w1 = w0.write(1, [[1.0]])
+ w2 = w1.write(2, [[-3.0]])
+
+ g_ta = w2.grad("grad")
+
+ g_w0 = g_ta.write(0, [[5.0]])
+ g_w1 = g_w0.write(1, [[2.0]])
+ g_w2 = g_w1.write(2, [[-2.0]])
+
+ r0 = w2.read(0)
+ r1 = w2.read(1)
+ r2 = w2.read(2)
+
+ g_r0 = g_w2.read(0)
+ g_r1 = g_w2.read(1)
+ g_r2 = g_w2.read(2)
+
+ d0, d1, d2, g_d0, g_d1, g_d2 = session.run([r0, r1, r2, g_r0, g_r1, g_r2])
+ self.assertAllEqual([[4.0]], d0)
+ self.assertAllEqual([[1.0]], d1)
+ self.assertAllEqual([[-3.0]], d2)
+ self.assertAllEqual([[5.0]], g_d0)
+ self.assertAllEqual([[2.0]], g_d1)
+ self.assertAllEqual([[-2.0]], g_d2)
+
+ def testTensorGradArrayDynamicWriteRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3)
+
+ w0 = ta.write(0, [[4.0]])
+ w1 = w0.write(1, [[1.0]])
+ w2 = w1.write(2, [[-3.0]])
+
+ g_ta = w2.grad("grad") # Get gradient array here so we know the shape
+
+ s = w2.size()
+ g_s = g_ta.size()
+
+ g_w0 = g_ta.write(0, [[5.0]])
+ g_w1 = g_w0.write(1, [[2.0]])
+ g_w2 = g_w1.write(2, [[-2.0]])
+
+ r0 = w2.read(0)
+ r1 = w2.read(1)
+ r2 = w2.read(2)
+
+ g_r0 = g_w2.read(0)
+ g_r1 = g_w2.read(1)
+ g_r2 = g_w2.read(2)
+
+ d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = session.run(
+ [r0, r1, r2, g_r0, g_r1, g_r2, s, g_s])
+ self.assertAllEqual([[4.0]], d0)
+ self.assertAllEqual([[1.0]], d1)
+ self.assertAllEqual([[-3.0]], d2)
+ self.assertAllEqual([[5.0]], g_d0)
+ self.assertAllEqual([[2.0]], g_d1)
+ self.assertAllEqual([[-2.0]], g_d2)
+ self.assertAllEqual(3, vs)
+ self.assertAllEqual(3, g_vs)
+
+ def testTensorGradAccessTwiceReceiveSameObject(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3,
+ element_shape=[1, 2])
+ g_ta_0 = ta.grad("grad")
+ g_ta_1 = ta.grad("grad")
+
+ with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]):
+ # Write with one gradient handle, read with another copy of it
+ r1_0 = g_ta_1.read(0)
+
+ t_g_ta_0, t_g_ta_1, d_r1_0 = session.run(
+ [g_ta_0.handle.op, g_ta_1.handle.op, r1_0])
+ self.assertAllEqual(t_g_ta_0, t_g_ta_1)
+ self.assertAllEqual([[4.0, 5.0]], d_r1_0)
+
+ def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+
+ # Test writing the wrong datatype
+ with self.assertRaisesOpError(
+ "TensorArray dtype is float but op has dtype int32"):
+ ta.write(-1, np.int32(7)).flow.eval()
+
+ def testTensorArrayReadWrongIndexOrDataTypeFails(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+
+ w0 = ta.write(0, [[4.0, 5.0]])
+
+ # Test reading wrong datatype
+ r0_bad = gen_data_flow_ops._tensor_array_read_v3(
+ handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
+ with self.assertRaisesOpError(
+ "TensorArray dtype is float but Op requested dtype double."):
+ r0_bad.eval()
+
+ # Test reading from a different index than the one we wrote to
+ w0.read(1)
+
+ def testTensorArraySplitIncompatibleShapesFails(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3,
+ infer_shape=False)
+
+ with self.assertRaisesOpError(
+ r"value is not 1D"):
+ lengths = array_ops.placeholder(dtypes.int64)
+ ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
+
+ with self.assertRaisesOpError(
+ r"lengths must be equal: 1 vs. 2"):
+ ta.split([1.0, 2.0, 3.0], [1, 2, 3]).flow.eval()
+
+ with self.assertRaisesOpError(
+ r"value must have rank >= 1"):
+ ta.split(1.0, [1]).flow.eval()
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=2,
+ infer_shape=False)
+
+ with self.assertRaisesOpError(
+ r"TensorArray's size is not equal to the size of lengths "
+ r"\(1 vs. 2\)"):
+ ta.split([1.0], [1]).flow.eval()
+
+ def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)
+
+ c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype)
+
+ w0 = ta.write(2, c(3.0))
+ w1 = w0.write(2, c(4.0))
+
+ ta_grad = w1.grad("grad")
+
+ w0_grad = ta_grad.write(2, c(3.0))
+ w1_grad = w0_grad.write(2, c(4.0))
+ w2_grad = w1_grad.write(2, c(5.0))
+
+ # Assert that aggregation works correctly
+ self.assertAllEqual(c(12.00), w2_grad.read(2).eval())
+
+ # Using differing shapes causes an exception
+ wb0_grad = ta_grad.write(1, c(1.0))
+ wb1_grad = wb0_grad.write(1, c([1.0]))
+
+ with self.assertRaisesOpError(
+ r"Mismatched TensorArray sizes"):
+ wb1_grad.flow.eval()
+
+ def testTensorArrayWriteGradientAddMultipleAdds(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
+
+ def testMultiTensorArray(self):
+ with self.test_session(), self.test_scope():
+ h1 = tensor_array_ops.TensorArray(
+ size=1, dtype=dtypes.float32, tensor_array_name="foo")
+ w1 = h1.write(0, 4.0)
+ r1 = w1.read(0)
+
+ h2 = tensor_array_ops.TensorArray(
+ size=1, dtype=dtypes.float32, tensor_array_name="bar")
+
+ w2 = h2.write(0, 5.0)
+ r2 = w2.read(0)
+ r = r1 + r2
+ self.assertAllClose(9.0, r.eval())
+
+ def _testTensorArrayGradientWriteReadType(self, dtype):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.as_dtype(dtype),
+ tensor_array_name="foo",
+ size=3,
+ infer_shape=False)
+
+ c = lambda x: np.array(x, dtype=dtype)
+
+ value_0 = constant_op.constant(c([[4.0, 5.0]]))
+ value_1 = constant_op.constant(c([[3.0, 3.5]]))
+
+ w0 = ta.write(0, value_0)
+ w1 = w0.write(1, value_1)
+ r0 = w1.read(0)
+ r1 = w1.read(1)
+ r0_2 = w1.read(0)
+
+ # Test individual components' gradients
+ grad_just_r0 = gradients_impl.gradients(
+ ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])])
+ grad_just_r0_vals = session.run(grad_just_r0)
+ self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0])
+
+ grad_r0_r0_2 = gradients_impl.gradients(
+ ys=[r0, r0_2],
+ xs=[value_0],
+ grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])])
+ grad_r0_r0_2_vals = session.run(grad_r0_r0_2)
+ self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0])
+
+ grad_just_r1 = gradients_impl.gradients(
+ ys=[r1], xs=[value_1], grad_ys=[c([[-2.0, -4.0]])])
+ grad_just_r1_vals = session.run(grad_just_r1)
+ self.assertAllEqual(c([[-2.0, -4.0]]), grad_just_r1_vals[0])
+
+ # Test combined gradients
+ grad = gradients_impl.gradients(
+ ys=[r0, r0_2, r1],
+ xs=[value_0, value_1],
+ grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]]), c([[-2.0, -10.0]])])
+ grad_vals = session.run(grad)
+ self.assertEqual(len(grad_vals), 2)
+ self.assertAllEqual(c([[3.0, 2.0]]), grad_vals[0])
+ self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1])
+
+ def testTensorArrayGradientWriteRead(self):
+ for dtype in self.numeric_types:
+ self._testTensorArrayGradientWriteReadType(dtype)
+
+ def _testTensorArrayGradientWritePackConcatAndRead(self):
+ with self.test_session() as sess, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=2,
+ clear_after_read=False)
+
+ value_0 = constant_op.constant([-1.0, 1.0])
+ value_1 = constant_op.constant([-10.0, 10.0])
+
+ w0 = ta.write(0, value_0)
+ w1 = w0.write(1, value_1)
+ p0 = w1.stack()
+ r0 = w1.read(0)
+ s0 = w1.concat()
+
+ # Test gradient accumulation between read(0), pack(), and concat()
+ with ops.control_dependencies([p0, r0, s0]):
+ grad_r = gradients_impl.gradients(
+ ys=[p0, r0, s0],
+ xs=[value_0, value_1],
+ grad_ys=[
+ [[2.0, 3.0], [4.0, 5.0]], # stack gradient
+ [-0.5, 1.5], # read(0) gradient
+ [20.0, 30.0, 40.0, 50.0], # concat gradient
+ ])
+ grad_vals = sess.run(grad_r) # 2 + 2 entries
+
+ self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
+ self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
+
+ def testTensorArrayGradientWritePackConcatAndRead(self):
+ self._testTensorArrayGradientWritePackConcatAndRead()
+
+ def testTensorArrayReadTwice(self):
+ with self.test_session(), self.test_scope():
+ value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
+
+ ta_readtwice = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=2,
+ clear_after_read=False)
+ w_readtwice = ta_readtwice.unstack(value)
+ r0_readtwice = w_readtwice.read(0)
+ with ops.control_dependencies([r0_readtwice]):
+ r1_readtwice = w_readtwice.read(0)
+
+ self.assertAllEqual([1.0, -1.0], r1_readtwice.eval())
+
+ def _testTensorArrayGradientUnpackRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=2,
+ clear_after_read=False)
+
+ value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
+
+ w = ta.unstack(value)
+ r0 = w.read(0)
+ r0_1 = w.read(0)
+ r1 = w.read(1)
+
+ # Test combined gradients + aggregation of read(0)
+ grad = gradients_impl.gradients(
+ ys=[r0, r0_1, r1],
+ xs=[value],
+ grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]])
+ grad_vals = session.run(grad)
+
+ self.assertEqual(len(grad_vals), 1)
+ self.assertAllEqual([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0])
+
+ def testTensorArrayGradientUnpackRead(self):
+ self._testTensorArrayGradientUnpackRead()
+
+ def testTensorArrayGradientSplitConcat(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=2)
+
+ value = constant_op.constant(
+ [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0], [1000.0, -1000.0]])
+
+ w = ta.split(value, [2, 2])
+ r = w.concat()
+
+ # Test combined gradients
+ grad = gradients_impl.gradients(
+ ys=[r],
+ xs=[value],
+ grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0],
+ [2000.0, -2000.0]]])
+ grad_vals = session.run(grad)
+
+ self.assertEqual(len(grad_vals), 1)
+ self.assertAllEqual([[2.0, -2.0], [20.0, -20.0], [200.0, -200.0],
+ [2000.0, -2000.0]],
+ grad_vals[0])
+
+ # TODO(phawkins): implement TensorArrayClose
+ # def testCloseTensorArray(self):
+ # with self.test_session() as session, self.test_scope():
+ # ta = tensor_array_ops.TensorArray(
+ # dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ # c1 = ta.close()
+ # session.run(c1)
+
+ def testSizeTensorArray(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ s = ta.size()
+ self.assertAllEqual(3, s.eval())
+
+ # TODO(phawkins): implement TensorArrayClose
+ # def testWriteCloseTensorArray(self):
+ # with self.test_session(), self.test_scope():
+ # ta = tensor_array_ops.TensorArray(
+ # dtype=dtypes.float32,
+ # tensor_array_name="foo",
+ # size=3,
+ # infer_shape=False)
+ # w0 = ta.write(0, [[4.0, 5.0]])
+ # w1 = w0.write(1, [3.0])
+ # w1.close().run() # Expected to run without problems
+
+ # TODO(phawkins): implement while loops.
+ # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
+ # np_dtype = dtype.as_numpy_dtype
+ # with self.test_session() as session, self.test_scope():
+ # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5))
+ # var = variables.Variable(np.arange(100, 105, dtype=np_dtype))
+ # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype))
+ # ta = tensor_array_ops.TensorArray(
+ # dtype=dtype,
+ # tensor_array_name="foo",
+ # size=0 if dynamic_size else 3,
+ # dynamic_size=dynamic_size)
+ # time_0 = array_ops.identity(0)
+
+ # def body(time, ta_t, state):
+ # sliced = array_ops.slice(
+ # v0, begin=array_ops.stack([time, 0]), size=[1, -1])
+ # sliced = array_ops.squeeze(sliced)
+ # out = sliced + var + state
+ # state += sliced
+ # ta_t = ta_t.write(time, out)
+ # return (time + 1, ta_t, state)
+
+ # (unused_0, h_final, unused_2) = control_flow_ops.while_loop(
+ # cond=lambda time, unused_1, unused_2: time < 3,
+ # body=body,
+ # loop_vars=(time_0, ta, state0),
+ # shape_invariants=(time_0.get_shape(), tensor_shape.unknown_shape(),
+ # tensor_shape.unknown_shape()),
+ # parallel_iterations=3)
+ # vout = h_final.stack()
+
+ # grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
+ # v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
+ # state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
+ # var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]
+
+ # variables.global_variables_initializer().run()
+ # state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
+ # session.run([state0, var, v0, vout, v0_grad, var_grad, state0_grad])
+ # )
+ # just_v0_grad_t, = session.run([v0_grad])
+
+ # # state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ]
+ # # vout = [ v0[0] + var + state[0] |
+ # # v0[1] + var + state[1] |
+ # # v0[2] + var + state[2] ]
+ # # = [ v0[0] + var + state0 |
+ # # v0[1] + var + state0 + v0[0] |
+ # # v0[2] + var + state0 + v0[0] + v0[1] ]
+ # #
+ # # d(vout[0])/d(v0) = [1 | 0 | 0 ]
+ # # d(vout[1])/d(v0) = [1 | 1 | 0 ]
+ # # d(vout[2])/d(v0) = [1 | 1 | 1 ]
+ # # d(vout)/d(var) = [1 | 1 | 1]
+ # # d(vout)/d(state0) = [ 1 | 1 | 1 ]
+
+ # state_per_time = np.array(
+ # [state0_t, state0_t + v0_t[0, :],
+ # state0_t + v0_t[0, :] + v0_t[1, :]])
+
+ # # Compare forward prop
+ # self.assertAllClose(v0_t + var_t + state_per_time, vout_t)
+
+ # # Compare backward prop
+ # expected_v0_grad_t = np.array([
+ # grad_val[0, :] + grad_val[1, :] + grad_val[2, :],
+ # grad_val[1, :] + grad_val[2, :], grad_val[2, :]
+ # ])
+
+ # self.assertAllEqual(expected_v0_grad_t, v0_grad_t)
+ # self.assertAllEqual(expected_v0_grad_t, just_v0_grad_t)
+ # self.assertAllClose(grad_val.sum(axis=0), var_grad_t)
+ # self.assertAllClose(grad_val.sum(axis=0), state0_grad_t)
+
+ # def testWhileLoopWritePackGradients(self):
+ # self._testWhileLoopWritePackGradients(
+ # dynamic_size=False, dtype=dtypes.float32)
+ # # TODO(ebrevdo): re-enable when While supports non-float32 gradients.
+ # # self._testWhileLoopWritePackGradients(
+ # # dynamic_size=False, dtype=tf.int64)
+
+ # def testWhileLoopDynamicWritePackGradients(self):
+ # self._testWhileLoopWritePackGradients(
+ # dynamic_size=True, dtype=dtypes.float32)
+
+ # def testGradSerialTwoLoops(self):
+ # with self.test_session(), self.test_scope():
+ # num_steps = 100
+ # acc = tensor_array_ops.TensorArray(
+ # dtype=dtypes.float32,
+ # size=num_steps,
+ # clear_after_read=False,
+ # element_shape=tensor_shape.scalar())
+ # i = constant_op.constant(0, name="i")
+ # x = constant_op.constant(2.0, name="x")
+
+ # c = lambda i, acc: i < 5
+
+ # def b(i, acc):
+ # x1 = control_flow_ops.cond(
+ # math_ops.equal(i, 0), lambda: x,
+ # lambda: math_ops.multiply(acc.read(i - 1), 2.0))
+ # return i + 1, acc.write(i, x1)
+
+ # i1, acc1 = control_flow_ops.while_loop(c, b, [i, acc])
+
+ # z = constant_op.constant(0.0)
+
+ # def fn(i, acc):
+ # return i + 1, acc.write(i, z)
+
+ # _, acc2 = control_flow_ops.while_loop(lambda i, acc: i < num_steps, fn,
+ # [i1, acc1])
+
+ # r = acc2.stack()
+ # grad = gradients_impl.gradients(r, [x])[0]
+ # self.assertAllClose(31.0, grad.eval())
+
+ def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
+ with self.test_session() as session, self.test_scope():
+ a = array_ops.identity(
+ np.arange(
+ 3 * 5, dtype=np.float32).reshape(3, 5) + 1)
+ b = array_ops.identity(
+ np.arange(
+ 3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5)
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
+ ta = ta.write(0, a, name="write_a")
+ ta = ta.write(1, b, name="write_b")
+ c = (
+ ta.read(
+ 0, name="read_a_0") + # a + b
+ ta.read(
+ 1, name="read_b_0"))
+ g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1)
+ grad_a = gradients_impl.gradients([c], [a], [g0])[0] # d(a+b)/da = 1
+ grad_b = gradients_impl.gradients([c], [b], [g0])[0] # d(a+b)/db = 1
+
+ # Test gradients calculated individually
+ grad_a_t, = session.run([grad_a])
+ self.assertAllEqual(grad_a_t, g0)
+
+ grad_b_t, = session.run([grad_b])
+ self.assertAllEqual(grad_b_t, g0)
+
+ # Test gradients calculated jointly
+ joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b])
+ self.assertAllEqual(joint_grad_a_t, g0)
+ self.assertAllEqual(joint_grad_b_t, g0)
+
+ def testWriteShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ c0 = constant_op.constant([4.0, 5.0])
+ w0 = ta.write(0, c0)
+ r0 = w0.read(0)
+ self.assertAllEqual(c0.get_shape(), r0.get_shape())
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ c1 = constant_op.constant([6.0, 7.0])
+ w1 = w0.write(1, c1)
+ r0 = w1.read(0)
+ r1 = w1.read(1)
+ self.assertAllEqual(c0.get_shape(), r0.get_shape())
+ self.assertAllEqual(c1.get_shape(), r1.get_shape())
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ c2 = constant_op.constant([4.0, 5.0, 6.0])
+ with self.assertRaises(ValueError):
+ w0.write(0, c2)
+
+ def testPartlyUnknownShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=6)
+
+ c0 = array_ops.placeholder(dtypes.float32, [None, None, None, 3])
+ w0 = ta.write(0, c0)
+ r0 = w0.read(0)
+ self.assertAllEqual([None, None, None, 3], r0.get_shape().as_list())
+
+ c1 = array_ops.placeholder(dtypes.float32, [None, None, None, 3])
+ w1 = w0.write(1, c1)
+ r1 = w1.read(0)
+ self.assertAllEqual([None, None, None, 3], r1.get_shape().as_list())
+
+ # Writing less specific shape (doesn't change type.)
+ c2 = array_ops.placeholder(dtypes.float32, [None, None, None, None])
+ w2 = w1.write(2, c2)
+ r2 = w2.read(0)
+ self.assertAllEqual([None, None, None, 3], r2.get_shape().as_list())
+
+ # Writing more specific shape in one dimension and less specific in
+ # another.
+ c3 = array_ops.placeholder(dtypes.float32, [None, None, 2, None])
+ w3 = w2.write(3, c3)
+ r3 = w3.read(0)
+ self.assertAllEqual([None, None, 2, 3], r3.get_shape().as_list())
+
+ # Writing partly defined shape using TensorArray.scatter.
+ c4 = array_ops.placeholder(dtypes.float32, [2, None, 4, 2, 3])
+ w4 = w3.scatter([4, 5], c4)
+ r4 = w4.read(0)
+ self.assertAllEqual([None, 4, 2, 3], r4.get_shape().as_list())
+
+ # Writing fully defined shape using TensorArray.split.
+ c5 = array_ops.placeholder(dtypes.float32, [10, 4, 2, 3])
+ w5 = w4.split(c5, constant_op.constant([5, 5]))
+ r5 = w5.read(0)
+ self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list())
+
+ def _testUnpackShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=0,
+ infer_shape=True)
+ value = constant_op.constant(
+ [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
+ w0 = ta.unstack(value)
+ r0 = w0.read(0)
+ self.assertAllEqual((2,), r0.get_shape())
+
+ c1 = constant_op.constant([4.0, 5.0])
+ w1 = w0.write(3, c1)
+ r1 = w1.read(0)
+ self.assertAllEqual(c1.get_shape(), r1.get_shape())
+
+ c2 = constant_op.constant([4.0, 5.0, 6.0])
+ with self.assertRaises(ValueError):
+ w1.write(4, c2)
+
+ def testUnpackShape(self):
+ self._testUnpackShape()
+
+ def testSplitShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=0,
+ infer_shape=True)
+ value = constant_op.constant([[1.0, -1.0], [2.0, -2.0], [3.0, -3.0]])
+ w0 = ta.split(value, [1, 1, 1])
+ r0 = w0.read(0)
+ self.assertAllEqual((1, 2), r0.get_shape())
+
+ ta1 = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo1",
+ size=0,
+ infer_shape=True)
+ w0 = ta1.split(value, [1, 2])
+ r0 = w0.read(0)
+ self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
+
+ def testWriteUnknownShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3,
+ infer_shape=True)
+ c0 = array_ops.placeholder(dtypes.float32)
+ w0 = ta.write(0, c0)
+ r0 = w0.read(0)
+ self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
+
+ def _testGradientWhenNotAllComponentsRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
+ x = constant_op.constant([2.0, 3.0])
+ w = ta.unstack(x)
+ r0 = w.read(0)
+ # calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0).
+ grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0])
+ grad_r0_vals = session.run(grad_r0)[0]
+ self.assertAllEqual(grad_r0_vals, [1.0, 0.0])
+
+ def testGradientWhenNotAllComponentsRead(self):
+ self._testGradientWhenNotAllComponentsRead()
+
+ def _testTensorArrayEvalEmpty(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, size=0, infer_shape=False)
+ with self.assertRaisesOpError(
+ "TensorArray has size zero, but element shape <unknown> is not fully "
+ "defined. Currently only static shapes are supported when packing "
+ "zero-size TensorArrays."):
+ ta.stack().eval()
+
+ def testTensorArrayEvalEmpty(self):
+ self._testTensorArrayEvalEmpty()
+
+ def _testTensorArrayEvalEmptyWithDefault(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, size=0, infer_shape=True)
+ self.assertEqual(0, ta.size().eval())
+ ta = ta.unstack(array_ops.zeros([0, 3, 5]))
+ packed = ta.stack()
+ self.assertAllEqual([0, 3, 5], packed.eval().shape)
+ # Concatenating zero tensors along their first dimension gives a
+ # first dimension of zero
+ self.assertAllEqual([0, 5], ta.concat().eval().shape)
+
+ def testTensorArrayEvalEmptyWithDefault(self):
+ self._testTensorArrayEvalEmptyWithDefault()
+
+ def testTensorArrayScatterReadAndGradients(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=10)
+
+ indices = constant_op.constant([1, 8])
+ value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
+
+ w = ta.scatter(indices, value)
+ r0 = w.read(1)
+ r1 = w.read(8)
+
+ # Test combined gradients + aggregation of read(0)
+ grad = gradients_impl.gradients(
+ ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
+ read_vals, grad_vals = session.run([[r0, r1], grad])
+
+ self.assertEqual(len(read_vals), 2)
+ self.assertEqual(len(grad_vals), 1)
+ self.assertAllEqual([1.0, -1.0], read_vals[0])
+ self.assertAllEqual([10.0, -10.0], read_vals[1])
+ self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
+
+ def testTensorArrayWriteGatherAndGradients(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=10)
+
+ values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)])
+ indices = constant_op.constant([1, 8])
+
+ w = ta.unstack(values)
+ g = w.gather(indices)
+
+ # Test combined gradients + aggregation of read(0)
+ grad = gradients_impl.gradients(
+ ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]])
+ g_vals, grad_vals = session.run([[g], grad])
+
+ # Gradients for 8 of the 10 unread components are zero.
+ expected_grad = np.zeros((10, 2))
+ expected_grad[1] = [2.0, 3.0]
+ expected_grad[8] = [4.0, 5.0]
+
+ self.assertEqual(len(g_vals), 1)
+ self.assertEqual(len(grad_vals), 1)
+ self.assertAllEqual([[1.0, -1.0], [8.0, -8.0]], g_vals[0])
+ self.assertAllEqual(expected_grad, grad_vals[0])
+
+ def testTensorArrayIdentity(self):
+ with self.test_session() as session, self.test_scope():
+ ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
+ infer_shape=False)
+ ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4,
+ infer_shape=True)
+
+ ta0 = ta0.write(0, 0.)
+ ta1 = ta1.write(0, 1)
+
+ v0 = resource_variable_ops.ResourceVariable(0)
+ v1 = resource_variable_ops.ResourceVariable(0)
+
+ with ops.control_dependencies([v0.assign_add(1)]):
+ ta0 = ta0.identity()
+
+ with ops.control_dependencies([v1.assign_add(1)]):
+ ta1 = ta1.identity()
+
+ read0 = ta0.read(0)
+ read1 = ta1.read(0)
+
+ size0 = ta0.size()
+ size1 = ta1.size()
+
+ # Tests correct properties on new TensorArrays.
+ self.assertEqual(dtypes.float32, ta0.dtype)
+ self.assertEqual(dtypes.int32, ta1.dtype)
+ self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
+ self.assertEqual(tensor_shape.scalar(), read1.get_shape())
+
+ variables.global_variables_initializer().run()
+
+ read0_v, read1_v, size0_v, size1_v = session.run(
+ (read0, read1, size0, size1))
+
+ # Tests that the control dependencies was added and executed.
+ self.assertEqual(1, v0.eval())
+ self.assertEqual(1, v1.eval())
+
+ # Tests correct TensorArray.
+ self.assertEqual(read0_v, 0)
+ self.assertEqual(read1_v, 1)
+ self.assertEqual(size0_v, 2)
+ self.assertEqual(size1_v, 4)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index f7fe186cf8..79549644ea 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -54,16 +54,20 @@ class XLATestCase(test.TestCase):
self.device = FLAGS.test_device
self.has_custom_call = (self.device == 'XLA_CPU')
self.all_tf_types = [
- dtypes.DType(types_pb2.DataType.Value(name))
+ dtypes.as_dtype(types_pb2.DataType.Value(name))
for name in FLAGS.types.split(',')
]
- self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
- self.int_types = [
- dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer
+ self.int_tf_types = [
+ dtype for dtype in self.all_tf_types if dtype.is_integer
]
- self.float_types = [
- dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating
+ self.float_tf_types = [
+ dtype for dtype in self.all_tf_types if dtype.is_floating
]
+ self.numeric_tf_types = self.int_tf_types + self.float_tf_types
+
+ self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
+ self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types]
+ self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types]
self.numeric_types = self.int_types + self.float_types
# Parse the manifest file, if any, into a regex identifying tests to
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index c4cbaebb25..36a6c90af4 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -89,6 +89,8 @@ Status BackwardsConstAnalysis(const Graph& g,
{"StridedSliceGrad", "end"},
{"StridedSliceGrad", "strides"},
{"Sum", "reduction_indices"},
+ {"TensorArrayV3", "size"},
+ {"TensorArraySplitV3", "lengths"},
{"Tile", "multiples"},
{"Transpose", "perm"}};
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 81b065689d..a434c74680 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -55,6 +55,7 @@ tf_kernel_library(
"spacetobatch_op.cc",
"split_op.cc",
"strided_slice_op.cc",
+ "tensor_array_ops.cc",
"tile_ops.cc",
"training_ops.cc",
"transpose_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
index d6897d6e33..620fc84437 100644
--- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
@@ -49,14 +49,15 @@ class ArgOp : public XlaOpKernel {
return;
}
- XlaContext& tc = XlaContext::Get(ctx);
- const XlaContext::Argument& arg = tc.args()[index_];
+ XlaContext& xc = XlaContext::Get(ctx);
+ const XlaContext::Argument& arg = xc.args()[index_];
if (arg.is_variable) {
- // We use the argument position of the variable input as a unique ID.
// TODO(phawkins): this code assumes that variables do not alias.
- OP_REQUIRES_OK(ctx, tc.CreateVariable(index_, arg.name, arg.value.type,
- arg.value.handle));
- ctx->SetVariableOutput(0, index_);
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type,
+ arg.value.handle, &var));
+ var->tensor_array_size = arg.tensor_array_size;
+ ctx->SetVariableOutput(0, var);
} else if (arg.value.is_constant) {
ctx->SetConstantOutput(0, arg.value.constant_value);
} else {
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
new file mode 100644
index 0000000000..de542d55e8
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -0,0 +1,538 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA TensorArray operators.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+// Since the element shape is not always provided to the TensorArrayV3 operator,
+// we must support lazily initialization of the TensorArray at the time of the
+// first write.
+// If a TensorArray `var` has not been initialized, constructs storage for the
+// TensorArray with elements of `elem_shape`. For both initialized and
+// uninitialized TensorArrays, checks that the tensor has a type compatible with
+// 'dtype' and shape compatible with 'elem_shape'.
+Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
+ XlaVariable* var, DataType dtype,
+ const TensorShape& elem_shape) {
+ if (var->type != dtype) {
+ return errors::InvalidArgument(
+ "TensorArray dtype is ", DataTypeString(var->type),
+ " but op has dtype ", DataTypeString(dtype), ".");
+ }
+
+ TF_RET_CHECK(var->tensor_array_size >= 0)
+ << var->name << " size " << var->tensor_array_size;
+ TensorShape ta_shape;
+ ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AppendShape(elem_shape);
+
+ if (var->value.handle() == 0) {
+ // TensorArray has not been initialized.
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type);
+ var->value = builder->Broadcast(zero, ta_shape.dim_sizes());
+ } else {
+ // Checks the elem_shape matches the TensorArray shape.
+ auto shape_or_status = builder->GetShape(var->value);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ TensorShape shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie());
+ if (ta_shape != shape) {
+ return errors::InvalidArgument(
+ "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
+ shape.DebugString());
+ }
+ }
+ return Status::OK();
+}
+
+// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
+xla::ComputationDataHandle PadIndexWithZeros(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ int count) {
+ xla::ComputationDataHandle zero = builder->ConstantR1<int32>({0});
+ std::vector<xla::ComputationDataHandle> xs(count + 1, zero);
+ xs[0] = builder->Reshape(x, {1});
+ return builder->ConcatInDim(xs, 0);
+}
+
+// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
+// relevant slice of 'operand'.
+xla::ComputationDataHandle DynamicAddSlice(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand,
+ const xla::ComputationDataHandle& update,
+ const gtl::ArraySlice<int64>& update_dims,
+ const xla::ComputationDataHandle& start_indices) {
+ xla::ComputationDataHandle current =
+ builder->DynamicSlice(operand, start_indices, update_dims);
+ xla::ComputationDataHandle sum = builder->Add(current, update);
+ return builder->DynamicUpdateSlice(operand, sum, start_indices);
+}
+
+class TensorArrayOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_shape", &element_shape_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ bool dynamic_size;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dynamic_size", &dynamic_size));
+ OP_REQUIRES(
+ ctx, !dynamic_size,
+ errors::Unimplemented(
+ "TensorArrays with dynamic size are not supported by XLA."));
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_array_name", &tensor_array_name_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ int64 size;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
+ OP_REQUIRES(ctx, size >= 0,
+ errors::InvalidArgument("TensorArray size must be >= 0"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+ b->set_die_immediately_on_error(true);
+
+ // Initializes the TensorArray value if we know the element shape.
+ // Otherwise, defer initialization to the first write.
+ xla::ComputationDataHandle value;
+ if (element_shape_.IsFullyDefined()) {
+ TensorShape shape;
+ CHECK(element_shape_.AsTensorShape(&shape));
+ TensorShape ta_shape;
+ ta_shape.AddDim(size);
+ ta_shape.AppendShape(shape);
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_);
+ value = b->Broadcast(zero, ta_shape.dim_sizes());
+ }
+
+ XlaContext& xc = XlaContext::Get(ctx);
+ XlaVariable* var;
+ string name = strings::StrCat("TensorArray: ", tensor_array_name_);
+ OP_REQUIRES_OK(ctx,
+ xc.CreateVariable(-1, std::move(name), dtype_, value, &var));
+ var->tensor_array_size = size;
+ ctx->SetVariableOutput(0, var);
+ ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
+ }
+
+ private:
+ PartialTensorShape element_shape_;
+ DataType dtype_;
+ string tensor_array_name_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp);
+
+class TensorArrayWriteOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayWriteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ TensorShape elem_shape = ctx->InputShape(2);
+
+ // Initializes the TensorArray, if the element shape was not known at
+ // construction time.
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+
+ xla::ComputationDataHandle ta = var->value;
+ xla::ComputationDataHandle index = ctx->Input(1);
+ xla::ComputationDataHandle value = ctx->Input(2);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
+
+ TensorShape slice_shape = elem_shape;
+ slice_shape.InsertDim(0, 1LL);
+ auto update = b->Reshape(value, slice_shape.dim_sizes());
+
+ xla::ComputationDataHandle written =
+ DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written));
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayWriteOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayWriteV3"), TensorArrayWriteOp);
+
+class TensorArrayReadOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayReadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument(
+ "TensorArray dtype is ", DataTypeString(ta_type),
+ " but Op requested dtype ", DataTypeString(dtype_), "."));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ xla::ComputationDataHandle index = ctx->Input(1);
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
+
+ auto slice_shape = ta_shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ xla::ComputationDataHandle read =
+ b->DynamicSlice(ta, start_indices, slice_shape);
+
+ // Remove the leading '1' dimension.
+ std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
+ ctx->SetOutput(0, b->Reshape(read, value_shape));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayReadOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayReadV3"), TensorArrayReadOp);
+
+class TensorArrayGatherOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument("TensorArray type mismatch"));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ const TensorShape indices_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, indices_shape.dims() >= 1,
+ errors::InvalidArgument("indices must be rank 1"));
+ const int num_indices = indices_shape.dim_size(0);
+ auto indices = ctx->Input(1);
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+
+ // For each index in `indices`, add the corresponding slice to `slices`.
+ std::vector<xla::ComputationDataHandle> slices(num_indices);
+ for (int i = 0; i < num_indices; ++i) {
+ // Slices the i-th index out of `indices`, and pads it with zeros in the
+ // minor dimensions to form an index into the TensorArray storage.
+ auto index = b->Slice(indices, {i}, {i + 1});
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
+
+ auto slice_shape = ta_shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ slices[i] = b->DynamicSlice(ta, start_indices, slice_shape);
+ }
+
+ xla::ComputationDataHandle gather;
+ if (slices.empty()) {
+ auto shape = ta_shape.dim_sizes();
+ shape[0] = 0;
+ gather = b->Broadcast(XlaHelpers::Zero(b, dtype_), shape);
+ } else {
+ gather = b->ConcatInDim(slices, 0);
+ }
+ ctx->SetOutput(0, gather);
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGatherOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayGatherV3"), TensorArrayGatherOp);
+
+class TensorArrayScatterOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayScatterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ const TensorShape value_shape = ctx->InputShape(2);
+
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ TensorShape elem_shape = value_shape;
+ elem_shape.RemoveDim(0);
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+
+ const TensorShape indices_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, indices_shape.dims() >= 1,
+ errors::InvalidArgument("indices must be rank 1"));
+ const int num_indices = indices_shape.dim_size(0);
+ const xla::ComputationDataHandle indices = ctx->Input(1);
+
+ xla::ComputationDataHandle ta = var->value;
+ const xla::ComputationDataHandle value = ctx->Input(2);
+
+ auto slice_dims = value_shape.dim_sizes();
+ slice_dims[0] = 1LL;
+
+ std::vector<int64> value_starts(value_shape.dims(), 0);
+ auto value_ends = value_shape.dim_sizes();
+
+ // For every (index, value) pair, update the corresponding TensorArray
+ // storage.
+ for (int i = 0; i < num_indices; ++i) {
+ // Slice out part of the value.
+ value_starts[0] = i;
+ value_ends[0] = i + 1;
+ auto slice = b->Slice(value, value_starts, value_ends);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto index = b->Slice(indices, {i}, {i + 1});
+ auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
+ ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
+ }
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayScatterOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayScatterV3"), TensorArrayScatterOp);
+
+class TensorArrayConcatOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument("TensorArray type mismatch"));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+
+ auto ta_dims = ta_shape.dim_sizes();
+ std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
+ shape[0] *= ta_shape.dim_size(0);
+ ctx->SetOutput(0, b->Reshape(ta, shape));
+
+ Tensor lengths(DT_INT64, {ta_dims[0]});
+ auto lengths_vec = lengths.vec<int64>();
+ for (int i = 0; i < ta_dims[0]; ++i) {
+ lengths_vec(i) = ta_dims[1];
+ }
+ ctx->SetConstantOutput(1, lengths);
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayConcatOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayConcatV3"), TensorArrayConcatOp);
+
+class TensorArraySplitOp : public XlaOpKernel {
+ public:
+ explicit TensorArraySplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<int64> lengths;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
+
+ int64 length = 0;
+ if (!lengths.empty()) {
+ length = lengths[0];
+ for (int i = 1; i < lengths.size(); ++i) {
+ OP_REQUIRES(ctx, lengths[i] == length,
+ errors::InvalidArgument("lengths must be equal: ", length,
+ " vs. ", lengths[i]));
+ }
+ }
+
+ TensorShape value_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, value_shape.dims() >= 1,
+ errors::InvalidArgument("value must have rank >= 1, got ",
+ value_shape.DebugString()));
+ TensorShape elem_shape = value_shape;
+ elem_shape.set_dim(0, length);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+ xla::ComputationDataHandle ta = var->value;
+
+ TensorShape ta_shape;
+ ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AppendShape(elem_shape);
+
+ OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size,
+ errors::InvalidArgument(
+ "TensorArray's size is not equal to the size of lengths (",
+ lengths.size(), " vs. ", var->tensor_array_size, ")"));
+
+ const xla::ComputationDataHandle value = ctx->Input(1);
+
+ OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
+ errors::InvalidArgument("mismatched element count ",
+ value_shape.DebugString(), " vs. ",
+ ta_shape.DebugString()));
+
+ ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp);
+
+class TensorArraySizeOp : public XlaOpKernel {
+ public:
+ explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ Tensor size_tensor(DT_INT32, {});
+ size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
+ ctx->SetConstantOutput(0, size_tensor);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySizeOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArraySizeV3"), TensorArraySizeOp);
+
+class TensorArrayGradOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("source", &source_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ // Finds or looks up the corresponding gradient TensorArray, which stores
+ // gradients computed during backpropagation.
+ XlaVariable*& gradient = var->tensor_array_gradient[source_];
+ if (!gradient) {
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type);
+ xla::ComputationDataHandle value =
+ b->Broadcast(zero, ta_shape.dim_sizes());
+
+ XlaContext& xc = XlaContext::Get(ctx);
+ string name = strings::StrCat("TensorArrayGrad: ", var->name);
+ OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type,
+ value, &gradient));
+ gradient->tensor_array_size = var->tensor_array_size;
+ }
+
+ ctx->SetVariableOutput(0, gradient);
+ ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
+ }
+
+ private:
+ string source_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index 362a101895..1d0098591e 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -119,6 +119,4 @@ void XlaExpression::set_constant_value(Tensor value) {
constant_value_ = std::move(value);
}
-void XlaExpression::set_variable_id(int id) { variable_id_ = id; }
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h
index 1ee96e5e6c..75630bee39 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.h
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h
@@ -64,6 +64,39 @@ class XlaCompilationDevice : public LocalDevice {
std::unique_ptr<XlaCompilationAllocator> allocator_;
};
+struct XlaVariable {
+ // If this variable is visible externally, what was its argument number?
+ int arg_num = -1;
+
+ // A descriptive name for the variable, used in error messages.
+ string name;
+
+ // Current type and value of the variable. Uninitialized variables are
+ // represented by a default (zero) handle and type DT_INVALID.
+ // While the type of a variable is notionally fixed during execution, when
+ // a variable is first initialized we do not yet know its type, so we keep
+ // track of its type dynamically.
+ DataType type = DT_INVALID;
+ xla::ComputationDataHandle value;
+
+ // Value of the variable at computation entry. Used to detect which
+ // variables have new values that need to be written back.
+ xla::ComputationDataHandle initial_value;
+
+ // We treat TensorArrays as a Variable with some extra metadata.
+
+ // 'tensor_array_size' stores the expected size of the TensorArray. We need
+ // to store this since sometimes TensorArrays must be initialized lazily since
+ // we do not know the element shape at construction time.
+ int64 tensor_array_size = -1;
+
+ // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
+ // to an XlaVariable containing the gradient TensorArrays. We store a pointer
+ // here since there should only be one gradient TensorArray per 'source'
+ // string, irrespective of the number of calls to TensorArrayGrad.
+ std::unordered_map<string, XlaVariable*> tensor_array_gradient;
+};
+
// A XlaExpression wraps an XLA computation. Each Tensor on an
// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
// matches the shape of the subcomputation in the ComputationDataHandle. Each
@@ -82,8 +115,8 @@ class XlaExpression {
bool has_constant_value() const { return has_constant_value_; }
const Tensor& constant_value() const { return constant_value_; }
- void set_variable_id(int id);
- int variable_id() const { return variable_id_; }
+ void set_variable(XlaVariable* variable) { variable_ = variable; }
+ XlaVariable* variable() const { return variable_; }
private:
// The XLA handle of the expression's computation.
@@ -95,7 +128,7 @@ class XlaExpression {
bool has_constant_value_ = false;
Tensor constant_value_;
- int variable_id_ = -1;
+ XlaVariable* variable_ = nullptr; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 820e8dd56f..580ce3d802 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -59,8 +59,9 @@ Status CheckSignature(const DataTypeVector& types,
bool XlaCompiler::Argument::operator==(
const XlaCompiler::Argument& other) const {
- if (std::tie(kind, type, shape, name) !=
- std::tie(other.kind, other.type, other.shape, other.name)) {
+ if (std::tie(kind, type, shape, name, tensor_array_size) !=
+ std::tie(other.kind, other.type, other.shape, other.name,
+ other.tensor_array_size)) {
return false;
}
if (constant_value.shape() != other.constant_value.shape()) {
@@ -264,8 +265,9 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
switch (args[i].kind) {
case XlaCompiler::Argument::kVariable:
variables.push_back(i);
- context_arg.value.is_constant = false;
context_arg.is_variable = true;
+ context_arg.value.is_constant = false;
+ context_arg.tensor_array_size = args[i].tensor_array_size;
break;
case XlaCompiler::Argument::kParameter:
parameters.push_back(i);
@@ -274,6 +276,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
case XlaCompiler::Argument::kUninitializedVariable:
context_arg.is_variable = true;
context_arg.value.is_constant = true;
+ context_arg.tensor_array_size = args[i].tensor_array_size;
break;
case XlaCompiler::Argument::kConstant:
context_arg.value.is_constant = true;
@@ -337,7 +340,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// type of the final output.
Status BuildComputation(
const std::vector<XlaContext::HandleOrConstant>& retvals,
- const std::unordered_map<int, XlaContext::Variable>& variable_map,
+ const std::vector<std::unique_ptr<XlaVariable>>& variables,
bool has_side_effects, bool return_updated_values_for_all_variables,
xla::ComputationBuilder* builder, xla::Computation* computation,
int* num_nonconst_outputs,
@@ -352,27 +355,27 @@ Status BuildComputation(
*num_nonconst_outputs = elems.size();
// Add return values for variables whose values have changed.
- std::vector<std::pair<int, const XlaContext::Variable*>> variables;
- variables.reserve(variable_map.size());
- for (const auto& entry : variable_map) {
- variables.emplace_back(entry.first, &entry.second);
+ std::vector<const XlaVariable*> arg_vars;
+ arg_vars.reserve(variables.size());
+ for (const auto& var : variables) {
+ if (var->arg_num >= 0) {
+ arg_vars.push_back(var.get());
+ }
}
- std::sort(variables.begin(), variables.end(),
- [](const std::pair<int, const XlaContext::Variable*>& a,
- const std::pair<int, const XlaContext::Variable*>& b) {
- return a.first < b.first;
+ std::sort(arg_vars.begin(), arg_vars.end(),
+ [](const XlaVariable* a, const XlaVariable* b) {
+ return a->arg_num < b->arg_num;
});
- for (const auto& entry : variables) {
- bool modified =
- entry.second->value.handle() != entry.second->initial_value.handle();
+ for (const XlaVariable* var : arg_vars) {
+ bool modified = var->value.handle() != var->initial_value.handle();
if (return_updated_values_for_all_variables || modified) {
variable_updates->emplace_back();
XlaCompiler::VariableUpdate& update = variable_updates->back();
- update.input_index = entry.first;
- update.type = entry.second->type;
+ update.input_index = var->arg_num;
+ update.type = var->type;
update.modified = modified;
- elems.push_back(entry.second->value);
+ elems.push_back(var->value);
}
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 15f723ad78..1314305532 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -114,6 +114,10 @@ class XlaCompiler {
// The name of this argument, used for debugging.
string name;
+ // For a kVariable or kUninitializedVariable corresponding to a TensorArray,
+ // what is the tensor array's declared size?
+ int64 tensor_array_size = -1;
+
bool operator==(const Argument& other) const;
};
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 3592680303..4440b53069 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
@@ -53,6 +54,10 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
return *context;
}
+/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) {
+ return Get(ctx->op_kernel_context());
+}
+
void XlaContext::set_args(std::vector<Argument> args) {
args_ = std::move(args);
}
@@ -124,29 +129,19 @@ void XlaContext::AddSideEffects() {
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
-Status XlaContext::CreateVariable(int variable_id, string name, DataType type,
- const xla::ComputationDataHandle& handle) {
- auto result = variables_.emplace(variable_id, Variable());
- if (!result.second) {
- return errors::InvalidArgument("Duplicate ID ", variable_id,
- " for variable ", name);
- }
- Variable& var = result.first->second;
+Status XlaContext::CreateVariable(int arg_num, string name, DataType type,
+ const xla::ComputationDataHandle& handle,
+ XlaVariable** variable) {
+ variables_.emplace_back(new XlaVariable);
+ *variable = variables_.back().get();
+ XlaVariable& var = **variable;
+ var.arg_num = arg_num;
var.name = std::move(name);
var.type = type;
var.initial_value = var.value = handle;
return Status::OK();
}
-Status XlaContext::GetVariable(int variable_id, Variable** variable) {
- auto it = variables_.find(variable_id);
- if (it == variables_.end()) {
- return errors::InvalidArgument("Unknown variable ID ", variable_id);
- }
- *variable = &it->second;
- return Status::OK();
-}
-
const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
return LookupOrCreate(type, &max_func_, [this, type] {
const string type_string = DataTypeString(type);
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 657ead5391..3978baaf63 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -21,7 +21,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -31,6 +30,8 @@ limitations under the License.
namespace tensorflow {
+class XlaOpKernelContext;
+
// The XlaContext is the data structure that holds the state of an XLA
// compilation, that is accessible from OpKernelContexts when compiling a
// subgraph of Ops using XLA.
@@ -55,16 +56,16 @@ class XlaContext : public ResourceBase {
string name;
// Is this a variable?
- bool is_variable;
+ bool is_variable = false;
HandleOrConstant value;
+
+ int64 tensor_array_size = -1;
};
// Retrieves the XlaContext of the current compilation.
static XlaContext& Get(const OpKernelContext* ctx);
- static XlaContext& Get(const XlaOpKernelContext* ctx) {
- return Get(ctx->op_kernel_context());
- }
+ static XlaContext& Get(const XlaOpKernelContext* ctx);
// Creates a new XlaContext.
XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
@@ -105,33 +106,16 @@ class XlaContext : public ResourceBase {
bool has_side_effects() const { return has_side_effects_; }
- struct Variable {
- // A descriptive name for the variable, used in error messages.
- string name;
-
- // Current type and value of the variable. Uninitialized variables are
- // represented by a default (zero) handle and type DT_INVALID.
- // While the type of a variable is notionally fixed during execution, when
- // a variable is first initialized we do not yet know its type, so we keep
- // track of its type dynamically.
- DataType type = DT_INVALID;
- xla::ComputationDataHandle value;
-
- // Value of the variable at computation entry. Used to detect which
- // variables have new values that need to be written back.
- xla::ComputationDataHandle initial_value;
- };
-
// Creates a variable with variable `variable_id` and initial type `type` and
// value `handle`. `name` is a descriptive name for use in error messages.
// Fails if the variable already exists.
- Status CreateVariable(int variable_id, string name, DataType type,
- const xla::ComputationDataHandle& handle);
+ Status CreateVariable(int arg_num, string name, DataType type,
+ const xla::ComputationDataHandle& handle,
+ XlaVariable** variable);
- // Retrieves variable `variable_id`. Fails if the variable does not exist.
- Status GetVariable(int variable_id, Variable** variable);
-
- const std::unordered_map<int, Variable>& variables() { return variables_; }
+ const std::vector<std::unique_ptr<XlaVariable>>& variables() {
+ return variables_;
+ }
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@@ -182,8 +166,8 @@ class XlaContext : public ResourceBase {
// Does the computation have side effects, i.e., Send() calls?
bool has_side_effects_ = false;
- // Map from variable ID to the current value of each variable.
- std::unordered_map<int, Variable> variables_;
+ // Holds ownership of variables. The variables are not ordered.
+ std::vector<std::unique_ptr<XlaVariable>> variables_;
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::Computation>;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 4de69ee43c..3272b1efa1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -38,7 +38,8 @@ xla::ComputationBuilder* XlaOpKernelContext::builder() const {
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
- CHECK(expression->handle().handle() != 0 || expression->variable_id() >= 0);
+ CHECK(expression->handle().handle() != 0 ||
+ expression->variable() != nullptr);
VLOG(1) << "Fetched T" << expression->handle().handle();
return expression;
}
@@ -251,11 +252,8 @@ Status XlaOpKernelContext::ReadVariableInput(
int index, xla::ComputationDataHandle* value) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
- int variable_id = expression->variable_id();
-
- XlaContext::Variable* variable;
- XlaContext& context = XlaContext::Get(this);
- TF_RETURN_IF_ERROR(context.GetVariable(variable_id, &variable));
+ XlaVariable* variable = expression->variable();
+ TF_RET_CHECK(variable != nullptr);
if (variable->value.handle() == 0) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name);
@@ -267,11 +265,8 @@ Status XlaOpKernelContext::ReadVariableInput(
string XlaOpKernelContext::VariableDebugString(int index) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
- int variable_id = expression->variable_id();
-
- XlaContext::Variable* variable;
- XlaContext& context = XlaContext::Get(this);
- if (!context.GetVariable(variable_id, &variable).ok()) {
+ XlaVariable* variable = expression->variable();
+ if (!variable) {
return "<invalid variable ID>";
}
return variable->name;
@@ -281,11 +276,8 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
- int variable_id = expression->variable_id();
-
- XlaContext::Variable* variable;
- XlaContext& context = XlaContext::Get(this);
- TF_RETURN_IF_ERROR(context.GetVariable(variable_id, &variable));
+ XlaVariable* variable = expression->variable();
+ TF_RET_CHECK(variable != nullptr);
if (variable->value.handle() == 0) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name);
@@ -345,14 +337,22 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
expression->set_constant_value(constant);
}
-void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) {
+void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) {
Tensor* output = nullptr;
// The shape of the output tensor is the shape of the variable resource
// (i.e., a scalar), not the shape of the variable's value.
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape(), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
- expression->set_variable_id(variable_id);
+ expression->set_variable(variable);
+}
+
+Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) {
+ const XlaExpression* expression =
+ CastExpressionFromTensor(context_->input(index));
+ TF_RET_CHECK(expression->variable() != nullptr);
+ *variable = expression->variable();
+ return Status::OK();
}
Status XlaOpKernelContext::AssignVariable(
@@ -362,9 +362,8 @@ Status XlaOpKernelContext::AssignVariable(
const XlaExpression* expression =
CastExpressionFromTensor(context_->input(index));
- XlaContext& context = XlaContext::Get(this);
- XlaContext::Variable* variable;
- TF_RETURN_IF_ERROR(context.GetVariable(expression->variable_id(), &variable));
+ XlaVariable* variable = expression->variable();
+ TF_RET_CHECK(variable != nullptr);
if (!((variable->type == DT_INVALID && type != DT_INVALID) ||
(variable->type == type))) {
return errors::InvalidArgument(
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 0a8a928418..a25774c3a6 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -157,15 +157,18 @@ class XlaOpKernelContext {
// 'index'.
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
- // Sets output 'index' to be a reference to variable 'variable_id'. Used
- // to propagate resource variables through the compilation.
- void SetVariableOutput(int index, int variable_id);
-
// Assigns the value `handle` to the variable referenced by input
// `variable_index`. Marks the operator as having side effects.
Status AssignVariable(int variable_index, DataType type,
const xla::ComputationDataHandle& handle);
+ // Sets '*variable' to the variable associated with input `index`.
+ Status GetVariableInput(int index, XlaVariable** variable);
+
+ // Sets output 'index' to be a reference to variable 'variable'. Used
+ // to propagate resource variables through the compilation.
+ void SetVariableOutput(int index, XlaVariable* variable);
+
// Returns a human-readable debug string describing 'variable_index'.
string VariableDebugString(int variable_index);
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 8f6a70ffff..64e58e32fb 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -1740,6 +1740,7 @@ Status Literal::Populate(
stride_config.dimensions, stride_config.step,
init_function);
} else {
+ // For scalars.
data.at(0) = generator({});
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index aaab36dc8c..50ea286b53 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -806,7 +806,9 @@ TEST_F(LiteralUtilTest, Populate) {
std::vector<int64> layout;
} populate_data[] = {
{{}, {}},
+ {{0}, {0}},
{{16}, {0}},
+ {{2, 0}, {1, 0}},
{{4, 16}, {1, 0}},
{{21, 12}, {0, 1}},
{{6, 11, 17}, {2, 0, 1}},
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 590c6a2491..cc147f5062 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -104,10 +104,12 @@ cc_test(
":hlo_evaluator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
@@ -866,9 +868,10 @@ cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
- "//tensorflow/core:test_main",
+ "//tensorflow/core:test",
],
)
@@ -1489,7 +1492,9 @@ cc_library(
hdrs = ["hlo_constant_folding.h"],
deps = [
":hlo",
+ ":hlo_evaluator",
":hlo_pass",
+ ":hlo_query",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index e0915e3526..19583433db 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
namespace op = xla::testing::opcode_matchers;
@@ -59,7 +61,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
@@ -82,7 +84,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
@@ -105,7 +107,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
@@ -127,7 +129,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
@@ -149,7 +151,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, div);
@@ -171,7 +173,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, div);
@@ -199,7 +201,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, add);
@@ -225,7 +227,7 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -250,7 +252,7 @@ TEST_F(AlgebraicSimplifierTest, LnExp) {
builder.AddInstruction(
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0)));
@@ -279,7 +281,7 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
builder.AddInstruction(
HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -304,7 +306,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
@@ -329,7 +331,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
@@ -359,7 +361,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, one));
@@ -382,7 +384,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, two));
@@ -405,7 +407,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
param0, negative_one));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one));
@@ -434,7 +436,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
auto computation = builder.Build();
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
EXPECT_THAT(module->entry_computation()->root_instruction(),
@@ -455,7 +457,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
@@ -476,7 +478,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
@@ -497,7 +499,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
builder.AddInstruction(
HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0));
@@ -527,7 +529,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(
@@ -558,7 +560,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, empty_slice}, 0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -581,7 +583,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
HloInstruction* copy = builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
// Set to different layouts.
@@ -608,7 +610,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
HloInstruction* copy = builder.AddInstruction(
HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
// Set to same layouts.
@@ -640,7 +642,7 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
*reshape->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
@@ -686,7 +688,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
builder.AddInstruction(HloInstruction::CreateTuple(
{transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -716,7 +718,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
HloOpcode::kMaximum, movable_reshape, zero));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -744,7 +746,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
*transpose->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({0, 1, 2, 3});
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
@@ -771,7 +773,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
*transpose->mutable_shape()->mutable_layout() =
LayoutUtil::MakeLayout({3, 1, 2, 0});
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
@@ -797,7 +799,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -825,7 +827,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
HloOpcode::kCopy, copy1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0)));
@@ -850,7 +852,7 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1));
@@ -874,7 +876,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -897,7 +899,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -919,7 +921,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -942,7 +944,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -966,7 +968,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -992,7 +994,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
@@ -1697,7 +1699,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
builder.AddInstruction(
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEmbeddedComputation(std::move(dot_computation));
module->AddEntryComputation(call_builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -1707,3 +1709,20 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
} // namespace
} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index cb0a99d773..762ceebf39 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -24,230 +24,57 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
-namespace {
-
-template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-static std::unique_ptr<Literal> ConvertIfTypesMatch(
- const Literal& src_literal) {
- CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
- return LiteralUtil::Convert<
- typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
- typename primitive_util::PrimitiveTypeToNative<
- primitive_dest_type>::type>(src_literal);
-}
-
-template <PrimitiveType primitive_src_type>
-static std::unique_ptr<Literal> ConvertIfDestTypeMatches(
- const Literal& src_literal, PrimitiveType primitive_dest_type) {
- switch (primitive_dest_type) {
-#define CONVERT_IF_TYPES_MATCH(type) \
- case (type): \
- return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
- CONVERT_IF_TYPES_MATCH(PRED)
- CONVERT_IF_TYPES_MATCH(S8)
- CONVERT_IF_TYPES_MATCH(S32)
- CONVERT_IF_TYPES_MATCH(S64)
- CONVERT_IF_TYPES_MATCH(U8)
- CONVERT_IF_TYPES_MATCH(U32)
- CONVERT_IF_TYPES_MATCH(U64)
- CONVERT_IF_TYPES_MATCH(F32)
- CONVERT_IF_TYPES_MATCH(F64)
-#undef CONVERT_IF_TYPES_MATCH
- // Other types are not yet supported.
- default:
- LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
- << PrimitiveType_Name(src_literal.shape().element_type());
- }
-}
-
-static std::unique_ptr<Literal> ConvertIfSrcTypeMatches(
- const Literal& src_literal, PrimitiveType primitive_dest_type) {
- switch (src_literal.shape().element_type()) {
-#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
- case (type): \
- return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
- CONVERT_IF_DEST_TYPE_MATCHES(PRED)
- CONVERT_IF_DEST_TYPE_MATCHES(S8)
- CONVERT_IF_DEST_TYPE_MATCHES(S32)
- CONVERT_IF_DEST_TYPE_MATCHES(S64)
- CONVERT_IF_DEST_TYPE_MATCHES(U8)
- CONVERT_IF_DEST_TYPE_MATCHES(U32)
- CONVERT_IF_DEST_TYPE_MATCHES(U64)
- CONVERT_IF_DEST_TYPE_MATCHES(F32)
- CONVERT_IF_DEST_TYPE_MATCHES(F64)
-#undef CONVERT_IF_DEST_TYPE_MATCHES
- // Other types are not yet supported.
- default:
- LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
- << PrimitiveType_Name(src_literal.shape().element_type());
- }
-}
-
-} // namespace
-
-// ConstantFolderVisitor traverses the HLO computation and reduces certain
-// constant graph sections, to literals.
-class ConstantFolderVisitor : public DfsHloVisitorWithDefault {
- public:
- // Default visitor action is to do nothing and return OK.
- Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
- return Status::OK();
- }
-
- Status HandleConcatenate(
- HloInstruction* concatenate,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
-
- Status HandleConvert(HloInstruction* convert,
- HloInstruction* operand) override;
-
- Status HandleReshape(HloInstruction* reshape) override;
-
- Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
-
- Status HandleTranspose(HloInstruction* transpose) override;
-
- // Returns whether a constant folding operation has occurred.
- const bool changed() const { return changed_; }
-
- // Runs the visitor on a computation and returns whether any changes were
- // performed.
- static StatusOr<bool> Run(HloComputation* computation);
-
- private:
- ConstantFolderVisitor() = default;
-
- // Replaces the existing HLO instruction old_instruction, with a literal,
- // and marks the optimizer status as changed.
- // Returns the Status representing the result of the replace operation.
- Status ReplaceWithConstant(HloInstruction* old_instruction,
- std::unique_ptr<Literal> literal) {
- TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
- old_instruction, HloInstruction::CreateConstant(std::move(literal))));
- changed_ = true;
- return Status::OK();
- }
-
- // Whether any constant folding operations have occurred.
- bool changed_ = false;
-};
-
-StatusOr<bool> ConstantFolderVisitor::Run(HloComputation* computation) {
- ConstantFolderVisitor visitor;
- TF_RETURN_IF_ERROR(computation->Accept(&visitor));
- return visitor.changed();
-}
StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
+ auto evaluator = MakeUnique<HloEvaluator>();
+
XLA_VLOG_LINES(2,
"HloConstantFolding::Run(), before:\n" + module->ToString());
bool changed = false;
- for (auto& comp : module->computations()) {
- TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get()));
- changed = changed || result;
- }
- XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
- return changed;
-}
-
-Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) {
- if (reshape->operand(0)->opcode() == HloOpcode::kConstant) {
- TF_ASSIGN_OR_RETURN(
- auto reshaped_literal,
- LiteralUtil::Reshape(reshape->operand(0)->literal(),
- AsInt64Slice(reshape->shape().dimensions())));
- return ReplaceWithConstant(reshape, std::move(reshaped_literal));
- }
- return Status::OK();
-}
-Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) {
- if (transpose->operand(0)->opcode() == HloOpcode::kConstant) {
- auto transposed_literal = LiteralUtil::Transpose(
- transpose->operand(0)->literal(), transpose->dimensions());
- return ReplaceWithConstant(transpose, std::move(transposed_literal));
- }
- return Status::OK();
-}
+ for (auto& computation : module->computations()) {
+ for (auto instruction : computation->MakeInstructionPostOrder()) {
+ // Skip dead code.
+ if (instruction->user_count() == 0 &&
+ computation->root_instruction() != instruction) {
+ continue;
+ }
+ // Skip Constant and Parameter operation.
+ if (instruction->opcode() == HloOpcode::kParameter ||
+ instruction->opcode() == HloOpcode::kConstant) {
+ continue;
+ }
+ // Skip instructions with non-constant operands.
+ if (!hlo_query::AllOperandsAreConstants(*instruction)) {
+ continue;
+ }
-Status ConstantFolderVisitor::HandleConcatenate(
- HloInstruction* concatenate,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
- if (operands[0]->opcode() == HloOpcode::kConstant) {
- // If all the operands of a concatenate are constant, fold them into a
- // single constant tensor.
- // The result concatenate dimension is going to be the sum of all the
- // concatenate dimensions of the arrays taking part of the operation.
- int64 concat_dim = concatenate->dimensions()[0];
- const Shape& reference_shape = operands[0]->shape();
- CHECK(!ShapeUtil::IsTuple(reference_shape));
- int64 rank = ShapeUtil::Rank(reference_shape);
- std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(),
- reference_shape.dimensions().end());
- if (concat_dim < 0) {
- concat_dim += rank;
- }
- for (int64 i = 1; i < operands.size(); ++i) {
- const Shape& operand_shape = operands[i]->shape();
- CHECK(!ShapeUtil::IsTuple(operand_shape));
- if (operands[i]->opcode() != HloOpcode::kConstant) {
- return Status::OK();
+ std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
+ // Currently we skip unimplemented operations.
+ // TODO(b/35975797): Fold constant computations for more operations.
+ if (result == nullptr) {
+ VLOG(2) << "Constant folding failed for instruction: "
+ << instruction->ToString();
+ continue;
}
- // Accumulate the concat dimension from all tensors taking part to the
- // operation.
- concat_dimensions[concat_dim] +=
- ShapeUtil::GetDimension(operand_shape, concat_dim);
- }
- auto literal = LiteralUtil::CreateFromDimensions(
- reference_shape.element_type(), concat_dimensions);
- std::vector<int64> source_indices(rank, 0);
- std::vector<int64> dest_indices(concat_dimensions.size(), 0);
- for (auto operand : operands) {
- const Shape& operand_shape = operand->shape();
- TF_RETURN_IF_ERROR(LiteralUtil::Copy(
- operand->literal(), source_indices, literal.get(), dest_indices,
- AsInt64Slice(operand_shape.dimensions())));
- dest_indices[concat_dim] +=
- ShapeUtil::GetDimension(operand_shape, concat_dim);
+ TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+ instruction, HloInstruction::CreateConstant(std::move(result))));
+ changed = true;
}
- return ReplaceWithConstant(concatenate, std::move(literal));
- }
- return Status::OK();
-}
-
-Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice,
- HloInstruction* operand) {
- if (operand->opcode() == HloOpcode::kConstant) {
- const Shape& shape = slice->shape();
- auto literal = LiteralUtil::CreateFromDimensions(
- shape.element_type(), AsInt64Slice(shape.dimensions()));
- std::vector<int64> dest_indices(slice->slice_starts().size(), 0);
- TF_RETURN_IF_ERROR(LiteralUtil::Copy(
- operand->literal(), slice->slice_starts(), literal.get(), dest_indices,
- AsInt64Slice(shape.dimensions())));
- TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal)));
}
- return Status::OK();
-}
-
-Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert,
- HloInstruction* operand) {
- if (operand->opcode() == HloOpcode::kConstant) {
- const Literal& src_literal = operand->literal();
- std::unique_ptr<Literal> new_constant =
- ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type());
- return ReplaceWithConstant(convert, std::move(new_constant));
- }
- return Status::OK();
+ XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
+ return changed;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index e0447d69aa..3e7f5b1f3d 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -46,6 +46,89 @@ limitations under the License.
namespace xla {
+namespace {
+
+template <typename OperandT>
+StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
+ const Literal& lhs_literal,
+ const Literal& rhs_literal) {
+ std::function<bool(OperandT, OperandT)> compare_op;
+ switch (opcode) {
+ case HloOpcode::kEq:
+ compare_op = [](OperandT lhs_el, OperandT rhs_el) {
+ return lhs_el == rhs_el;
+ };
+ break;
+ case HloOpcode::kNe:
+ compare_op = [](OperandT lhs_el, OperandT rhs_el) {
+ return lhs_el != rhs_el;
+ };
+ break;
+ case HloOpcode::kGe:
+ compare_op = [](OperandT lhs_el, OperandT rhs_el) {
+ return lhs_el >= rhs_el;
+ };
+ break;
+ case HloOpcode::kGt:
+ compare_op = [](OperandT lhs_el, OperandT rhs_el) {
+ return lhs_el > rhs_el;
+ };
+ break;
+ case HloOpcode::kLe:
+ compare_op = [](OperandT lhs_el, OperandT rhs_el) {
+ return lhs_el <= rhs_el;
+ };
+ break;
+ case HloOpcode::kLt:
+ compare_op = [](OperandT lhs_el, OperandT rhs_el) {
+ return lhs_el < rhs_el;
+ };
+ break;
+ default:
+ LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
+ << HloOpcodeString(opcode);
+ }
+
+ auto result = LiteralUtil::CreateFromShape(shape);
+ TF_RETURN_IF_ERROR(LiteralUtil::Populate<bool>(
+ result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return compare_op(LiteralUtil::Get<OperandT>(lhs_literal, multi_index),
+ LiteralUtil::Get<OperandT>(rhs_literal, multi_index));
+ }));
+
+ return std::move(result);
+}
+
+template <typename ReturnT, typename NativeT>
+StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
+ HloInstruction* instruction,
+ const std::function<ReturnT(NativeT)>& unary_op,
+ const Literal& operand_literal) {
+ const auto shape = instruction->shape();
+ const auto* operand = instruction->operand(0);
+
+ // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
+ // removed.
+ if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
+ return Unimplemented(
+ "Implicit broadcasting is currently unsupported in HLO evaluator "
+ "Shape Mismatch: %s vs %s",
+ ShapeUtil::HumanString(shape).c_str(),
+ ShapeUtil::HumanString(operand->shape()).c_str());
+ }
+
+ auto result = LiteralUtil::CreateFromShape(shape);
+
+ TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
+ result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return unary_op(
+ LiteralUtil::Get<NativeT>(operand_literal, multi_index));
+ }));
+ return std::move(result);
+}
+
+} // namespace
+
template <typename ReturnT>
class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
public:
@@ -68,7 +151,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return elem_operand;
}));
return Status::OK();
- };
+ }
template <
typename NativeT,
@@ -79,7 +162,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return std::abs(elem_operand);
}));
return Status::OK();
- };
+ }
Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override {
return HandleAbs<ReturnT>(abs, operand);
@@ -101,6 +184,45 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
+ template <PrimitiveType src_type, PrimitiveType dest_type>
+ std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
+ DCHECK_EQ(src_type, src_literal.shape().element_type());
+ return LiteralUtil::Convert<
+ typename primitive_util::PrimitiveTypeToNative<src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<dest_type>::type>(
+ src_literal);
+ }
+
+ Status HandleConvert(HloInstruction* convert,
+ HloInstruction* operand) override {
+ auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+
+ switch (operand->shape().element_type()) {
+#define CONVERT_IF_TYPES_MATCH(src_type) \
+ case (src_type): \
+ parent_->evaluated_[convert] = LiteralUtil::Convert< \
+ typename primitive_util::PrimitiveTypeToNative<src_type>::type, \
+ ReturnT>(operand_literal); \
+ break;
+ CONVERT_IF_TYPES_MATCH(PRED)
+ CONVERT_IF_TYPES_MATCH(S8)
+ CONVERT_IF_TYPES_MATCH(S32)
+ CONVERT_IF_TYPES_MATCH(S64)
+ CONVERT_IF_TYPES_MATCH(U8)
+ CONVERT_IF_TYPES_MATCH(U32)
+ CONVERT_IF_TYPES_MATCH(U64)
+ CONVERT_IF_TYPES_MATCH(F32)
+ CONVERT_IF_TYPES_MATCH(F64)
+#undef CONVERT_IF_TYPES_MATCH
+ // Other types are not yet supported.
+ default:
+ LOG(FATAL) << "unimplemented operand type for HandleCovert: "
+ << PrimitiveType_Name(operand->shape().element_type());
+ }
+
+ return Status::OK();
+ }
+
Status HandleExp(HloInstruction* exp, HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
ElementWiseUnaryOp(exp, [](ReturnT elem_operand) {
@@ -117,15 +239,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
- Status HandleIsFinite(HloInstruction* is_finite,
- HloInstruction* operand) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite],
- ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) {
- return std::isfinite(elem_operand);
- }));
- return Status::OK();
- };
-
Status HandleLog(HloInstruction* log, HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ReturnT elem_operand) {
@@ -209,77 +322,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
- Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
- HloInstruction* lhs, HloInstruction* rhs) override {
- std::function<bool(ReturnT, ReturnT)> compare_op;
- switch (opcode) {
- case HloOpcode::kEq:
- compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
- return lhs_el == rhs_el;
- };
- break;
- case HloOpcode::kNe:
- compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
- return lhs_el != rhs_el;
- };
- break;
- case HloOpcode::kGe:
- compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
- return lhs_el >= rhs_el;
- };
- break;
- case HloOpcode::kGt:
- compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
- return lhs_el > rhs_el;
- };
- break;
- case HloOpcode::kLe:
- compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
- return lhs_el <= rhs_el;
- };
- break;
- case HloOpcode::kLt:
- compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
- return lhs_el < rhs_el;
- };
- break;
- default:
- LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
- << HloOpcodeString(opcode);
- }
-
- // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
- // removed.
- if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
- ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
- return Unimplemented(
- "Compare operation with mismatched dimensions, likely due to "
- "broadcasting is unsupported.");
- }
-
- const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
- const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
-
- auto result = LiteralUtil::CreateFromShape(compare->shape());
- std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
- do {
- LiteralUtil::Set<bool>(
- result.get(), multi_index,
- compare_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
- LiteralUtil::Get<ReturnT>(rhs_literal, multi_index)));
- } while (IndexUtil::BumpIndices(result->shape(), &multi_index));
-
- parent_->evaluated_[compare] = std::move(result);
-
- return Status::OK();
- };
-
Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
HloInstruction* rhs) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[maximum],
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
- return std::max(lhs, rhs);
+ return std::fmax(lhs, rhs);
}));
return Status::OK();
};
@@ -289,7 +337,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[minimum],
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
- return std::min(lhs_el, rhs_el);
+ return std::fmin(lhs_el, rhs_el);
}));
return Status::OK();
};
@@ -309,7 +357,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[remainder],
ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) {
- return std::remainder(lhs_el, rhs_el);
+ return std::fmod(lhs_el, rhs_el);
}));
return Status::OK();
};
@@ -338,7 +386,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
HloInstruction* arg, HloInstruction* max) override {
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
[](ReturnT low, ReturnT high, ReturnT value) {
- return std::max(low, std::min(value, high));
+ return std::fmax(low, std::fmin(value, high));
};
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
ElementWiseTernaryOp(clamp, std::move(clamp_op)));
@@ -370,32 +418,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ReturnT(ReturnT)>& unary_op) {
- const auto shape = instruction->shape();
- const auto* operand = instruction->operand(0);
-
- // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
- // removed.
- if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
- return Unimplemented(
- "Implicit broadcasting is currently unsupported in HLO evaluator "
- "Shape Mismatch: %s vs %s",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(operand->shape()).c_str());
- }
-
- const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
-
- auto result = LiteralUtil::CreateFromShape(shape);
-
- std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
- do {
- LiteralUtil::Set<ReturnT>(
- result.get(), multi_index,
- unary_op(LiteralUtil::Get<ReturnT>(operand_literal, multi_index)));
- } while (IndexUtil::BumpIndices(result->shape(), &multi_index));
-
- return std::move(result);
- };
+ const Literal& operand_literal =
+ parent_->GetEvaluatedLiteralFor(instruction->operand(0));
+ return ElementWiseUnaryOpImpl<ReturnT, ReturnT>(instruction, unary_op,
+ operand_literal);
+ }
StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
HloInstruction* instruction,
@@ -420,16 +447,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
auto result = LiteralUtil::CreateFromShape(shape);
- std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
- do {
- LiteralUtil::Set<ReturnT>(
- result.get(), multi_index,
- binary_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
- LiteralUtil::Get<ReturnT>(rhs_literal, multi_index)));
- } while (IndexUtil::BumpIndices(result->shape(), &multi_index));
+ TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
+ result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return binary_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
+ LiteralUtil::Get<ReturnT>(rhs_literal, multi_index));
+ }));
return std::move(result);
- };
+ }
template <typename LhsType, typename RhsType, typename EhsType>
StatusOr<std::unique_ptr<Literal>> ElementWiseTernaryOp(
@@ -459,17 +484,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
auto result = LiteralUtil::CreateFromShape(shape);
- std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
- do {
- LiteralUtil::Set<ReturnT>(
- result.get(), multi_index,
- ternary_op(LiteralUtil::Get<LhsType>(lhs_literal, multi_index),
- LiteralUtil::Get<RhsType>(rhs_literal, multi_index),
- LiteralUtil::Get<EhsType>(ehs_literal, multi_index)));
- } while (IndexUtil::BumpIndices(result->shape(), &multi_index));
+
+ TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
+ result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return ternary_op(
+ LiteralUtil::Get<LhsType>(lhs_literal, multi_index),
+ LiteralUtil::Get<RhsType>(rhs_literal, multi_index),
+ LiteralUtil::Get<EhsType>(ehs_literal, multi_index));
+ }));
return std::move(result);
- };
+ }
HloEvaluator* parent_;
};
@@ -493,6 +518,12 @@ HloEvaluator::HloEvaluator() {
});
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
+ typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented("unhandled primitive type: TUPLE.");
+ });
+ typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented("unhandled primitive type: OPAQUE.");
+ });
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
@@ -502,15 +533,15 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
evaluated_.clear();
TF_RETURN_IF_ERROR(computation->Accept(this));
- return std::move(FindOrDie(evaluated_, computation->root_instruction()));
+ return MakeUnique<Literal>(
+ GetEvaluatedLiteralFor(computation->root_instruction()));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> operands) {
- DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
- Shape shape = instruction->shape();
- TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
+ TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
arg_literals_ = operands;
evaluated_.clear();
@@ -525,13 +556,34 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
evaluated_[operand] = MakeUnique<Literal>(*input_literal);
- } else if (operand->opcode() == HloOpcode::kConstant) {
- evaluated_[operand] = MakeUnique<Literal>(operand->literal());
}
}
TF_RETURN_IF_ERROR(instruction->Visit(this));
- return std::move(FindOrDie(evaluated_, instruction));
+ return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
+}
+
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+ HloInstruction* instruction) {
+ TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction));
+ TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter);
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
+
+ arg_literals_.clear();
+ evaluated_.clear();
+ TF_RETURN_IF_ERROR(instruction->Visit(this));
+ return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
+}
+
+std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
+ HloInstruction* instruction) {
+ auto result_or = Evaluate(instruction);
+ if (!result_or.ok()) {
+ VLOG(1) << "TryEvaluate failed:" << result_or.status();
+ return nullptr;
+ }
+
+ return result_or.ConsumeValueOrDie();
}
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
@@ -548,9 +600,191 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
Status HloEvaluator::HandleConstant(HloInstruction* constant,
const Literal& literal) {
VLOG(2) << "HandleConstant: " << constant->ToString();
- DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape()));
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[reshape],
+ LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)),
+ AsInt64Slice(reshape->shape().dimensions())));
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
+ evaluated_[transpose] = LiteralUtil::Transpose(
+ GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions());
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ // The result concatenate dimension is going to be the sum of all concatenate
+ // dimensions of the operands taking part of the operation.
+ const Shape& reference_shape = operands[0]->shape();
+ CHECK(!ShapeUtil::IsTuple(reference_shape));
+ const int64 rank = ShapeUtil::Rank(reference_shape);
+ const int64 concat_dim = concatenate->dimensions()[0];
+ CHECK_GE(concat_dim, 0);
+ CHECK_LT(concat_dim, rank);
+
+ DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
+ reference_shape.dimensions().end());
+
+ for (int64 i = 1; i < operands.size(); ++i) {
+ const Shape& operand_shape = operands[i]->shape();
+ CHECK(!ShapeUtil::IsTuple(operand_shape));
+ // Accumulate the concat dimension from all tensors taking part to the
+ // operation.
+ concat_dimensions[concat_dim] +=
+ ShapeUtil::GetDimension(operand_shape, concat_dim);
+ }
+
+ auto result_literal = LiteralUtil::CreateFromDimensions(
+ reference_shape.element_type(), concat_dimensions);
+ DimensionVector source_indices(rank, 0);
+ DimensionVector dest_indices(concat_dimensions.size(), 0);
+
+ for (auto operand : operands) {
+ const Shape& operand_shape = operand->shape();
+ TF_RETURN_IF_ERROR(LiteralUtil::Copy(
+ GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(),
+ dest_indices, AsInt64Slice(operand_shape.dimensions())));
+ dest_indices[concat_dim] +=
+ ShapeUtil::GetDimension(operand_shape, concat_dim);
+ }
+
+ evaluated_[concatenate] = std::move(result_literal);
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite,
+ HloInstruction* operand) {
+ if (!ShapeUtil::ElementIsFloating(operand->shape())) {
+ return InvalidArgument(
+ "expected element type in shape to be float for IsFinite op, got: %s",
+ PrimitiveType_Name(operand->shape().element_type()).c_str());
+ }
+
+ switch (operand->shape().element_type()) {
+ case F16:
+ return Unimplemented("unhandled primitive type: F16.");
+ case F32: {
+ auto result_or = ElementWiseUnaryOpImpl<bool, float>(
+ is_finite,
+ [](float elem_operand) { return std::isfinite(elem_operand); },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
+ break;
+ }
+ case F64: {
+ auto result_or = ElementWiseUnaryOpImpl<bool, double>(
+ is_finite,
+ [](double elem_operand) { return std::isfinite(elem_operand); },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
+ break;
+ }
+ default:
+ LOG(FATAL) << "unknown/unhandled primitive type.";
+ }
+
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
+ HloInstruction* lhs, HloInstruction* rhs) {
+ // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
+ // removed.
+ if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
+ ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
+ return Unimplemented(
+ "Implicit broadcasting is currently unsupported in HLO evaluator "
+ "Shape Mismatch: %s vs %s vs %s",
+ ShapeUtil::HumanString(compare->shape()).c_str(),
+ ShapeUtil::HumanString(lhs->shape()).c_str(),
+ ShapeUtil::HumanString(rhs->shape()).c_str());
+ }
+
+ TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
+
+ const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
+ const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
+
+ // Note here we switch on the operand's type.
+ switch (lhs->shape().element_type()) {
+ case PRED: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ case U8: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ case U16:
+ return Unimplemented("unhandled primitive type: U16.");
+ case U32: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ case U64: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ case S8: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ case S16:
+ return Unimplemented("unhandled primitive type: S16.");
+ case S32: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ case S64: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ case F16:
+ return Unimplemented("unhandled primitive type: F16.");
+ case F32: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ case F64: {
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[compare],
+ Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
+ } break;
+ default:
+ LOG(FATAL) << "unknown primitive type.";
+ }
+
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleSlice(HloInstruction* slice,
+ HloInstruction* operand) {
+ const Shape& shape = slice->shape();
+ auto literal = LiteralUtil::CreateFromDimensions(
+ shape.element_type(), AsInt64Slice(shape.dimensions()));
+
+ DimensionVector dest_indices(slice->slice_starts().size(), 0);
+
+ TF_RETURN_IF_ERROR(LiteralUtil::Copy(
+ GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(),
+ dest_indices, AsInt64Slice(shape.dimensions())));
- evaluated_[constant] = MakeUnique<Literal>(literal);
+ evaluated_[slice] = std::move(literal);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 040fd3d73c..91fd56f54c 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -57,21 +57,32 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
// Precondition:
- // 1. argument literals are corresponds to the input instruction's
- // parameters in their post-orderring.
+ // 1. argument literals correspond to the input instruction's parameters in
+ // their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
// TODO(b/35950897): implement more ops other than element-wise ops.
StatusOr<std::unique_ptr<Literal>> Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
+ // Evaluates a single HLO instruction with constant operands.
+ // Returns the evaluated result as literal if successful.
+ // Precondition:
+ // 1. all operands of the input instruction are constants.
+ // 2. the instruction is not a Parameter operation.
+ StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
+
+ // Same as Evaluate, except returning nullptr on error.
+ std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
+
protected:
// Templated DfsHloVisitor. Typically ReturnT here indicates the resulting
- // literal type of each evaluated Handle* method of a TypedVisitor. One
- // exception to this is HandleCompare, where the resulting literal type is
+ // literal type of each evaluated Handle* method of a TypedVisitor.
+ // There are however a few notable exceptions to this is rule, notably:
+ // - HandleCompare and HandleIsFinite: where the resulting literal type is
// always boolean.
- // Note the forward declaration here is necessary to enable TypedVisitor to
- // access parent members.
+ // These operations are handled outside of the parent HloEvaluator handlers
+ // instead of from within TypedVisitor.
template <typename ReturnT>
class TypedVisitor;
@@ -81,15 +92,38 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
}
+ // Operations that are type-agnostic.
+ //
Status HandleParameter(HloInstruction* parameter) override;
Status HandleConstant(HloInstruction* constant,
const Literal& literal) override;
+ Status HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+
+ Status HandleReshape(HloInstruction* reshape) override;
+
+ Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
+
+ Status HandleTranspose(HloInstruction* transpose) override;
+
+ Status HandleIsFinite(HloInstruction* is_finite,
+ HloInstruction* operand) override;
+
+ Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
+ HloInstruction* lhs, HloInstruction* rhs) override;
+
private:
// Returns the already-evaluated literal result for the instruction.
+ // A Constant instruction is considered evaluated and its literal will be
+ // returned directly without looking up the cache.
// Crash with log if the given instruction has not been evaluated previously.
const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
+ if (hlo->IsConstant()) {
+ return hlo->literal();
+ }
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 443e5ad4f4..b26ece28b7 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -23,8 +23,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
@@ -143,7 +145,7 @@ TEST_F(HloEvaluatorTest, DoesDivide) {
// element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest, DoesAbs) {
auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
- Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
+ const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2});
auto c1 = HloInstruction::CreateConstant(std::move(operand));
auto instruction =
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get());
@@ -154,7 +156,29 @@ TEST_F(HloEvaluatorTest, DoesAbs) {
auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
-}
+
+ // For R0 literal.
+ const Shape& r0 = ShapeUtil::MakeShape(F32, {});
+ operand = LiteralUtil::CreateR0<float>(-1.0f);
+ c1 = HloInstruction::CreateConstant(std::move(operand));
+ instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get());
+ result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie();
+ expected = LiteralUtil::CreateR0<float>(1.0f);
+
+ EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
+
+ // For R1 literal with dimension of size 0.
+ Shape empty_r1 = ShapeUtil::MakeShape(F32, {0});
+ operand = LiteralUtil::CreateR1<float>({});
+ c1 = HloInstruction::CreateConstant(std::move(operand));
+ instruction =
+ HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get());
+
+ result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie();
+ expected = LiteralUtil::CreateR1<float>({});
+
+ EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
+} // namespace
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
// constant operands.
@@ -187,5 +211,35 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) {
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
}
+// Verifies Reshape operation is correctly evaluated.
+TEST_F(HloEvaluatorTest, DoesReshape) {
+ HloComputation::Builder builder(
+ ::testing::UnitTest::GetInstance()->current_test_info()->name());
+
+ const int64 dimensions[] = {11, 8, 7, 5, 9};
+ TF_ASSIGN_OR_ASSERT_OK(auto literal,
+ LiteralTestUtil::CreateRandomLiteral<F32>(
+ ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
+ auto literal_clone = LiteralUtil::CloneToUnique(*literal);
+ HloInstruction* literal_instruction = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
+ const int64 permutation[] = {1, 2, 0, 4, 3};
+ builder.AddInstruction(
+ HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
+
+ std::unique_ptr<Literal> result =
+ evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
+
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
+ LiteralUtil::EachCell<NativeT>(
+ *result, [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
+ std::vector<int64> rindexes = Permute(permutation, indices);
+ EXPECT_TRUE(value ==
+ LiteralUtil::Get<NativeT>(*literal_clone, rindexes));
+ });
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index d713d826fb..ecbf1dd1e5 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1306,7 +1306,7 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
void HloInstruction::DetachFromOperands() {
CHECK_EQ(0, user_count());
- // An intruction may be repeated as an operand. To avoid calling RemoveUser
+ // An instruction may be repeated as an operand. To avoid calling RemoveUser
// twice on the same operand, keep a set of already detached operands.
std::set<HloInstruction*> detached_operands;
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
@@ -2162,6 +2162,70 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
return true;
}
+// A helper class for memoized, recursive computation of HloOpcode::kFusion
+// in HloInstruction::OperandElementUse below.
+class HloInstruction::FusionReusesParamElements {
+ public:
+ using UseKind = HloInstruction::UseKind;
+
+ // We could rather iterate backwards thru fused_instructions_ here, as it is
+ // in reverse postorder, and compute whether each fused instruction reuses the
+ // value of this parameter, which would save stack space but not allow us to
+ // finish early if we find a reuse.
+ static UseKind Compute(int64 i, const HloInstruction& hlo) {
+ tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
+ return ComputeInternal(i, hlo, &memoization_cache);
+ }
+
+ private:
+ static UseKind ComputeInternal(
+ int64 i, const HloInstruction& hlo,
+ tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
+ if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) {
+ return UseKind::kUse;
+ }
+
+ auto p = cache->emplace(&hlo, UseKind{});
+ auto value_it = p.first;
+ const bool key_is_new = p.second;
+
+ if (key_is_new) {
+ for (int64 j = 0; j < hlo.operands_.size(); ++j) {
+ UseKind old_val = value_it->second;
+
+ // The next operation invalidates iterators.
+ UseKind new_val =
+ Plus(old_val, std::min(hlo.OperandElementUse(j),
+ ComputeInternal(i, *hlo.operand(j), cache)));
+
+ // Re-acquire the iterator. We could work harder to do this only if
+ // absolutely necessary, but this code is not hot enough to warrant
+ // that.
+ value_it = cache->find(&hlo);
+ value_it->second = new_val;
+ }
+ }
+ return value_it->second;
+ }
+
+ // Fold operation for UseKinds.
+ static UseKind Plus(UseKind a, UseKind b) {
+ if (a == UseKind::kNoUse) {
+ return b;
+ } else if (b == UseKind::kNoUse) {
+ return a;
+ } else if (a == UseKind::kReuse || b == UseKind::kReuse) {
+ return UseKind::kReuse;
+ } else if (a == UseKind::kUsePermutingElements ||
+ b == UseKind::kUsePermutingElements) {
+ return UseKind::kReuse;
+ } else {
+ CHECK(a == UseKind::kUse && b == UseKind::kUse);
+ return UseKind::kUse;
+ }
+ }
+};
+
HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
switch (opcode_) {
case HloOpcode::kBitcast:
@@ -2176,46 +2240,9 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
// Pad reuses the padding value but not the padded array elements.
// Reduce reuses the init value but not the operand array elements.
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
- case HloOpcode::kFusion: {
- tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> cache;
- // We could rather iterate backwards thru fused_instructions_ here, as it
- // is in reverse postorder, and compute whether each fused instruction
- // reuses the value of this parameter, which would save stack space but
- // not allow us to finish early if we find a reuse.
- std::function<UseKind(const HloInstruction&)> reuses_parameter_elements =
- [i, &cache, &reuses_parameter_elements](const HloInstruction& hlo) {
- auto plus = [](const UseKind& a, const UseKind& b) {
- if (a == UseKind::kNoUse) {
- return b;
- } else if (b == UseKind::kNoUse) {
- return a;
- } else if (a == UseKind::kReuse || b == UseKind::kReuse) {
- return UseKind::kReuse;
- } else if (a == UseKind::kUsePermutingElements ||
- b == UseKind::kUsePermutingElements) {
- return UseKind::kReuse;
- }
- CHECK(UseKind::kUse == a && UseKind::kUse == b);
- return UseKind::kUse;
- };
-
- if (hlo.opcode_ == HloOpcode::kParameter &&
- hlo.parameter_number_ == i) {
- return UseKind::kUse;
- }
- if (!ContainsKey(cache, &hlo)) {
- for (int64 j = 0; j < hlo.operands_.size(); ++j) {
- UseKind old = cache[&hlo];
- UseKind updated = plus(
- old, std::min(hlo.OperandElementUse(j),
- reuses_parameter_elements(*hlo.operand(j))));
- cache[&hlo] = updated;
- }
- }
- return cache[&hlo];
- };
- return reuses_parameter_elements(*fused_expression_root());
- }
+ case HloOpcode::kFusion:
+ // Uses the memoizing, recursive computation defined above.
+ return FusionReusesParamElements::Compute(i, *fused_expression_root());
default:
return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 3bf46341be..522414325e 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -775,6 +775,9 @@ class HloInstruction {
private:
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
+ // Helper class for computing OperandElementUse for kFusion.
+ class FusionReusesParamElements;
+
// Creates an n-ary elementwise operation.
static std::unique_ptr<HloInstruction> CreateNary(
const Shape& shape, HloOpcode opcode,
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index da2c075c8c..ee49a9ae5f 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -787,27 +787,28 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
// and unmodified_dim_pair have size >1. Otherwise, returns true and appends
// the degerenate input/output dimensions in the gap to
// deleted_indices/inserted_indices respectively.
- auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices,
- &inserted_indices](
- std::pair<int64, int64> prior_unmodified_dim_pair,
- std::pair<int64, int64> unmodified_dim_pair) {
- for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
- modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) {
- if (shape_pre.dimensions(modified_input_dim) > 1) {
- return false;
- }
- deleted_indices.push_back(modified_input_dim);
- }
- for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
- modified_output_dim < unmodified_dim_pair.second;
- ++modified_output_dim) {
- if (shape_post.dimensions(modified_output_dim) > 1) {
- return false;
- }
- inserted_indices.push_back(modified_output_dim);
- }
- return true;
- };
+ auto check_modified_dims =
+ [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
+ std::pair<int64, int64> prior_unmodified_dim_pair,
+ std::pair<int64, int64> unmodified_dim_pair) {
+ for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
+ modified_input_dim < unmodified_dim_pair.first;
+ ++modified_input_dim) {
+ if (shape_pre.dimensions(modified_input_dim) > 1) {
+ return false;
+ }
+ deleted_indices.push_back(modified_input_dim);
+ }
+ for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
+ modified_output_dim < unmodified_dim_pair.second;
+ ++modified_output_dim) {
+ if (shape_post.dimensions(modified_output_dim) > 1) {
+ return false;
+ }
+ inserted_indices.push_back(modified_output_dim);
+ }
+ return true;
+ };
std::vector<std::pair<int64, int64>> unmodified_dims =
DimensionsUnmodifiedByReshape(shape_pre, shape_post);
@@ -1220,6 +1221,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const IndexVisitorFunction& visitor_function) {
+ if (ShapeUtil::HasZeroElements(shape)) {
+ return;
+ }
DCHECK_EQ(Rank(shape), base.size());
DCHECK_EQ(incr.size(), base.size());
DCHECK_EQ(count.size(), base.size());
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 8ac2e8345b..69ef6175cc 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -467,6 +467,34 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
}
+TEST(ShapeUtilTest, ForEachIndex) {
+ struct ShapeDimensionAndNumberInvocations {
+ std::vector<int64> dimensions;
+ int invocations;
+ } test_data[] = {
+ {{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0},
+ {{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610},
+ };
+
+ for (const auto& data : test_data) {
+ Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
+ // Increments at every invocation.
+ int invocations = 0;
+ auto increment_func = [&invocations](const std::vector<int64>& indexes) {
+ invocations++;
+ return true;
+ };
+
+ std::vector<int64> zero_base(data.dimensions.size(), 0);
+ std::vector<int64> step(data.dimensions.size(), 1);
+
+ ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step,
+ increment_func);
+
+ EXPECT_EQ(invocations, data.invocations);
+ }
+}
+
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
// All output dimensions should be unmodified. One of the input dimensions is
// modified because the input rank is larger by one.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 1971868a38..e60d38d0c6 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -93,6 +93,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:backend",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 03552d7bbf..b96bb8f846 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -48,6 +48,15 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
: client_(GetOrCreateLocalClientOrDie(platform)) {
*(execution_options_.mutable_debug_options()) =
legacy_flags::GetDebugOptionsFromFlags();
+
+ // Disabling constant_folding so that tests (usually written using Constants)
+ // will exercise the intended code paths, instead of being constant folded.
+ //
+ // TODO(b/38354253): Constant folding is currently disabled. Change tests to
+ // use Parameters instead of Constants, and re-enable constant folding by
+ // default.
+ execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
+ "constant_folding");
}
string ClientLibraryTestBase::TestName() const {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 871fbeb0a8..fbbb101ce9 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
@@ -54,6 +55,8 @@ struct HloTestBase::EigenThreadPoolWrapper {
HloTestBase::HloTestBase()
: backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) {
+ // TODO(b/62411181): get rid of this flag entirely when the usual debug flags
+ // are piped to all HLO tests.
test_hlo_dumper_ = [](const HloModule& module, const string& label) {
legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags();
if (flags->xla_hlo_test_generate_hlo_graph) {
@@ -73,6 +76,13 @@ HloTestBase::~HloTestBase() {
}
}
+std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
+ HloModuleConfig config;
+ config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+ return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
+ config);
+}
+
StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index f8045b45b9..4fe0bbd55f 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -44,6 +44,12 @@ class HloTestBase : public ::testing::Test {
~HloTestBase() override;
+ // Creates a new HLO module for a test. The module created will have
+ // TestName() for its name; it will also automatically populate its debug
+ // options from command-line flags. It's recommended to use this method to
+ // create all HloModules for tests.
+ std::unique_ptr<HloModule> CreateNewModule();
+
// Executes the given module and returns a global data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 4f98083033..a8b07a2c5d 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 7286cce03c..b99933ff9b 100755
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -16,6 +16,7 @@ py_library(
"//tensorflow/contrib/batching:batch_py",
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/cloud:cloud_py",
+ "//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/compiler:compiler_py",
"//tensorflow/contrib/copy_graph:copy_graph_py",
"//tensorflow/contrib/crf:crf_py",
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD
new file mode 100644
index 0000000000..34cdb2a132
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/BUILD
@@ -0,0 +1,47 @@
+# Description: Operations defined for Cluster Resolvers
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+package(
+ default_visibility = [
+ "//tensorflow:__subpackages__",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
+
+py_library(
+ name = "cluster_resolver_py",
+ srcs = [
+ "python/training/__init__.py",
+ "python/training/cluster_resolver.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework",
+ ],
+)
+
+tf_py_test(
+ name = "cluster_resolver_py_test",
+ srcs = ["python/training/cluster_resolver_test.py"],
+ additional_deps = [
+ ":cluster_resolver_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ main = "python/training/cluster_resolver_test.py",
+)
diff --git a/tensorflow/contrib/cluster_resolver/README.md b/tensorflow/contrib/cluster_resolver/README.md
new file mode 100644
index 0000000000..6fe6871eb4
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/README.md
@@ -0,0 +1,5 @@
+# Cluster Resolvers
+
+Cluster Resolvers are a new way of specifying cluster information for distributed execution. Built on top of existing `ClusterSpec` framework, Cluster Resolvers allow users to simply specify a configuration and a cluster management service and a `ClusterResolver` will automatically fetch the relevant information from the service and populate `ClusterSpec`s.
+
+`ClusterResolvers` are designed to work well with `ManagedTrainingSession` and `ClusterSpec` propagation so that distributed training sessions remain robust in the face of node and network failures.
diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py
new file mode 100644
index 0000000000..3520467bc6
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library Imports for Cluster Resolvers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py
new file mode 100644
index 0000000000..87da24f22d
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py
@@ -0,0 +1,171 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.training.server_lib import ClusterSpec
+
+
+class ClusterResolver(object):
+ """Abstract class for all implementations of ClusterResolvers.
+
+ This defines the skeleton for all implementations of ClusterResolvers.
+ ClusterResolvers are a way for TensorFlow to communicate with various cluster
+ management systems (e.g. GCE, AWS, etc...).
+
+ By letting TensorFlow communicate with these systems, we will be able to
+ automatically discover and resolve IP addresses for various TensorFlow
+ workers. This will eventually allow us to automatically recover from
+ underlying machine failures and scale TensorFlow worker clusters up and down.
+ """
+
+ @abc.abstractmethod
+ def cluster_spec(self):
+ """Retrieve the current state of the cluster and returns a ClusterSpec.
+
+ Returns:
+ A ClusterSpec representing the state of the cluster at the moment this
+ function is called.
+
+ Implementors of this function must take care in ensuring that the
+ ClusterSpec returned is up-to-date at the time of calling this function.
+ This usually means retrieving the information from the underlying cluster
+ management system every time this function is invoked and reconstructing
+ a cluster_spec, rather than attempting to cache anything.
+ """
+ raise NotImplementedError(
+ 'cluster_spec is not implemented for {}.'.format(self))
+
+
+class SimpleClusterResolver(ClusterResolver):
+ """Simple implementation of ClusterResolver that accepts a ClusterSpec."""
+
+ def __init__(self, cluster_spec):
+ """Creates a SimpleClusterResolver from a ClusterSpec."""
+ super(SimpleClusterResolver, self).__init__()
+
+ if not isinstance(cluster_spec, ClusterSpec):
+ raise TypeError('cluster_spec must be a ClusterSpec.')
+ self._cluster_spec = cluster_spec
+
+ def cluster_spec(self):
+ """Returns the ClusterSpec passed into the constructor."""
+ return self._cluster_spec
+
+
+class UnionClusterResolver(ClusterResolver):
+ """Performs a union on underlying ClusterResolvers.
+
+ This class performs a union given two or more existing ClusterResolvers. It
+ merges the underlying ClusterResolvers, and returns one unified ClusterSpec
+ when as_cluster_spec is called. The details of the merge function is
+ documented in the as_cluster_spec function.
+ """
+
+ def __init__(self, *args):
+ """Initializes a UnionClusterResolver with other ClusterResolvers.
+
+ Args:
+ *args: `ClusterResolver` objects to be unionized.
+
+ Raises:
+ TypeError: If any argument is not a subclass of `ClusterResolvers`.
+ """
+ super(UnionClusterResolver, self).__init__()
+
+ for cluster_resolver in args:
+ if not isinstance(cluster_resolver, ClusterResolver):
+ raise TypeError('All arguments must be a sub-class of '
+ '`ClusterResolver.`')
+ self._cluster_resolvers = args
+
+ def cluster_spec(self):
+ """Returns a union of all the ClusterSpecs from the ClusterResolvers.
+
+ Returns:
+ A ClusterSpec containing host information merged from all the underlying
+ ClusterResolvers.
+
+ Raises:
+ KeyError: If there are conflicting keys detected when merging two or
+ more dictionaries, this exception is raised.
+
+ Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the
+ same job name, we will merge the list/dict of workers.
+
+ If *all* underlying ClusterSpecs expose the set of workers as lists, we will
+ concatenate the lists of workers, starting with the list of workers from
+ the first ClusterResolver passed into the constructor.
+
+ If *any* of the ClusterSpecs expose the set of workers as a dict, we will
+ treat all the sets of workers as dicts (even if they are returned as lists)
+ and will only merge them into a dict if there is no conflicting keys. If
+ there is a conflicting key, we will raise a `KeyError`.
+ """
+
+ merged_cluster = {}
+
+ # We figure out whether it is all lists for a particular job, or whether
+ # there are dicts inside.
+ for cluster_resolver in self._cluster_resolvers:
+ cluster_spec = cluster_resolver.cluster_spec()
+ cluster_dict = cluster_spec.as_dict()
+
+ for job_name, tasks in cluster_dict.items():
+ if job_name in merged_cluster:
+ # If we see a dict, then we write a dict out regardless.
+ if isinstance(tasks, dict):
+ merged_cluster[job_name] = {}
+ else:
+ # We take whichever type is present.
+ if isinstance(tasks, list):
+ merged_cluster[job_name] = []
+ else:
+ merged_cluster[job_name] = {}
+
+ # We then do the merge as appropriate in merged_cluster[job].
+ for cluster_resolver in self._cluster_resolvers:
+ cluster_spec = cluster_resolver.cluster_spec()
+ cluster_dict = cluster_spec.as_dict()
+
+ for job_name, tasks in cluster_dict.items():
+ if isinstance(merged_cluster[job_name], list):
+ # We all have lists, we can just concatenate and be done.
+ merged_cluster[job_name].extend(tasks)
+ else:
+ if isinstance(tasks, list):
+ # We convert to a dictionary if the type is a list.
+ task_dict = dict(zip(range(0, len(tasks)), tasks))
+ else:
+ # We can simply make a copy (for update) and be done.
+ task_dict = tasks.copy()
+
+ # We detect if there are duplicates, and raise an error if so.
+ task_keys = set(task_dict)
+ merged_keys = set(merged_cluster[job_name].keys())
+ intersected_keys = task_keys.intersection(merged_keys)
+ if intersected_keys:
+ raise KeyError('Duplicate keys detected when merging two '
+ 'ClusterSpecs: %s' % repr(intersected_keys))
+
+ # We do the merge after all the processing.
+ merged_cluster[job_name].update(task_dict)
+
+ return ClusterSpec(merged_cluster)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py
new file mode 100644
index 0000000000..dbfb77723c
--- /dev/null
+++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py
@@ -0,0 +1,238 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Cluster Resolvers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
+from tensorflow.python.platform import test
+from tensorflow.python.training import server_lib
+
+
+class UnionClusterResolverTest(test.TestCase):
+ # TODO(frankchn): Transform to parameterized test after it is included in the
+ # TF open source codebase.
+
+ def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
+ self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto,
+ server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
+ self.assertProtoEquals(
+ expected_proto,
+ server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
+
+ def testSingleClusterResolver(self):
+ base_cluster_spec = server_lib.ClusterSpec({
+ "ps": ["ps0:2222", "ps1:2222"],
+ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+ })
+ simple_resolver = SimpleClusterResolver(base_cluster_spec)
+ union_resolver = UnionClusterResolver(simple_resolver)
+
+ expected_proto = """
+ job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
+ tasks { key: 1 value: 'ps1:2222' } }
+ job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
+ tasks { key: 1 value: 'worker1:2222' }
+ tasks { key: 2 value: 'worker2:2222' } }
+ """
+ actual_cluster_spec = union_resolver.cluster_spec()
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ def testTwoNonOverlappingJobMergedClusterResolver(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "ps": [
+ "ps0:2222",
+ "ps1:2222"
+ ]
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": [
+ "worker0:2222",
+ "worker1:2222",
+ "worker2:2222"
+ ]
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
+ tasks { key: 1 value: 'ps1:2222' } }
+ job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
+ tasks { key: 1 value: 'worker1:2222' }
+ tasks { key: 2 value: 'worker2:2222' } }
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+ def testOverlappingJobMergedClusterResolver(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": [
+ "worker4:2222",
+ "worker5:2222"
+ ]
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": [
+ "worker0:2222",
+ "worker1:2222",
+ "worker2:2222"
+ ]
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'worker' tasks { key: 0 value: 'worker4:2222' }
+ tasks { key: 1 value: 'worker5:2222' }
+ tasks { key: 2 value: 'worker0:2222' }
+ tasks { key: 3 value: 'worker1:2222' }
+ tasks { key: 4 value: 'worker2:2222' } }
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+ def testOverlappingSparseJobMergedClusterResolverThrowError(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": {
+ 7: "worker4:2222",
+ 9: "worker5:2222"
+ }
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": {
+ 3: "worker0:2222",
+ 6: "worker1:2222",
+ 7: "worker2:2222"
+ }
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ self.assertRaises(KeyError, union_cluster.cluster_spec)
+
+ def testOverlappingDictAndListThrowError(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": [
+ "worker4:2222",
+ "worker5:2222"
+ ]
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": {
+ 1: "worker0:2222",
+ 2: "worker1:2222",
+ 3: "worker2:2222"
+ }
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ self.assertRaises(KeyError, union_cluster.cluster_spec)
+
+ def testOverlappingJobNonOverlappingKey(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": {
+ 5: "worker4:2222",
+ 9: "worker5:2222"
+ }
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": {
+ 3: "worker0:2222",
+ 6: "worker1:2222",
+ 7: "worker2:2222"
+ }
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'worker' tasks { key: 3 value: 'worker0:2222' }
+ tasks { key: 5 value: 'worker4:2222' }
+ tasks { key: 6 value: 'worker1:2222' }
+ tasks { key: 7 value: 'worker2:2222' }
+ tasks { key: 9 value: 'worker5:2222' }}
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+ def testMixedModeNonOverlappingKey(self):
+ cluster_spec_1 = server_lib.ClusterSpec({
+ "worker": [
+ "worker4:2222",
+ "worker5:2222"
+ ]
+ })
+ cluster_spec_2 = server_lib.ClusterSpec({
+ "worker": {
+ 3: "worker0:2222",
+ 6: "worker1:2222",
+ 7: "worker2:2222"
+ }
+ })
+ cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1)
+ cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2)
+
+ union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'worker' tasks { key: 0 value: 'worker4:2222' }
+ tasks { key: 1 value: 'worker5:2222' }
+ tasks { key: 3 value: 'worker0:2222' }
+ tasks { key: 6 value: 'worker1:2222' }
+ tasks { key: 7 value: 'worker2:2222' }}
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+ def testRetainSparseJobWithNoMerging(self):
+ base_cluster_spec = server_lib.ClusterSpec({
+ "worker": {
+ 1: "worker0:2222",
+ 3: "worker1:2222",
+ 5: "worker2:2222"
+ }
+ })
+
+ base_cluster_resolver = SimpleClusterResolver(base_cluster_spec)
+ union_cluster = UnionClusterResolver(base_cluster_resolver)
+ cluster_spec = union_cluster.cluster_spec()
+
+ expected_proto = """
+ job { name: 'worker' tasks { key: 1 value: 'worker0:2222' }
+ tasks { key: 3 value: 'worker1:2222' }
+ tasks { key: 5 value: 'worker2:2222' } }
+ """
+ self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc b/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc
index 12a4d36bf1..a74ad98663 100644
--- a/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc
+++ b/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc
@@ -31,6 +31,29 @@ limitations under the License.
namespace tensorflow {
+namespace {
+// Returning a Status instead of using OP_REQUIRES directly since that doesn't
+// seem to work outside the main OpKernel functions.
+Status RemapVectorToMap(const TTypes<const int64>::Vec& remapping,
+ std::vector<bool>* id_present,
+ std::unordered_map<int64, int64>* old_id_to_new_id) {
+ id_present->clear();
+ id_present->resize(remapping.size(), false);
+ for (int i = 0; i < remapping.size(); ++i) {
+ const int64 old_id = remapping(i);
+ if (old_id < 0) continue;
+ (*id_present)[i] = true;
+ if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
+ return errors::Unimplemented(
+ strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
+ old_id_to_new_id->at(old_id), " and ", i,
+ ", which is not supported."));
+ }
+ }
+ return Status::OK();
+}
+} // anonymous namespace
+
// This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
// swaps around the rows/columns according to row_remapping/col_remapping.
// "Missing" cells are initialized with values from initializing_values.
@@ -40,13 +63,15 @@ class LoadAndRemapMatrixOp : public OpKernel {
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
}
void Compute(OpKernelContext* context) override {
// Checks what we're remapping and inverts the relevant remapping Tensors to
// be maps with key = old ID, value = new ID.
- std::vector<std::pair<int64, int64>> old_row_to_new_row_pairs;
- std::vector<bool> row_id_present(num_rows_);
+ std::unordered_map<int64, int64> old_row_to_new_row_map;
+ std::vector<bool> row_id_present;
const Tensor* row_remapping_t;
OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
const auto row_remapping = row_remapping_t->vec<int64>();
@@ -54,16 +79,27 @@ class LoadAndRemapMatrixOp : public OpKernel {
errors::InvalidArgument(strings::StrCat(
"Size of row_remapping is ", row_remapping.size(),
" intead of being equal to num_rows=", num_rows_)));
- old_row_to_new_row_pairs.reserve(num_rows_);
+ OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
+ &old_row_to_new_row_map));
+
+ // Calculates the min/max old row ID that we need to read, to save us from
+ // reading some unnecessary slices of the old tensor.
+ int64 min_old_row = -1;
+ int64 max_old_row = -1;
for (int i = 0; i < row_remapping.size(); ++i) {
- if (row_remapping(i) < 0) continue;
- row_id_present[i] = true;
- old_row_to_new_row_pairs.push_back(std::make_pair(row_remapping(i), i));
+ if (min_old_row < 0 ||
+ (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
+ min_old_row = row_remapping(i);
+ }
+ if (max_old_row < 0 ||
+ (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
+ max_old_row = row_remapping(i);
+ }
}
// Processes the remapping for columns.
std::unordered_map<int64, int64> old_col_to_new_col_map;
- std::vector<bool> col_id_present(num_cols_);
+ std::vector<bool> col_id_present;
const Tensor* col_remapping_t;
OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
const auto col_remapping = col_remapping_t->vec<int64>();
@@ -77,19 +113,8 @@ class LoadAndRemapMatrixOp : public OpKernel {
errors::InvalidArgument(strings::StrCat(
"Provided col_remapping, but its size is ", col_remapping.size(),
" instead of being equal to num_cols=", num_cols_)));
- for (int i = 0; i < col_remapping.size(); ++i) {
- const int64 old_col = col_remapping(i);
- if (old_col < 0) continue;
- col_id_present[i] = true;
- OP_REQUIRES(
- context,
- gtl::InsertIfNotPresent(&old_col_to_new_col_map, old_col, i),
- errors::Unimplemented(strings::StrCat(
- "Old column ID ", old_col, " is mapped to both new column ID ",
- old_col_to_new_col_map[old_col], " and ", i,
- ", which is not currently supported - but could be "
- "implemented.")));
- }
+ OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
+ &old_col_to_new_col_map));
} else {
col_id_present.clear();
col_id_present.resize(num_cols_, true);
@@ -139,29 +164,27 @@ class LoadAndRemapMatrixOp : public OpKernel {
" instead of being equal to num_cols=", num_cols_)));
}
- // Uses TensorSlice to selectively read rows of interest from the old
- // tensor. Given BundleReader's use of RandomAccessFile and InputBuffer,
- // there shouldn't too many more additional disk seeks when compared to
- // loading the old tensor in chunks, once we sort the row IDs. Even if there
- // are locality concerns with some reading patterns, that just means if we
- // had read it in chunks, then we would have had to read, copy, and process
- // then discard many redundant rows - so we should come out ahead this way.
- // In addition, this frees us from having to hold the entire old tensor in
- // memory.
- std::sort(old_row_to_new_row_pairs.begin(), old_row_to_new_row_pairs.end());
+ // Uses TensorSlice to potentially load the old tensor in chunks in case
+ // memory usage is a concern.
std::vector<TensorSlice> tensor_slices;
- tensor_slices.reserve(old_row_to_new_row_pairs.size());
TensorSlice slice(tensor_shape.dims());
- for (const auto& pair : old_row_to_new_row_pairs) {
- OP_REQUIRES(
- context, pair.first < tensor_shape.dim_size(0),
- errors::InvalidArgument(strings::StrCat(
- "Trying to read row ", pair.first, " from tensor ",
- old_tensor_name, ", which only has ", tensor_shape.dim_size(0),
- " rows (with shape ", tensor_shape.DebugString(), ").")));
- slice.set_start(0, pair.first);
- slice.set_length(0, 1);
- tensor_slices.push_back(slice);
+ if (min_old_row >= 0 && max_old_row >= 0) {
+ int64 row_start = min_old_row;
+ // TODO(weiho): Given the list of old row IDs of interest (the keys of
+ // old_row_to_new_row_map), we could also try something smarter to
+ // find some minimal set of covering ranges for the list of old row IDs
+ // such that the size of each range is less than max_rows_in_memory_.
+ while (row_start <= max_old_row) {
+ const int64 slice_length =
+ max_rows_in_memory_ <= 0
+ // If max_rows_in_memory_ <= 0, we just load the entire chunk.
+ ? max_old_row - row_start + 1
+ : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
+ slice.set_start(0, row_start);
+ slice.set_length(0, slice_length);
+ tensor_slices.push_back(slice);
+ row_start += slice_length;
+ }
}
// Allocates the output matrix.
@@ -174,52 +197,72 @@ class LoadAndRemapMatrixOp : public OpKernel {
// Iterates through tensor slices and copies over values from the old tensor
// to the output matrix.
- Tensor loaded_tensor_t(DT_FLOAT,
- TensorShape({1, tensor_shape.dim_size(1)}));
- for (int i = 0; i < tensor_slices.size(); ++i) {
- const int64 new_row = old_row_to_new_row_pairs[i].second;
- if (i % 500000 == 0) {
- LOG(INFO) << "Processing slice " << i << " of " << tensor_slices.size()
- << " - corresponding to old row "
- << old_row_to_new_row_pairs[i].first << " of "
- << tensor_shape.dim_size(0);
- }
+ int64 row_index = min_old_row;
+ int64 rows_copied = 0;
+ Tensor loaded_tensor_t;
+ for (const TensorSlice& tensor_slice : tensor_slices) {
+ LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
+ TensorShape slice_shape;
OP_REQUIRES_OK(context,
- reader.LookupSlice(old_tensor_name, tensor_slices[i],
- &loaded_tensor_t));
+ tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
+ // Potentially re-allocates the tensor buffer since the last slice may
+ // have fewer rows than the other slices.
+ if (loaded_tensor_t.shape() != slice_shape) {
+ loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
+ }
+ OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
+ &loaded_tensor_t));
- // Copies over the row element-by-element, in case remapping is needed
- // along the column axis.
- const auto& loaded_tensor = loaded_tensor_t.flat<float>();
- for (int old_col = 0; old_col < loaded_tensor.size(); ++old_col) {
- int64 new_col = old_col;
- if (remap_cols) {
- const int64* new_col_ptr =
- gtl::FindOrNull(old_col_to_new_col_map, old_col);
- if (new_col_ptr == nullptr) {
- // Column remapping is specified, but this column is not found in
- // old_col_to_new_col_map, so we leave it uninitialized, to be
- // filled in with initializing_values later.
- continue;
- }
- new_col = *new_col_ptr;
+ // Iterates through the old loaded tensor slice row-by-row.
+ for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
+ if (row_index % 500000 == min_old_row) {
+ LOG(INFO) << "Processing old row " << row_index;
+ }
+
+ // If the old row ID is not found in old_row_to_new_row_map, continue
+ // to the next row; otherwise, copy it to the output matrix.
+ const int64* new_row_ptr =
+ gtl::FindOrNull(old_row_to_new_row_map, row_index);
+ if (new_row_ptr == nullptr) {
+ continue;
}
+ ++rows_copied;
+ const int64 new_row = *new_row_ptr;
- OP_REQUIRES(context,
- new_row < num_rows_ && new_col < num_cols_ &&
- new_row >= 0 && new_col >= 0,
- errors::Internal(strings::StrCat(
- "new_row=", new_row, " and new_col=", new_col,
- " should have been less than num_rows_=", num_rows_,
- " and num_cols_=", num_cols_,
- " and non-negative. This should never have happened "
- "if the code were correct. Please file a bug.")));
- output_matrix(new_row, new_col) = loaded_tensor(old_col);
+ // Copies over the row element-by-element, in case remapping is needed
+ // along the column axis.
+ const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
+ for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
+ ++old_col) {
+ int64 new_col = old_col;
+ if (remap_cols) {
+ const int64* new_col_ptr =
+ gtl::FindOrNull(old_col_to_new_col_map, old_col);
+ if (new_col_ptr == nullptr) {
+ // Column remapping is specified, but this column is not found in
+ // old_col_to_new_col_map, so we leave it uninitialized, to be
+ // filled in with initializing_values later.
+ continue;
+ }
+ new_col = *new_col_ptr;
+ }
+
+ OP_REQUIRES(context,
+ new_row < num_rows_ && new_col < num_cols_ &&
+ new_row >= 0 && new_col >= 0,
+ errors::Internal(strings::StrCat(
+ "new_row=", new_row, " and new_col=", new_col,
+ " should have been less than num_rows_=", num_rows_,
+ " and num_cols_=", num_cols_,
+ " and non-negative. This should never have happened "
+ "if the code were correct. Please file a bug.")));
+ output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
+ }
}
}
- LOG(INFO) << "Copied " << tensor_slices.size()
- << " rows from old matrix (with " << tensor_shape.dim_size(0)
- << " rows) to new matrix (with " << num_rows_ << " rows).";
+ LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
+ << tensor_shape.dim_size(0) << " rows) to new matrix (with "
+ << num_rows_ << " rows).";
// At this point, there are potentially whole rows/columns uninitialized
// (corresponding to the indices where row_id_present/col_id_present are
@@ -232,10 +275,14 @@ class LoadAndRemapMatrixOp : public OpKernel {
int64 initializing_values_index = 0;
for (int i = 0; i < num_rows_; ++i) {
for (int j = 0; j < num_cols_; ++j) {
- if (!row_id_present[i] || !col_id_present[j]) {
- output_matrix(i, j) = initializing_values(initializing_values_index);
- ++initializing_values_index;
- }
+ if (row_id_present[i] && col_id_present[j]) continue;
+ OP_REQUIRES(
+ context, initializing_values_index < initializing_values.size(),
+ errors::InvalidArgument(
+ "initializing_values contained ", initializing_values.size(),
+ " elements, but more missing values remain."));
+ output_matrix(i, j) = initializing_values(initializing_values_index);
+ ++initializing_values_index;
}
}
@@ -251,6 +298,7 @@ class LoadAndRemapMatrixOp : public OpKernel {
private:
int64 num_rows_;
int64 num_cols_;
+ int64 max_rows_in_memory_;
};
REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
diff --git a/tensorflow/contrib/framework/ops/checkpoint_ops.cc b/tensorflow/contrib/framework/ops/checkpoint_ops.cc
index 09d487dd64..b49d7b4d40 100644
--- a/tensorflow/contrib/framework/ops/checkpoint_ops.cc
+++ b/tensorflow/contrib/framework/ops/checkpoint_ops.cc
@@ -83,6 +83,7 @@ REGISTER_OP("LoadAndRemapMatrix")
.Input("initializing_values: float")
.Attr("num_rows: int >= 0")
.Attr("num_cols: int >= 1")
+ .Attr("max_rows_in_memory: int = -1")
.Output("output_matrix: float")
// TODO(b/30502450): Setting the op as being stateful prevents it from being
// executed more often than expected (possibly due to stateful ops not being
@@ -154,6 +155,9 @@ initializing_values: A float `Tensor` containing values to fill in for cells
exactly the same as the number of missing / new cells.
num_rows: Number of rows (length of the 1st dimension) in the output matrix.
num_cols: Number of columns (length of the 2nd dimension) in the output matrix.
+max_rows_in_memory: The maximum number of rows to load from the checkpoint at
+ once. If less than or equal to 0, the entire matrix will be loaded into
+ memory. Setting this arg trades increased disk reads for lower memory usage.
output_matrix: Output matrix containing existing values loaded from the
checkpoint, and with any missing values filled in from initializing_values.
)doc");
diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops.py
index fdb834f46b..92228f8916 100644
--- a/tensorflow/contrib/framework/python/ops/checkpoint_ops.py
+++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops.py
@@ -46,7 +46,8 @@ def _load_and_remap_matrix(ckpt_path,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=0,
- num_col_oov_buckets=0):
+ num_col_oov_buckets=0,
+ max_rows_in_memory=-1):
"""Loads a 2-D (matrix) `Tensor` from checkpoint.
Generates 1D-remappings for rows and columns using the
@@ -99,6 +100,10 @@ def _load_and_remap_matrix(ckpt_path,
to append. Must be >= 0.
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
columns to append. Must be >= 0.
+ max_rows_in_memory: `int` specifying the maximum number of rows to load from
+ the checkpoint at once. If less than or equal to 0, the entire matrix will
+ be loaded into memory. Setting this arg trades increased disk reads for
+ lower memory usage.
Returns:
A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
@@ -177,7 +182,8 @@ def _load_and_remap_matrix(ckpt_path,
col_remapping=col_remapping,
initializing_values=init_vals,
num_rows=num_rows_to_load,
- num_cols=new_col_vocab_size)
+ num_cols=new_col_vocab_size,
+ max_rows_in_memory=max_rows_in_memory)
# Add OOV row(s) and column(s).
if num_row_oov_buckets > 0:
@@ -204,7 +210,8 @@ def load_and_remap_matrix_initializer(ckpt_path,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0,
- initializer=None):
+ initializer=None,
+ max_rows_in_memory=-1):
r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor.
The returned initializer loads a 2-D (matrix) `Tensor` with name
@@ -297,6 +304,10 @@ def load_and_remap_matrix_initializer(ckpt_path,
initializer: Initializer function to initialize missing values. Accepts a
1-D tensor as the arg to specify the shape of the returned tensor. If
`None`, defaults to using `zeros_initializer()`.
+ max_rows_in_memory: `int` specifying the maximum number of rows to load from
+ the checkpoint at once. If less than or equal to 0, the entire matrix will
+ be loaded into memory. Setting this arg trades increased disk reads for
+ lower memory usage.
Returns:
A variable initializer function that should be used to initialize a
@@ -378,7 +389,8 @@ def load_and_remap_matrix_initializer(ckpt_path,
old_col_vocab_file=old_col_vocab_file,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=row_oov_buckets_to_use,
- num_col_oov_buckets=num_col_oov_buckets)
+ num_col_oov_buckets=num_col_oov_buckets,
+ max_rows_in_memory=max_rows_in_memory)
return _initializer
@@ -390,7 +402,8 @@ def load_embedding_initializer(ckpt_path,
old_vocab_file,
new_vocab_file,
num_oov_buckets=0,
- initializer=None):
+ initializer=None,
+ max_rows_in_memory=-1):
"""Returns a variable initializer for loading pre-trained embeddings.
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
@@ -416,6 +429,10 @@ def load_embedding_initializer(ckpt_path,
initializer: Initializer function that accepts a 1-D tensor as the arg to
specify the shape of the returned tensor. If `None`, defaults to using
`truncated_normal_initializer()`.
+ max_rows_in_memory: `int` specifying the maximum number of rows to load from
+ the checkpoint at once. If less than or equal to 0, the entire matrix will
+ be loaded into memory. Setting this arg trades increased disk reads for
+ lower memory usage.
Returns:
A variable initializer function.
@@ -437,7 +454,8 @@ def load_embedding_initializer(ckpt_path,
new_col_vocab_file=None,
num_row_oov_buckets=num_oov_buckets,
num_col_oov_buckets=0,
- initializer=initializer)
+ initializer=initializer,
+ max_rows_in_memory=max_rows_in_memory)
def load_linear_multiclass_bias_initializer(ckpt_path,
@@ -446,7 +464,8 @@ def load_linear_multiclass_bias_initializer(ckpt_path,
old_class_vocab_file,
new_class_vocab_file,
num_class_oov_buckets=0,
- initializer=None):
+ initializer=None,
+ max_rows_in_memory=-1):
"""Loads pre-trained multi-class biases for linear models from checkpoint.
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
@@ -469,6 +488,10 @@ def load_linear_multiclass_bias_initializer(ckpt_path,
initializer: Initializer function that accepts a 1-D tensor as the arg to
specify the shape of the returned tensor. If `None`, defaults to using
`zeros_initializer()`.
+ max_rows_in_memory: `int` specifying the maximum number of rows to load from
+ the checkpoint at once. If less than or equal to 0, the entire matrix will
+ be loaded into memory. Setting this arg trades increased disk reads for
+ lower memory usage.
Returns:
A variable initializer function.
@@ -488,7 +511,8 @@ def load_linear_multiclass_bias_initializer(ckpt_path,
new_col_vocab_file=None,
num_row_oov_buckets=num_class_oov_buckets,
num_col_oov_buckets=0,
- initializer=initializer)
+ initializer=initializer,
+ max_rows_in_memory=max_rows_in_memory)
def load_variable_slot_initializer(ckpt_path,
@@ -502,7 +526,8 @@ def load_variable_slot_initializer(ckpt_path,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0,
- initializer=None):
+ initializer=None,
+ max_rows_in_memory=-1):
"""Loads pre-trained multi-class slots for linear models from checkpoint.
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
@@ -549,6 +574,10 @@ def load_variable_slot_initializer(ckpt_path,
initializer: Initializer function to initialize missing values. Accepts a
1-D tensor as the arg to specify the shape of the returned tensor. If
`None`, defaults to using `zeros_initializer()`.
+ max_rows_in_memory: `int` specifying the maximum number of rows to load from
+ the checkpoint at once. If less than or equal to 0, the entire matrix will
+ be loaded into memory. Setting this arg trades increased disk reads for
+ lower memory usage.
Returns:
A variable initializer function that should be used to initialize a
@@ -570,7 +599,8 @@ def load_variable_slot_initializer(ckpt_path,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=num_row_oov_buckets,
num_col_oov_buckets=num_col_oov_buckets,
- initializer=initializer)
+ initializer=initializer,
+ max_rows_in_memory=max_rows_in_memory)
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
del partition_info # Unused by this override.
diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
index 321375ddfc..911c5a210c 100644
--- a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
@@ -118,7 +118,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
# No column remapping, new weight matrix has second row, then first row.
row_remapping = [1, 0]
- remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=row_remapping,
@@ -128,12 +128,12 @@ class LoadAndRemapMatrixTest(test.TestCase):
num_cols=self.old_num_cols)
with self.test_session():
self.assertAllClose(self.matrix_value[row_remapping],
- remapped_weight_matrix.eval())
+ remapped_matrix.eval())
# No row remapping, new weight matrix has third col, then first col.
row_remapping = list(range(self.old_num_rows))
col_remapping = [2, 0]
- remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=row_remapping,
@@ -143,12 +143,12 @@ class LoadAndRemapMatrixTest(test.TestCase):
num_cols=len(col_remapping))
with self.test_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
- remapped_weight_matrix.eval())
+ remapped_matrix.eval())
# Both row and column remappings.
row_remapping = [1, 0, 4]
col_remapping = [1, 15]
- remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=row_remapping,
@@ -158,12 +158,12 @@ class LoadAndRemapMatrixTest(test.TestCase):
num_cols=len(col_remapping))
with self.test_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
- remapped_weight_matrix.eval())
+ remapped_matrix.eval())
def test_load_and_remap_with_init(self):
"""Tests the op's load and remap where there are missing entries."""
init_val = 42
- remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=[2, -1, 0],
@@ -172,18 +172,17 @@ class LoadAndRemapMatrixTest(test.TestCase):
num_rows=3,
num_cols=2)
- expected_remapped_weight_matrix = np.reshape(
+ expected_remapped_matrix = np.reshape(
[33, init_val, init_val, init_val, 1, init_val], [3, 2])
with self.test_session():
- self.assertAllClose(expected_remapped_weight_matrix,
- remapped_weight_matrix.eval())
+ self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_all_missing_rows(self):
"""Tests when all the rows are missing and need to be initialized."""
num_rows = 7
initializing_values = [42] * num_rows * self.old_num_cols
- remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=[-1] * num_rows,
@@ -194,14 +193,14 @@ class LoadAndRemapMatrixTest(test.TestCase):
with self.test_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, self.old_num_cols)),
- remapped_weight_matrix.eval())
+ remapped_matrix.eval())
def test_load_and_remap_all_missing_rows_and_cols(self):
"""Tests when all the rows & cols are missing and need to be initialized."""
num_rows = 7
num_cols = 4
initializing_values = [42] * num_rows * num_cols
- remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=[-1] * num_rows,
@@ -212,42 +211,216 @@ class LoadAndRemapMatrixTest(test.TestCase):
with self.test_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, num_cols)),
- remapped_weight_matrix.eval())
+ remapped_matrix.eval())
- def test_load_and_remap_duplicate_row_remapping(self):
- """Tests when an old row maps to multiple new rows.
+ def test_load_and_remap_invalid_remapping(self):
+ """Tests that errors are raised when an ID maps to multiple new IDs.
(This should usually not happen when using public APIs).
"""
- row_remapping = [1, 0, 0, 0, 1, 2]
- remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ invalid_remapping = [1, 0, 0, 0, 1, 2]
+
+ # Invalid row remapping.
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
- row_remapping=row_remapping,
+ row_remapping=invalid_remapping,
col_remapping=[],
initializing_values=[],
- num_rows=len(row_remapping),
+ num_rows=len(invalid_remapping),
num_cols=self.old_num_cols)
- with self.test_session():
- self.assertAllClose(self.matrix_value[row_remapping],
- remapped_weight_matrix.eval())
-
- def test_load_and_remap_invalid_col_remapping(self):
- """Tests that an error is raised when an old col maps to multiple new cols.
+ with self.test_session(), self.assertRaises(errors.UnimplementedError):
+ remapped_matrix.eval()
- (This should usually not happen when using public APIs).
- """
- col_remapping = [1, 0, 0, 0, 1, 2]
- remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ # Invalid column remapping.
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=[self.bundle_file],
old_tensor_name=self.old_tensor_name,
row_remapping=list(range(self.old_num_rows)),
- col_remapping=col_remapping,
+ col_remapping=invalid_remapping,
initializing_values=[],
num_rows=self.old_num_rows,
- num_cols=len(col_remapping))
+ num_cols=len(invalid_remapping))
with self.test_session(), self.assertRaises(errors.UnimplementedError):
- remapped_weight_matrix.eval()
+ remapped_matrix.eval()
+
+ def test_load_and_remap_incorrect_initializing_values(self):
+ """Tests that errors are raised with incorrect number of init values."""
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ ckpt_path=[self.bundle_file],
+ old_tensor_name=self.old_tensor_name,
+ row_remapping=[2, -1, 0],
+ col_remapping=[1, -1],
+ # Too few initializing values - there should be 4. For some reason,
+ # initializing_values must contain no element (instead of 3 or fewer) to
+ # ensure that a seg fault would reliably occur if the check raising the
+ # InvalidArgumentError were not present.
+ initializing_values=[],
+ num_rows=3,
+ num_cols=2)
+ with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+ remapped_matrix.eval()
+
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ ckpt_path=[self.bundle_file],
+ old_tensor_name=self.old_tensor_name,
+ row_remapping=[2, -1, 0],
+ col_remapping=[1, -1],
+ # Too many initializing values - there should be 4.
+ initializing_values=[0] * 5,
+ num_rows=3,
+ num_cols=2)
+ with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+ remapped_matrix.eval()
+
+
+class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase):
+ """Tests for the load_and_remap_matrix() op.
+
+ (Specifically focused on the max_rows_in_memory arg and its effects on
+ TensorBundle's BundleReader and TensorSlice logic).
+ """
+
+ def _test_loading_variable_with_max_rows(self, np_value, partitioner,
+ max_rows_in_memory):
+ """Helper function for various tests using max_rows_in_memory."""
+ ops.reset_default_graph()
+ old_tensor_name = 'matrix_to_load_and_remap'
+ matrix = variable_scope.get_variable(
+ old_tensor_name,
+ dtype=dtypes.float32,
+ initializer=constant_op.constant(np_value, dtype=dtypes.float32),
+ partitioner=partitioner)
+
+ with self.test_session() as sess:
+ ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
+ save = saver.Saver([matrix])
+ variables.global_variables_initializer().run()
+ save.save(sess, ckpt_path)
+ num_rows, num_cols = np_value.shape
+
+ # Tests loading the entire tensor (except reversed).
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ ckpt_path=ckpt_path,
+ old_tensor_name=old_tensor_name,
+ # Simply reverses the rows of the matrix.
+ row_remapping=list(range(num_rows - 1, -1, -1)),
+ col_remapping=[],
+ initializing_values=[],
+ num_rows=num_rows,
+ num_cols=num_cols,
+ max_rows_in_memory=max_rows_in_memory)
+ self.assertAllClose(np_value[::-1], remapped_matrix.eval())
+
+ # Tests loading the tensor (except for the first and last rows), with
+ # uninitialized values. Requires num_rows to be at least 3 since we're
+ # skipping the first and last rows.
+ self.assertGreater(num_rows, 2)
+ prefix_rows = 2
+ suffix_rows = 3
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ ckpt_path=ckpt_path,
+ old_tensor_name=old_tensor_name,
+ # Reverses the rows of the matrix, then prepends and appends
+ # uninitialized rows.
+ row_remapping=([-1] * prefix_rows + list(range(1, num_rows - 1)) +
+ [-1] * suffix_rows),
+ col_remapping=[],
+ initializing_values=[42] * (prefix_rows + suffix_rows) * num_cols,
+ num_rows=num_rows - 2 + prefix_rows + suffix_rows,
+ num_cols=num_cols,
+ max_rows_in_memory=max_rows_in_memory)
+ self.assertAllClose(
+ np.vstack([
+ np.tile(42, [prefix_rows, num_cols]), np_value[1:-1],
+ np.tile(42, [suffix_rows, num_cols])
+ ]), remapped_matrix.eval())
+
+ # Tests when everything is taken from initializing_values.
+ new_rows = 7
+ initializing_values = [42] * new_rows * num_cols
+ remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
+ ckpt_path=ckpt_path,
+ old_tensor_name=old_tensor_name,
+ # Nothing is loaded from the old tensor.
+ row_remapping=[-1] * new_rows,
+ col_remapping=[],
+ initializing_values=initializing_values,
+ num_rows=new_rows,
+ num_cols=num_cols,
+ max_rows_in_memory=max_rows_in_memory)
+ self.assertAllClose(
+ np.reshape(initializing_values, (new_rows, num_cols)),
+ remapped_matrix.eval())
+
+ def test_loading_rows_divisible_by_max_rows(self):
+ """Tests loading normal var when rows are evenly divisible by max_rows."""
+ self._test_loading_variable_with_max_rows(
+ np_value=np.reshape(list(range(0, 36)), (9, 4)),
+ partitioner=None,
+ # 9 is evenly divisible by 3.
+ max_rows_in_memory=3)
+
+ def test_loading_rows_not_divisible_by_max_rows(self):
+ """Tests loading normal var when rows aren't divisible by max_rows."""
+ self._test_loading_variable_with_max_rows(
+ np_value=np.reshape(list(range(0, 36)), (9, 4)),
+ partitioner=None,
+ # 9 is not evenly divisible by 4.
+ max_rows_in_memory=4)
+
+ def test_loading_rows_less_than_max_rows(self):
+ """Tests loading normal var as a single slice.
+
+ (When the specified max_rows_in_memory is larger than the number of rows)
+ """
+ self._test_loading_variable_with_max_rows(
+ np_value=np.reshape(list(range(0, 36)), (9, 4)),
+ partitioner=None,
+ # 10 > 9.
+ max_rows_in_memory=10)
+
+ def test_loading_no_max_rows(self):
+ """Tests loading normal var as a single slice with no valid max_rows."""
+ self._test_loading_variable_with_max_rows(
+ np_value=np.reshape(list(range(0, 18)), (6, 3)),
+ partitioner=None,
+ max_rows_in_memory=-1)
+
+ def test_loading_partitions_equals_max_rows(self):
+ """Tests loading partitioned var sliced on partition boundary."""
+ self._test_loading_variable_with_max_rows(
+ np_value=np.reshape(list(range(0, 36)), (9, 4)),
+ partitioner=partitioned_variables.fixed_size_partitioner(3),
+ # With a tensor of shape [9, 3] and 3 partitions, each partition has
+ # exactly 3 rows.
+ max_rows_in_memory=3)
+
+ def test_loading_partitions_greater_than_max_rows(self):
+ """Tests loading partitioned var with more slices than partitions."""
+ self._test_loading_variable_with_max_rows(
+ np_value=np.reshape(list(range(0, 36)), (9, 4)),
+ partitioner=partitioned_variables.fixed_size_partitioner(3),
+ # Even though each partition has 3 rows, we'll only load the tensor one
+ # row at a time.
+ max_rows_in_memory=1)
+
+ def test_loading_partitions_less_than_max_rows(self):
+ """Tests loading partitioned var as a single slice.
+
+ (When the specified max_rows_in_memory is larger than the number of rows)
+ """
+ self._test_loading_variable_with_max_rows(
+ np_value=np.reshape(list(range(0, 36)), (9, 4)),
+ partitioner=partitioned_variables.fixed_size_partitioner(3),
+ max_rows_in_memory=10)
+
+ def test_loading_partitions_no_max_rows(self):
+ """Tests loading partitioned var as single slice with no valid max_rows."""
+ self._test_loading_variable_with_max_rows(
+ np_value=np.reshape(list(range(0, 36)), (9, 4)),
+ partitioner=partitioned_variables.fixed_size_partitioner(3),
+ max_rows_in_memory=-1)
class LoadAndRemapWrappersTest(test.TestCase):
diff --git a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj
deleted file mode 100644
index e9d783e49d..0000000000
--- a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj
+++ /dev/null
@@ -1,431 +0,0 @@
-// !$*UTF8*$!
-{
- archiveVersion = 1;
- classes = {
- };
- objectVersion = 46;
- objects = {
-
-/* Begin PBXBuildFile section */
- 591D3EC51CFF7F130059011C /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3EC41CFF7F120059011C /* AVFoundation.framework */; };
- 591D3ECB1CFF7F5F0059011C /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3ECA1CFF7F5F0059011C /* CoreMedia.framework */; };
- 591D3ECD1CFF7F9F0059011C /* AssetsLibrary.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3ECC1CFF7F9F0059011C /* AssetsLibrary.framework */; };
- 591D3ECF1CFF7FCE0059011C /* ImageIO.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3ECE1CFF7FCE0059011C /* ImageIO.framework */; };
- 591D3ED21CFF85C30059011C /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 591D3ED11CFF85C30059011C /* ios_image_load.mm */; };
- 591D3ED51CFF85FD0059011C /* tensorflow_utils.mm in Sources */ = {isa = PBXBuildFile; fileRef = 591D3ED31CFF85FD0059011C /* tensorflow_utils.mm */; };
- 591D3EDB1CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 591D3ED81CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt */; };
- 591D3EDC1CFFA83A0059011C /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 591D3ED91CFFA83A0059011C /* tensorflow_inception_graph.pb */; };
- 591D3EDF1CFFAD230059011C /* libprotobuf-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3EDD1CFFAD230059011C /* libprotobuf-lite.a */; };
- 591D3EE01CFFAD230059011C /* libprotobuf.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3EDE1CFFAD230059011C /* libprotobuf.a */; };
- 592FF8B918ECBD7600C164F8 /* Foundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 592FF8B818ECBD7600C164F8 /* Foundation.framework */; };
- 592FF8BB18ECBD7600C164F8 /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 592FF8BA18ECBD7600C164F8 /* CoreGraphics.framework */; };
- 592FF90218ECC66200C164F8 /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 592FF90118ECC66200C164F8 /* main.mm */; };
- 592FF90D18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 592FF90A18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard */; };
- 592FF92518EE240200C164F8 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 592FF92218EE240200C164F8 /* CameraExampleAppDelegate.m */; };
- 592FF92618EE240200C164F8 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 592FF92418EE240200C164F8 /* CameraExampleViewController.mm */; };
- 5993C7721D5D4E980048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C7711D5D4E980048CE6A /* Accelerate.framework */; };
-/* End PBXBuildFile section */
-
-/* Begin PBXFileReference section */
- 591D3EC41CFF7F120059011C /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/AVFoundation.framework; sourceTree = DEVELOPER_DIR; };
- 591D3EC61CFF7F370059011C /* CoreFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreFoundation.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/CoreFoundation.framework; sourceTree = DEVELOPER_DIR; };
- 591D3EC81CFF7F500059011C /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/CoreImage.framework; sourceTree = DEVELOPER_DIR; };
- 591D3ECA1CFF7F5F0059011C /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/CoreMedia.framework; sourceTree = DEVELOPER_DIR; };
- 591D3ECC1CFF7F9F0059011C /* AssetsLibrary.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AssetsLibrary.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/AssetsLibrary.framework; sourceTree = DEVELOPER_DIR; };
- 591D3ECE1CFF7FCE0059011C /* ImageIO.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = ImageIO.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/ImageIO.framework; sourceTree = DEVELOPER_DIR; };
- 591D3ED01CFF85C30059011C /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = SOURCE_ROOT; };
- 591D3ED11CFF85C30059011C /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = SOURCE_ROOT; };
- 591D3ED31CFF85FD0059011C /* tensorflow_utils.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = tensorflow_utils.mm; sourceTree = SOURCE_ROOT; };
- 591D3ED41CFF85FD0059011C /* tensorflow_utils.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = tensorflow_utils.h; sourceTree = SOURCE_ROOT; };
- 591D3ED81CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = "<group>"; };
- 591D3ED91CFFA83A0059011C /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = "<group>"; };
- 591D3EDD1CFFAD230059011C /* libprotobuf-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libprotobuf-lite.a"; path = "../../makefile/gen/protobuf_ios/lib/libprotobuf-lite.a"; sourceTree = "<group>"; };
- 591D3EDE1CFFAD230059011C /* libprotobuf.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libprotobuf.a; path = ../../makefile/gen/protobuf_ios/lib/libprotobuf.a; sourceTree = "<group>"; };
- 592FF8B518ECBD7600C164F8 /* CameraExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = CameraExample.app; sourceTree = BUILT_PRODUCTS_DIR; };
- 592FF8B818ECBD7600C164F8 /* Foundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Foundation.framework; path = System/Library/Frameworks/Foundation.framework; sourceTree = SDKROOT; };
- 592FF8BA18ECBD7600C164F8 /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
- 592FF90118ECC66200C164F8 /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = SOURCE_ROOT; };
- 592FF90318ECCB8300C164F8 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = SOURCE_ROOT; };
- 592FF90B18EDD0DA00C164F8 /* en */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = en; path = MainStoryboard_iPhone.storyboard; sourceTree = "<group>"; };
- 592FF92118EE240200C164F8 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = SOURCE_ROOT; };
- 592FF92218EE240200C164F8 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = SOURCE_ROOT; };
- 592FF92318EE240200C164F8 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = SOURCE_ROOT; };
- 592FF92418EE240200C164F8 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = SOURCE_ROOT; };
- 5993C7711D5D4E980048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.3.sdk/System/Library/Frameworks/Accelerate.framework; sourceTree = DEVELOPER_DIR; };
-/* End PBXFileReference section */
-
-/* Begin PBXFrameworksBuildPhase section */
- 592FF8B218ECBD7600C164F8 /* Frameworks */ = {
- isa = PBXFrameworksBuildPhase;
- buildActionMask = 2147483647;
- files = (
- 5993C7721D5D4E980048CE6A /* Accelerate.framework in Frameworks */,
- 591D3EDF1CFFAD230059011C /* libprotobuf-lite.a in Frameworks */,
- 591D3EE01CFFAD230059011C /* libprotobuf.a in Frameworks */,
- 591D3ECF1CFF7FCE0059011C /* ImageIO.framework in Frameworks */,
- 591D3ECD1CFF7F9F0059011C /* AssetsLibrary.framework in Frameworks */,
- 591D3ECB1CFF7F5F0059011C /* CoreMedia.framework in Frameworks */,
- 591D3EC51CFF7F130059011C /* AVFoundation.framework in Frameworks */,
- 592FF8BB18ECBD7600C164F8 /* CoreGraphics.framework in Frameworks */,
- 592FF8B918ECBD7600C164F8 /* Foundation.framework in Frameworks */,
- );
- runOnlyForDeploymentPostprocessing = 0;
- };
-/* End PBXFrameworksBuildPhase section */
-
-/* Begin PBXGroup section */
- 591D3ED61CFFA83A0059011C /* data */ = {
- isa = PBXGroup;
- children = (
- 591D3ED81CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt */,
- 591D3ED91CFFA83A0059011C /* tensorflow_inception_graph.pb */,
- );
- path = data;
- sourceTree = SOURCE_ROOT;
- };
- 592FF8AA18ECBD3600C164F8 = {
- isa = PBXGroup;
- children = (
- 592FF8BE18ECBD7600C164F8 /* CameraExample */,
- 592FF8B718ECBD7600C164F8 /* Frameworks */,
- 592FF8B618ECBD7600C164F8 /* Products */,
- );
- sourceTree = "<group>";
- };
- 592FF8B618ECBD7600C164F8 /* Products */ = {
- isa = PBXGroup;
- children = (
- 592FF8B518ECBD7600C164F8 /* CameraExample.app */,
- );
- name = Products;
- sourceTree = "<group>";
- };
- 592FF8B718ECBD7600C164F8 /* Frameworks */ = {
- isa = PBXGroup;
- children = (
- 5993C7711D5D4E980048CE6A /* Accelerate.framework */,
- 591D3EDD1CFFAD230059011C /* libprotobuf-lite.a */,
- 591D3EDE1CFFAD230059011C /* libprotobuf.a */,
- 591D3ECE1CFF7FCE0059011C /* ImageIO.framework */,
- 591D3ECC1CFF7F9F0059011C /* AssetsLibrary.framework */,
- 591D3ECA1CFF7F5F0059011C /* CoreMedia.framework */,
- 591D3EC81CFF7F500059011C /* CoreImage.framework */,
- 591D3EC61CFF7F370059011C /* CoreFoundation.framework */,
- 591D3EC41CFF7F120059011C /* AVFoundation.framework */,
- 592FF8B818ECBD7600C164F8 /* Foundation.framework */,
- 592FF8BA18ECBD7600C164F8 /* CoreGraphics.framework */,
- );
- name = Frameworks;
- sourceTree = "<group>";
- };
- 592FF8BE18ECBD7600C164F8 /* CameraExample */ = {
- isa = PBXGroup;
- children = (
- 591D3ED61CFFA83A0059011C /* data */,
- 592FF90718EDD0DA00C164F8 /* en.lproj */,
- 592FF92118EE240200C164F8 /* CameraExampleAppDelegate.h */,
- 592FF92218EE240200C164F8 /* CameraExampleAppDelegate.m */,
- 592FF92318EE240200C164F8 /* CameraExampleViewController.h */,
- 592FF92418EE240200C164F8 /* CameraExampleViewController.mm */,
- 592FF90318ECCB8300C164F8 /* Info.plist */,
- 591D3ED01CFF85C30059011C /* ios_image_load.h */,
- 591D3ED11CFF85C30059011C /* ios_image_load.mm */,
- 592FF90118ECC66200C164F8 /* main.mm */,
- 591D3ED31CFF85FD0059011C /* tensorflow_utils.mm */,
- 591D3ED41CFF85FD0059011C /* tensorflow_utils.h */,
- );
- name = CameraExample;
- path = SimpleExample;
- sourceTree = "<group>";
- };
- 592FF90718EDD0DA00C164F8 /* en.lproj */ = {
- isa = PBXGroup;
- children = (
- 592FF90A18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard */,
- );
- path = en.lproj;
- sourceTree = SOURCE_ROOT;
- };
-/* End PBXGroup section */
-
-/* Begin PBXNativeTarget section */
- 592FF8B418ECBD7600C164F8 /* CameraExample */ = {
- isa = PBXNativeTarget;
- buildConfigurationList = 592FF8E318ECBD7600C164F8 /* Build configuration list for PBXNativeTarget "CameraExample" */;
- buildPhases = (
- 592FF8B118ECBD7600C164F8 /* Sources */,
- 592FF8B218ECBD7600C164F8 /* Frameworks */,
- 592FF8B318ECBD7600C164F8 /* Resources */,
- );
- buildRules = (
- );
- dependencies = (
- );
- name = CameraExample;
- productName = SimpleExample;
- productReference = 592FF8B518ECBD7600C164F8 /* CameraExample.app */;
- productType = "com.apple.product-type.application";
- };
-/* End PBXNativeTarget section */
-
-/* Begin PBXProject section */
- 592FF8AB18ECBD3600C164F8 /* Project object */ = {
- isa = PBXProject;
- attributes = {
- LastUpgradeCheck = 0720;
- };
- buildConfigurationList = 592FF8AE18ECBD3600C164F8 /* Build configuration list for PBXProject "camera_example" */;
- compatibilityVersion = "Xcode 3.2";
- developmentRegion = English;
- hasScannedForEncodings = 0;
- knownRegions = (
- en,
- );
- mainGroup = 592FF8AA18ECBD3600C164F8;
- productRefGroup = 592FF8B618ECBD7600C164F8 /* Products */;
- projectDirPath = "";
- projectRoot = "";
- targets = (
- 592FF8B418ECBD7600C164F8 /* CameraExample */,
- );
- };
-/* End PBXProject section */
-
-/* Begin PBXResourcesBuildPhase section */
- 592FF8B318ECBD7600C164F8 /* Resources */ = {
- isa = PBXResourcesBuildPhase;
- buildActionMask = 2147483647;
- files = (
- 591D3EDC1CFFA83A0059011C /* tensorflow_inception_graph.pb in Resources */,
- 592FF90D18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard in Resources */,
- 591D3EDB1CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt in Resources */,
- );
- runOnlyForDeploymentPostprocessing = 0;
- };
-/* End PBXResourcesBuildPhase section */
-
-/* Begin PBXSourcesBuildPhase section */
- 592FF8B118ECBD7600C164F8 /* Sources */ = {
- isa = PBXSourcesBuildPhase;
- buildActionMask = 2147483647;
- files = (
- 592FF90218ECC66200C164F8 /* main.mm in Sources */,
- 591D3ED21CFF85C30059011C /* ios_image_load.mm in Sources */,
- 592FF92618EE240200C164F8 /* CameraExampleViewController.mm in Sources */,
- 592FF92518EE240200C164F8 /* CameraExampleAppDelegate.m in Sources */,
- 591D3ED51CFF85FD0059011C /* tensorflow_utils.mm in Sources */,
- );
- runOnlyForDeploymentPostprocessing = 0;
- };
-/* End PBXSourcesBuildPhase section */
-
-/* Begin PBXVariantGroup section */
- 592FF90A18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard */ = {
- isa = PBXVariantGroup;
- children = (
- 592FF90B18EDD0DA00C164F8 /* en */,
- );
- name = MainStoryboard_iPhone.storyboard;
- sourceTree = "<group>";
- };
-/* End PBXVariantGroup section */
-
-/* Begin XCBuildConfiguration section */
- 592FF8AF18ECBD3600C164F8 /* Debug */ = {
- isa = XCBuildConfiguration;
- buildSettings = {
- CLANG_WARN_BOOL_CONVERSION = YES;
- CLANG_WARN_CONSTANT_CONVERSION = YES;
- CLANG_WARN_EMPTY_BODY = YES;
- CLANG_WARN_ENUM_CONVERSION = YES;
- CLANG_WARN_INFINITE_RECURSION = YES;
- CLANG_WARN_INT_CONVERSION = YES;
- CLANG_WARN_SUSPICIOUS_MOVE = YES;
- CLANG_WARN_UNREACHABLE_CODE = YES;
- CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
- ENABLE_STRICT_OBJC_MSGSEND = YES;
- ENABLE_TESTABILITY = YES;
- GCC_NO_COMMON_BLOCKS = YES;
- GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
- GCC_WARN_ABOUT_RETURN_TYPE = YES;
- GCC_WARN_UNDECLARED_SELECTOR = YES;
- GCC_WARN_UNINITIALIZED_AUTOS = YES;
- GCC_WARN_UNUSED_FUNCTION = YES;
- GCC_WARN_UNUSED_VARIABLE = YES;
- ONLY_ACTIVE_ARCH = YES;
- };
- name = Debug;
- };
- 592FF8B018ECBD3600C164F8 /* Release */ = {
- isa = XCBuildConfiguration;
- buildSettings = {
- CLANG_WARN_BOOL_CONVERSION = YES;
- CLANG_WARN_CONSTANT_CONVERSION = YES;
- CLANG_WARN_EMPTY_BODY = YES;
- CLANG_WARN_ENUM_CONVERSION = YES;
- CLANG_WARN_INFINITE_RECURSION = YES;
- CLANG_WARN_INT_CONVERSION = YES;
- CLANG_WARN_SUSPICIOUS_MOVE = YES;
- CLANG_WARN_UNREACHABLE_CODE = YES;
- CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
- ENABLE_STRICT_OBJC_MSGSEND = YES;
- GCC_NO_COMMON_BLOCKS = YES;
- GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
- GCC_WARN_ABOUT_RETURN_TYPE = YES;
- GCC_WARN_UNDECLARED_SELECTOR = YES;
- GCC_WARN_UNINITIALIZED_AUTOS = YES;
- GCC_WARN_UNUSED_FUNCTION = YES;
- GCC_WARN_UNUSED_VARIABLE = YES;
- };
- name = Release;
- };
- 592FF8DF18ECBD7600C164F8 /* Debug */ = {
- isa = XCBuildConfiguration;
- buildSettings = {
- ALWAYS_SEARCH_USER_PATHS = NO;
- ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
- ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
- CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
- CLANG_CXX_LIBRARY = "compiler-default";
- CLANG_ENABLE_MODULES = YES;
- CLANG_ENABLE_OBJC_ARC = YES;
- CLANG_WARN_BOOL_CONVERSION = YES;
- CLANG_WARN_CONSTANT_CONVERSION = YES;
- CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
- CLANG_WARN_EMPTY_BODY = YES;
- CLANG_WARN_ENUM_CONVERSION = YES;
- CLANG_WARN_INT_CONVERSION = YES;
- CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
- CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
- "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
- COPY_PHASE_STRIP = NO;
- ENABLE_BITCODE = NO;
- FRAMEWORK_SEARCH_PATHS = "$(inherited)";
- GCC_C_LANGUAGE_STANDARD = gnu99;
- GCC_DYNAMIC_NO_PIC = NO;
- GCC_OPTIMIZATION_LEVEL = 0;
- GCC_PRECOMPILE_PREFIX_HEADER = YES;
- GCC_PREFIX_HEADER = "";
- GCC_PREPROCESSOR_DEFINITIONS = (
- "DEBUG=1",
- "$(inherited)",
- );
- GCC_SYMBOLS_PRIVATE_EXTERN = NO;
- GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
- GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
- GCC_WARN_UNDECLARED_SELECTOR = YES;
- GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
- GCC_WARN_UNUSED_FUNCTION = YES;
- GCC_WARN_UNUSED_VARIABLE = YES;
- HEADER_SEARCH_PATHS = (
- "$(SRCROOT)/../../makefile/gen/proto",
- "$(SRCROOT)/../../makefile/downloads/eigen",
- "$(SRCROOT)/../../makefile/downloads",
- "$(SRCROOT)/../../makefile/downloads/protobuf/src/",
- "$(SRCROOT)/../../../..",
- );
- INFOPLIST_FILE = "$(SRCROOT)/Info.plist";
- IPHONEOS_DEPLOYMENT_TARGET = 9.2;
- LIBRARY_SEARCH_PATHS = (
- "$(SRCROOT)/../../makefile/gen/lib",
- "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
- );
- ONLY_ACTIVE_ARCH = NO;
- OTHER_LDFLAGS = (
- "-force_load",
- "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
- );
- PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample;
- PRODUCT_NAME = "$(TARGET_NAME)";
- SDKROOT = iphoneos;
- TARGETED_DEVICE_FAMILY = "1,2";
- VALID_ARCHS = "arm64 armv7 armv7s";
- WRAPPER_EXTENSION = app;
- };
- name = Debug;
- };
- 592FF8E018ECBD7600C164F8 /* Release */ = {
- isa = XCBuildConfiguration;
- buildSettings = {
- ALWAYS_SEARCH_USER_PATHS = NO;
- ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
- ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
- CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
- CLANG_CXX_LIBRARY = "compiler-default";
- CLANG_ENABLE_MODULES = YES;
- CLANG_ENABLE_OBJC_ARC = YES;
- CLANG_WARN_BOOL_CONVERSION = YES;
- CLANG_WARN_CONSTANT_CONVERSION = YES;
- CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
- CLANG_WARN_EMPTY_BODY = YES;
- CLANG_WARN_ENUM_CONVERSION = YES;
- CLANG_WARN_INT_CONVERSION = YES;
- CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
- CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
- "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
- COPY_PHASE_STRIP = YES;
- ENABLE_BITCODE = NO;
- ENABLE_NS_ASSERTIONS = NO;
- FRAMEWORK_SEARCH_PATHS = "$(inherited)";
- GCC_C_LANGUAGE_STANDARD = gnu99;
- GCC_PRECOMPILE_PREFIX_HEADER = YES;
- GCC_PREFIX_HEADER = "";
- GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
- GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
- GCC_WARN_UNDECLARED_SELECTOR = YES;
- GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
- GCC_WARN_UNUSED_FUNCTION = YES;
- GCC_WARN_UNUSED_VARIABLE = YES;
- HEADER_SEARCH_PATHS = (
- "$(SRCROOT)/../../makefile/gen/proto",
- "$(SRCROOT)/../../makefile/downloads/eigen",
- "$(SRCROOT)/../../makefile/downloads",
- "$(SRCROOT)/../../makefile/downloads/protobuf/src/",
- "$(SRCROOT)/../../../..",
- );
- INFOPLIST_FILE = "$(SRCROOT)/Info.plist";
- IPHONEOS_DEPLOYMENT_TARGET = 9.2;
- LIBRARY_SEARCH_PATHS = (
- "$(SRCROOT)/../../makefile/gen/lib",
- "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
- );
- ONLY_ACTIVE_ARCH = NO;
- OTHER_LDFLAGS = (
- "-force_load",
- "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
- );
- PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample;
- PRODUCT_NAME = "$(TARGET_NAME)";
- SDKROOT = iphoneos;
- TARGETED_DEVICE_FAMILY = "1,2";
- VALIDATE_PRODUCT = YES;
- VALID_ARCHS = "arm64 armv7 armv7s";
- WRAPPER_EXTENSION = app;
- };
- name = Release;
- };
-/* End XCBuildConfiguration section */
-
-/* Begin XCConfigurationList section */
- 592FF8AE18ECBD3600C164F8 /* Build configuration list for PBXProject "camera_example" */ = {
- isa = XCConfigurationList;
- buildConfigurations = (
- 592FF8AF18ECBD3600C164F8 /* Debug */,
- 592FF8B018ECBD3600C164F8 /* Release */,
- );
- defaultConfigurationIsVisible = 0;
- defaultConfigurationName = Release;
- };
- 592FF8E318ECBD7600C164F8 /* Build configuration list for PBXNativeTarget "CameraExample" */ = {
- isa = XCConfigurationList;
- buildConfigurations = (
- 592FF8DF18ECBD7600C164F8 /* Debug */,
- 592FF8E018ECBD7600C164F8 /* Release */,
- );
- defaultConfigurationIsVisible = 0;
- defaultConfigurationName = Release;
- };
-/* End XCConfigurationList section */
- };
- rootObject = 592FF8AB18ECBD3600C164F8 /* Project object */;
-}
diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc
index 219473153b..72df272af8 100644
--- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc
+++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc
@@ -41,13 +41,7 @@ class ColumnInterface {
virtual int64 FeatureCount(int64 batch) const = 0;
// Returns the fingerprint of nth feature from the specified batch.
- InternalType Feature(int64 batch, int64 n) const {
- InternalType not_used = InternalType();
- return DoFeature(batch, n, not_used);
- }
-
- virtual InternalType DoFeature(int64 batch, int64 n,
- InternalType not_used) const = 0;
+ virtual InternalType Feature(int64 batch, int64 n) const = 0;
virtual ~ColumnInterface() {}
};
@@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
return feature_counts_[batch];
}
- // InternalType is int64 only when using HashCrosser.
- int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
- const int64 start = feature_start_indices_[batch];
- if (DT_STRING == values_.dtype())
- return Fingerprint64(values_.vec<string>().data()[start + n]);
- return values_.vec<int64>().data()[start + n];
- }
-
- // InternalType is string or StringPiece when using StringCrosser.
- string DoFeature(int64 batch, int64 n, string not_used) const {
- const int64 start = feature_start_indices_[batch];
- if (DT_STRING == values_.dtype())
- return values_.vec<string>().data()[start + n];
- return std::to_string(values_.vec<int64>().data()[start + n]);
- }
-
- StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
- const int64 start = feature_start_indices_[batch];
- return values_.vec<string>().data()[start + n];
- }
+ InternalType Feature(int64 batch, int64 n) const override;
~SparseTensorColumn() override {}
@@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
std::vector<int64> feature_start_indices_;
};
+// InternalType is int64 only when using HashCrosser.
+template <>
+int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
+ const int64 start = feature_start_indices_[batch];
+ if (DT_STRING == values_.dtype())
+ return Fingerprint64(values_.vec<string>().data()[start + n]);
+ return values_.vec<int64>().data()[start + n];
+}
+
+// InternalType is string or StringPiece when using StringCrosser.
+template <>
+string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
+ const int64 start = feature_start_indices_[batch];
+ if (DT_STRING == values_.dtype())
+ return values_.vec<string>().data()[start + n];
+ return std::to_string(values_.vec<int64>().data()[start + n]);
+}
+
+template <>
+StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
+ int64 n) const {
+ const int64 start = feature_start_indices_[batch];
+ return values_.vec<string>().data()[start + n];
+}
+
// A column that is backed by a dense tensor.
template <typename InternalType>
class DenseTensorColumn : public ColumnInterface<InternalType> {
@@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
- // InternalType is int64 only when using HashCrosser.
- int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
- if (DT_STRING == tensor_.dtype())
- return Fingerprint64(tensor_.matrix<string>()(batch, n));
- return tensor_.matrix<int64>()(batch, n);
- }
-
- // Internal type is string or StringPiece when using StringCrosser.
- string DoFeature(int64 batch, int64 n, string not_used) const {
- if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
- return std::to_string(tensor_.matrix<int64>()(batch, n));
- }
-
- StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
- return tensor_.matrix<string>()(batch, n);
- }
+ InternalType Feature(int64 batch, int64 n) const override;
~DenseTensorColumn() override {}
@@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
const Tensor& tensor_;
};
+// InternalType is int64 only when using HashCrosser.
+template <>
+int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
+ if (DT_STRING == tensor_.dtype())
+ return Fingerprint64(tensor_.matrix<string>()(batch, n));
+ return tensor_.matrix<int64>()(batch, n);
+}
+
+// Internal type is string or StringPiece when using StringCrosser.
+template <>
+string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
+ if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
+ return std::to_string(tensor_.matrix<int64>()(batch, n));
+}
+
+template <>
+StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
+ int64 n) const {
+ return tensor_.matrix<string>()(batch, n);
+}
+
// Updates Output tensors with sparse crosses.
template <typename OutType>
class OutputUpdater {
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index ddd3d087e7..b87b75d5c4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -209,7 +209,9 @@ def _get_replica_device_setter(config):
"""
ps_ops = [
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
- 'MutableHashTableOfTensors', 'MutableDenseHashTable'
+ 'MutableHashTableV2', 'MutableHashTableOfTensors',
+ 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
+ 'MutableDenseHashTableV2'
]
if config.task_type:
@@ -955,6 +957,7 @@ class BaseEstimator(
self._check_inputs(features, labels)
model_fn_ops = self._get_train_ops(features, labels)
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
+ all_hooks.extend(hooks)
all_hooks.extend([
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
basic_session_run_hooks.LoggingTensorHook(
@@ -964,7 +967,6 @@ class BaseEstimator(
},
every_n_iter=100)
])
- all_hooks.extend(hooks)
scaffold = model_fn_ops.scaffold or monitored_session.Scaffold()
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 7600d30539..d5d413c56a 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -18,807 +18,32 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-import functools
-
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_lookup_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import lookup_ops
+# pylint: disable=unused-import
+from tensorflow.python.ops.lookup_ops import FastHashSpec
+from tensorflow.python.ops.lookup_ops import HasherSpec
+from tensorflow.python.ops.lookup_ops import HashTable
+from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets
+from tensorflow.python.ops.lookup_ops import index_table_from_file
+from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file
+from tensorflow.python.ops.lookup_ops import InitializableLookupTableBase
+from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer
+from tensorflow.python.ops.lookup_ops import LookupInterface
+from tensorflow.python.ops.lookup_ops import StrongHashSpec
+from tensorflow.python.ops.lookup_ops import TableInitializerBase
+from tensorflow.python.ops.lookup_ops import TextFileIdTableInitializer
+from tensorflow.python.ops.lookup_ops import TextFileIndex
+from tensorflow.python.ops.lookup_ops import TextFileInitializer
+from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer
+# pylint: enable=unused-import
from tensorflow.python.training.saver import BaseSaverBuilder
-from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
-class LookupInterface(object):
- """Represent a lookup table that persists across different steps."""
-
- def __init__(self, key_dtype, value_dtype, name):
- """Construct a lookup table interface.
-
- Args:
- key_dtype: The table key type.
- value_dtype: The table value type.
- name: A name for the operation (optional).
- """
- self._key_dtype = dtypes.as_dtype(key_dtype)
- self._value_dtype = dtypes.as_dtype(value_dtype)
- self._name = name
-
- @property
- def key_dtype(self):
- """The table key dtype."""
- return self._key_dtype
-
- @property
- def value_dtype(self):
- """The table value dtype."""
- return self._value_dtype
-
- @property
- def name(self):
- """The name of the table."""
- return self._name
-
- @property
- def init(self):
- """The table initialization op."""
- raise NotImplementedError
-
- def size(self, name=None):
- """Compute the number of elements in this table."""
- raise NotImplementedError
-
- def lookup(self, keys, name=None):
- """Looks up `keys` in a table, outputs the corresponding values."""
- raise NotImplementedError
-
- def check_table_dtypes(self, key_dtype, value_dtype):
- """Check that the given key_dtype and value_dtype matches the table dtypes.
-
- Args:
- key_dtype: The key data type to check.
- value_dtype: The value data type to check.
-
- Raises:
- TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
- types.
- """
- if key_dtype != self.key_dtype:
- raise TypeError("Invalid key dtype, expected %s but got %s." %
- (self.key_dtype, key_dtype))
- if value_dtype != self.value_dtype:
- raise TypeError("Invalid value dtype, expected %s but got %s." %
- (self.value_dtype, value_dtype))
-
-
-class InitializableLookupTableBase(LookupInterface):
- """Initializable lookup table interface.
-
- An initializable lookup tables persist across different steps.
- """
-
- def __init__(self, table_ref, default_value, initializer):
- """Construct a table object from a table reference.
-
- If requires a table initializer object (subclass of `TableInitializerBase`).
- It provides the table key and value types, as well as the op to initialize
- the table. The caller is responsible to execute the initialization op.
-
- Args:
- table_ref: The table reference, i.e. the output of the lookup table ops.
- default_value: The value to use if a key is missing in the table.
- initializer: The table initializer to use.
- """
- super(InitializableLookupTableBase, self).__init__(
- initializer.key_dtype, initializer.value_dtype,
- table_ref.op.name.split("/")[-1])
- self._table_ref = table_ref
- self._default_value = ops.convert_to_tensor(default_value,
- dtype=self._value_dtype)
- self._default_value.get_shape().merge_with(tensor_shape.scalar())
- self._init = initializer.initialize(self)
-
- @property
- def table_ref(self):
- """Get the underlying table reference."""
- return self._table_ref
-
- @property
- def default_value(self):
- """The default value of the table."""
- return self._default_value
-
- @property
- def init(self):
- """The table initialization op."""
- return self._init
-
- def size(self, name=None):
- """Compute the number of elements in this table.
-
- Args:
- name: A name for the operation (optional).
-
- Returns:
- A scalar tensor containing the number of elements in this table.
- """
- with ops.name_scope(name, "%s_Size" % self._name,
- [self._table_ref]) as scope:
- # pylint: disable=protected-access
- return gen_lookup_ops._lookup_table_size(self._table_ref, name=scope)
- # pylint: enable=protected-access
-
- def lookup(self, keys, name=None):
- """Looks up `keys` in a table, outputs the corresponding values.
-
- The `default_value` is used for keys not present in the table.
-
- Args:
- keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
- name: A name for the operation (optional).
-
- Returns:
- A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
-
- Raises:
- TypeError: when `keys` or `default_value` doesn't match the table data
- types.
- """
- key_tensor = keys
- if isinstance(keys, sparse_tensor.SparseTensor):
- key_tensor = keys.values
-
- if keys.dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
- (self._key_dtype, keys.dtype))
-
- with ops.name_scope(
- name, "%s_Lookup" % self._name,
- (self._table_ref, key_tensor, self._default_value)) as scope:
- # pylint: disable=protected-access
- values = gen_lookup_ops._lookup_table_find(
- self._table_ref, key_tensor, self._default_value, name=scope)
- # pylint: enable=protected-access
-
- values.set_shape(key_tensor.get_shape())
- if isinstance(keys, sparse_tensor.SparseTensor):
- return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
- else:
- return values
-
-
-class HashTable(InitializableLookupTableBase):
- """A generic hash table implementation.
-
- Example usage:
-
- ```python
- table = tf.contrib.lookup.HashTable(
- tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1)
- out = table.lookup(input_tensor).
- table.init.run()
- print out.eval()
- ```
- """
-
- def __init__(self, initializer, default_value, shared_name=None, name=None):
- """Creates a non-initialized `HashTable` object.
-
- Creates a table, the type of its keys and values are specified by the
- initializer.
- Before using the table you will have to initialize it. After initialization
- the table will be immutable.
-
- Args:
- initializer: The table initializer to use. See `HashTable` kernel for
- supported key and value types.
- default_value: The value to use if a key is missing in the table.
- shared_name: If non-empty, this table will be shared under
- the given name across multiple sessions.
- name: A name for the operation (optional).
-
- Returns:
- A `HashTable` object.
- """
- with ops.name_scope(
- name, "hash_table", (initializer, default_value)) as scope:
- # pylint: disable=protected-access
- table_ref = gen_lookup_ops._hash_table(
- shared_name=shared_name,
- key_dtype=initializer.key_dtype,
- value_dtype=initializer.value_dtype,
- name=scope)
- # pylint: enable=protected-access
-
- super(HashTable, self).__init__(table_ref, default_value, initializer)
-
-
-class TableInitializerBase(object):
- """Base class for lookup table initializers."""
-
- def __init__(self, key_dtype, value_dtype):
- """Construct a table initializer object.
-
- Args:
- key_dtype: Type of the table keys.
- value_dtype: Type of the table values.
- """
- self._key_dtype = dtypes.as_dtype(key_dtype)
- self._value_dtype = dtypes.as_dtype(value_dtype)
-
- @property
- def key_dtype(self):
- """The expected table key dtype."""
- return self._key_dtype
-
- @property
- def value_dtype(self):
- """The expected table value dtype."""
- return self._value_dtype
-
- def initialize(self, table):
- """Returns the table initialization op."""
- raise NotImplementedError
-
-
-class KeyValueTensorInitializer(TableInitializerBase):
- """Table initializers given `keys` and `values` tensors."""
-
- def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
- """Constructs a table initializer object based on keys and values tensors.
-
- Args:
- keys: The tensor for the keys.
- values: The tensor for the values.
- key_dtype: The `keys` data type. Used when `keys` is a python array.
- value_dtype: The `values` data type. Used when `values` is a python array.
- name: A name for the operation (optional).
- """
- with ops.name_scope(name, "key_value_init", [keys, values]) as scope:
- self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
- self._values = ops.convert_to_tensor(values,
- dtype=value_dtype,
- name="values")
- self._name = scope
-
- super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
- self._values.dtype)
-
- def initialize(self, table):
- """Initializes the given `table` with `keys` and `values` tensors.
-
- Args:
- table: The table to initialize.
-
- Returns:
- The operation that initializes the table.
-
- Raises:
- TypeError: when the keys and values data types do not match the table
- key and value data types.
- """
- table.check_table_dtypes(self._keys.dtype, self._values.dtype)
- with ops.name_scope(
- self._name,
- values=(table.table_ref, self._keys, self._values)) as scope:
- # pylint: disable=protected-access
- init_op = gen_lookup_ops._initialize_table(
- table.table_ref, self._keys, self._values, name=scope)
- # pylint: enable=protected-access
- ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
- return init_op
-
-
-class TextFileIndex(object):
- WHOLE_LINE = -2
- LINE_NUMBER = -1
-
-
-class TextFileInitializer(TableInitializerBase):
- """Table initializers from a text file.
-
- This initializer assigns one entry in the table for each line in the file.
-
- The key and value type of the table to initialize is given by `key_dtype` and
- `value_dtype`.
-
- The key and value content to get from each line is specified by
- the `key_index` and `value_index`.
-
- * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
- expects data type int64.
- * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
- type string.
- * A value `>=0` means use the index (starting at zero) of the split line based
- on `delimiter`.
-
- For example if we have a file with the following content:
-
- ```
- emerson 10
- lake 20
- palmer 30
- ```
-
- The following snippet initializes a table with the first column as keys and
- second column as values:
-
- * `emerson -> 10`
- * `lake -> 20`
- * `palmer -> 30`
-
- ```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
- "test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1)
- ...
- table.init.run()
- ```
-
- Similarly to initialize the whole line as keys and the line number as values.
-
- * `emerson 10 -> 0`
- * `lake 20 -> 1`
- * `palmer 30 -> 2`
-
- ```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
- "test.txt", tf.string, tf.contrib.lookup.TextFileIndex.WHOLE_LINE,
- tf.int64, tf.contrib.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
- ...
- table.init.run()
- ```
- """
-
- def __init__(self,
- filename,
- key_dtype,
- key_index,
- value_dtype,
- value_index,
- vocab_size=None,
- delimiter="\t",
- name=None):
- """Constructs a table initializer object to populate from a text file.
-
- It generates one key-value pair per line. The type of table key and
- value are specified by `key_dtype` and `value_dtype`, respectively.
- Similarly the content of the key and value are specified by the key_index
- and value_index.
-
- - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
- expects data type int64.
- - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
- type string.
- - A value >=0 means use the index (starting at zero) of the split line based
- on `delimiter`.
-
- Args:
- filename: The filename of the text file to be used for initialization.
- The path must be accessible from wherever the graph is initialized
- (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
- key_dtype: The `key` data type.
- key_index: the index that represents information of a line to get the
- table 'key' values from.
- value_dtype: The `value` data type.
- value_index: the index that represents information of a line to get the
- table 'value' values from.'
- vocab_size: The number of elements in the file, if known.
- delimiter: The delimiter to separate fields in a line.
- name: A name for the operation (optional).
-
- Raises:
- ValueError: when the filename is empty, or when the table key and value
- data types do not match the expected data types.
- """
- if not isinstance(filename, ops.Tensor) and not filename:
- raise ValueError("Filename required for %s." % name)
-
- key_dtype = dtypes.as_dtype(key_dtype)
- value_dtype = dtypes.as_dtype(value_dtype)
-
- if key_index < -2:
- raise ValueError("Invalid key index %s." % (key_index))
-
- if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
- raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
- (dtypes.int64, key_dtype))
- if ((key_index == TextFileIndex.WHOLE_LINE) and
- (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
- raise ValueError(
- "Signature mismatch. Keys must be integer or string, got %s." %
- key_dtype)
- if value_index < -2:
- raise ValueError("Invalid value index %s." % (value_index))
-
- if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
- raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
- (dtypes.int64, value_dtype))
- if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
- raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
- (dtypes.string, value_dtype))
-
- if (vocab_size is not None) and (vocab_size <= 0):
- raise ValueError("Invalid vocab_size %s." % vocab_size)
-
- self._filename = filename
- self._key_index = key_index
- self._value_index = value_index
- self._vocab_size = vocab_size
- self._delimiter = delimiter
- self._name = name
-
- super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
-
- def initialize(self, table):
- """Initializes the table from a text file.
-
- Args:
- table: The table to be initialized.
-
- Returns:
- The operation that initializes the table.
-
- Raises:
- TypeError: when the keys and values data types do not match the table
- key and value data types.
- """
- table.check_table_dtypes(self.key_dtype, self.value_dtype)
- with ops.name_scope(
- self._name, "text_file_init", (table.table_ref,)) as scope:
- filename = ops.convert_to_tensor(self._filename,
- dtypes.string,
- name="asset_filepath")
- # pylint: disable=protected-access
- init_op = gen_lookup_ops._initialize_table_from_text_file(
- table.table_ref,
- filename,
- self._key_index,
- self._value_index,
- -1 if self._vocab_size is None else self._vocab_size,
- self._delimiter,
- name=scope)
- # pylint: enable=protected-access
- ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
- # If the filename tensor is anything other than a string constant (e.g., if
- # it is a placeholder) then it does not make sense to track it as an asset.
- if constant_op.is_constant(filename):
- ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
- return init_op
-
-
-class TextFileStringTableInitializer(TextFileInitializer):
- """Table initializer for `int64` IDs to string tables from a text file."""
-
- def __init__(self,
- filename,
- key_column_index=TextFileIndex.LINE_NUMBER,
- value_column_index=TextFileIndex.WHOLE_LINE,
- vocab_size=None,
- delimiter="\t",
- name="text_file_string_table_init"):
- """Constructs an initializer for an id-to-string table from a text file.
-
- It populates a table that its key and value types are int64 and string,
- respectively. It generates one key-value pair per line.
- The content of the key and value are specified by `key_column_index`
- and `value_column_index`.
-
- - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
- expects data type int64.
- - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
- type string.
- - A value >=0 means use the index (starting at zero) of the split line based
- on `delimiter`.
-
- Args:
- filename: The filename of the text file to be used for initialization.
- The path must be accessible from wherever the graph is initialized
- (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
- key_column_index: The column index from the text file to get the keys
- from. The default is 0 that represents the whole line content.
- value_column_index: The column index from the text file to get the
- values from. The default is to use the line number, starting from zero.
- vocab_size: The number of elements in the file, if known.
- delimiter: The delimiter to separate fields in a line.
- name: Optional name for the op.
-
- Raises:
- TypeError: when the filename is empty, or when the table key and value
- data types do not match the expected data types.
- """
- super(TextFileStringTableInitializer, self).__init__(filename,
- dtypes.int64,
- key_column_index,
- dtypes.string,
- value_column_index,
- vocab_size=vocab_size,
- delimiter=delimiter,
- name=name)
-
-
-class TextFileIdTableInitializer(TextFileInitializer):
- """Table initializer for string to `int64` IDs tables from a text file."""
-
- def __init__(self,
- filename,
- key_column_index=TextFileIndex.WHOLE_LINE,
- value_column_index=TextFileIndex.LINE_NUMBER,
- vocab_size=None,
- delimiter="\t",
- name="text_file_id_table_init",
- key_dtype=dtypes.string):
- """Constructs an initializer for an string-to-id table from a text file.
-
- It populates a table that its key and value types are string and int64,
- respectively. It generates one key-value pair per line.
- The content of the key and value are specified by the key_index
- and value_index.
-
- - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
- expects data type int64.
- - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
- type string.
- - A value >=0 means use the index (starting at zero) of the split line based
- on `delimiter`.
-
- Args:
- filename: The filename of the text file to be used for initialization.
- The path must be accessible from wherever the graph is initialized
- (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
- key_column_index: The column index from the text file to get the `key`
- values from. The default is to use the line number, starting from zero.
- value_column_index: The column index from the text file ro get the `value`
- values from. The default is 0 that represents the whole line content.
- vocab_size: The number of elements in the file, if known.
- delimiter: The delimiter to separate fields in a line.
- name: Optional name for the op.
- key_dtype: The `key` data type.
-
- Raises:
- TypeError: when the filename is empty, or when the table key and value
- data types do not match the expected data types.
- """
- super(TextFileIdTableInitializer, self).__init__(filename,
- key_dtype,
- key_column_index,
- dtypes.int64,
- value_column_index,
- vocab_size=vocab_size,
- delimiter=delimiter,
- name=name)
-
-
-class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
- """A structure for the spec of the hashing function to use for hash buckets.
-
- `hasher` is the name of the hashing function to use (eg. "fasthash",
- "stronghash").
- `key` is optional and specify the key to use for the hash function if
- supported, currently only used by a strong hash.
-
- Fields:
- hasher: The hasher name to use.
- key: The key to be used by the hashing function, if required.
- """
- __slots__ = ()
-
-
-FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name
-
-
-class StrongHashSpec(HasherSpec):
- """A structure to specify a key of the strong keyed hash spec.
-
- The strong hash requires a `key`, which is a list of 2 unsigned integer
- numbers. These should be non-zero; random numbers generated from random.org
- would be a fine choice.
-
- Fields:
- key: The key to be used by the keyed hashing function.
- """
- __slots__ = ()
-
- def __new__(cls, key):
- if len(key) != 2:
- raise ValueError("key must have size 2, got %s." % len(key))
-
- if not isinstance(key[0], compat.integral_types) or not isinstance(
- key[1], compat.integral_types):
- raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
-
- return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
-
-
-def _as_string(tensor):
- if dtypes.string == tensor.dtype.base_dtype:
- return tensor
- return string_ops.as_string(tensor)
-
-
-class IdTableWithHashBuckets(LookupInterface):
- """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
-
- For example, if an instance of `IdTableWithHashBuckets` is initialized with a
- string-to-id table that maps:
-
- - emerson -> 0
- - lake -> 1
- - palmer -> 2
-
- The `IdTableWithHashBuckets` object will performs the following mapping:
-
- - emerson -> 0
- - lake -> 1
- - palmer -> 2
- - <other term> -> bucket id between 3 and 3 + num_oov_buckets, calculated by:
- hash(<term>) % num_oov_buckets + vocab_size
-
- If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
- the lookup result is `[0, 1, 2, 4, 7]`
-
- If `table` is None, only out-of-vocabulary buckets are used.
-
- Example usage:
-
- ```python
- num_oov_buckets = 3
- input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
- table = tf.IdTableWithHashBuckets(
- tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),
- num_oov_buckets)
- out = table.lookup(input_tensor).
- table.init.run()
- print out.eval()
- ```
-
- The hash function used for generating out-of-vocabulary buckets ID is handled
- by `hasher_spec`.
- """
-
- def __init__(self,
- table,
- num_oov_buckets,
- hasher_spec=FastHashSpec,
- name=None,
- key_dtype=None):
- """Construct a `IdTableWithHashBuckets` object.
-
- Args:
- table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
- num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
- hasher_spec: A `HasherSpec` to specify the hash function to use for
- assignation of out-of-vocabulary buckets (optional).
- name: A name for the operation (optional).
- key_dtype: Data type of keys passed to `lookup`. Defaults to
- `table.key_dtype` if `table` is specified, otherwise `tf.string`.
- Must be string or integer, and must be castable to `table.key_dtype`.
-
- Raises:
- ValueError: when `table` in None and `num_oov_buckets` is not positive.
- TypeError: when `hasher_spec` is invalid.
- """
- # If a name ends with a '/' it is a "name scope", remove all trailing '/'
- # characters to use as table name.
- if name:
- name = name.rstrip("/")
- if table:
- if key_dtype is None:
- key_dtype = table.key_dtype
- supported_table_key_dtypes = (dtypes.int64, dtypes.string)
- if table.key_dtype not in supported_table_key_dtypes:
- raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
- (supported_table_key_dtypes, key_dtype))
- if table.key_dtype.is_integer != key_dtype.is_integer:
- raise TypeError("Invalid key dtype, expected %s but got %s." %
- ("integer" if key_dtype.is_integer else "non-integer",
- table.key_dtype))
- if table.value_dtype != dtypes.int64:
- raise TypeError("Invalid value dtype, expected %s but got %s." %
- (dtypes.int64, table.value_dtype))
- self._table = table
- name = name or self._table.name
- else:
- if num_oov_buckets <= 0:
- raise ValueError("oov_buckets must be > 0 if no table is supplied.")
- key_dtype = dtypes.string if key_dtype is None else key_dtype
- self._table = None
- name = name or "hash_bucket"
- if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
- raise TypeError(
- "Invalid key_dtype, expected integer or string, got %s." % key_dtype)
- self._num_oov_buckets = num_oov_buckets
-
- if not isinstance(hasher_spec, HasherSpec):
- raise TypeError("hasher_spec must be of type HasherSpec, got %s" %
- hasher_spec)
- self._hasher_spec = hasher_spec
- super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64,
- name.split("/")[-1])
-
- @property
- def init(self):
- """The table initialization op."""
- if self._table:
- return self._table.init
- with ops.name_scope(None, "init"):
- return control_flow_ops.no_op()
-
- def size(self, name=None):
- """Compute the number of elements in this table."""
- with ops.name_scope(name, "%s_Size" % self.name) as scope:
- if self._table:
- tsize = self._table.size(scope)
- else:
- tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
- return tsize + self._num_oov_buckets
-
- def _get_string_to_hash_bucket_fn(self, hasher_spec):
- """Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
- if not isinstance(hasher_spec, HasherSpec):
- raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec)
- if hasher_spec.hasher == "fasthash":
- return string_ops.string_to_hash_bucket_fast
- if hasher_spec.hasher == "legacy":
- return string_ops.string_to_hash_bucket
- if hasher_spec.hasher == "stronghash":
- return functools.partial(
- string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
- raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
-
- def lookup(self, keys, name=None):
- """Looks up `keys` in the table, outputs the corresponding values.
-
- It assigns out-of-vocabulary keys to buckets based in their hashes.
-
- Args:
- keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
- name: Optional name for the op.
-
- Returns:
- A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
-
- Raises:
- TypeError: when `keys` doesn't match the table key data type.
- """
- if keys.dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
- (self._key_dtype, keys.dtype))
- values = keys
- if isinstance(keys, sparse_tensor.SparseTensor):
- values = keys.values
- if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
- values = math_ops.to_int64(values)
-
- if self._num_oov_buckets == 0:
- ids = self._table.lookup(values, name=name)
- else:
- # TODO(yleon): Consider moving this functionality to its own kernel.
- with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
- str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
- self._hasher_spec)
- buckets = str_to_hash_bucket(
- _as_string(values),
- num_buckets=self._num_oov_buckets,
- name="hash_bucket")
- if self._table:
- ids = self._table.lookup(values)
- buckets = math_ops.add(buckets, self._table.size())
- is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
- ids = array_ops.where(is_id_non_default, ids, buckets, name=scope)
- else:
- ids = buckets
- if isinstance(keys, sparse_tensor.SparseTensor):
- return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
- return ids
-
-
@deprecated("2017-04-10", "Use `index_table_from_file`.")
def string_to_index_table_from_file(vocabulary_file=None,
num_oov_buckets=0,
@@ -831,113 +56,6 @@ def string_to_index_table_from_file(vocabulary_file=None,
key_dtype=dtypes.string, name=name)
-def index_table_from_file(vocabulary_file=None,
- num_oov_buckets=0,
- vocab_size=None,
- default_value=-1,
- hasher_spec=FastHashSpec,
- key_dtype=dtypes.string,
- name=None):
- """Returns a lookup table that converts a string tensor into int64 IDs.
-
- This operation constructs a lookup table to convert tensor of strings into
- int64 IDs. The mapping can be initialized from a vocabulary file specified in
- `vocabulary_file`, where the whole line is the key and the zero-based line
- number is the ID.
-
- Any lookup of an out-of-vocabulary token will return a bucket ID based on its
- hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
- `default_value`.
- The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets]`.
-
- The underlying table must be initialized by calling
- `tf.tables_initializer.run()` or `table.init.run()` once.
-
- Sample Usages:
-
- If we have a vocabulary file "test.txt" with the following content:
-
- ```
- emerson
- lake
- palmer
- ```
-
- ```python
- features = tf.constant(["emerson", "lake", "and", "palmer"])
- table = tf.contrib.lookup.index_table_from_file(
- vocabulary_file="test.txt", num_oov_buckets=1)
- ids = table.lookup(features)
- ...
- tf.tables_initializer().run()
-
- ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket
- ```
-
- Args:
- vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
- num_oov_buckets: The number of out-of-vocabulary buckets.
- vocab_size: Number of the elements in the vocabulary, if known.
- default_value: The value to use for out-of-vocabulary feature values.
- Defaults to -1.
- hasher_spec: A `HasherSpec` to specify the hash function to use for
- assignation of out-of-vocabulary buckets.
- key_dtype: The `key` data type.
- name: A name for this op (optional).
-
- Returns:
- The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`.
-
- Raises:
- ValueError: If `vocabulary_file` is not set.
- ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
- than zero.
- """
- if vocabulary_file is None or (
- isinstance(vocabulary_file, str) and not vocabulary_file):
- raise ValueError("vocabulary_file must be specified and must not be empty.")
- if num_oov_buckets < 0:
- raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
- % num_oov_buckets)
- if vocab_size is not None and vocab_size < 1:
- raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
- if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
- raise TypeError("Only integer and string keys are supported.")
-
- with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
- table = None
- shared_name = ""
- with ops.name_scope(None, "hash_table") as hash_table_scope:
- if vocab_size:
- # Keep the shared_name:
- # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>
- shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size,
- TextFileIndex.WHOLE_LINE,
- TextFileIndex.LINE_NUMBER)
- else:
- # Keep the shared_name
- # <table_type>_<filename>_<key_index>_<value_index>
- shared_name = "hash_table_%s_%s_%s" % (vocabulary_file,
- TextFileIndex.WHOLE_LINE,
- TextFileIndex.LINE_NUMBER)
- init = TextFileIdTableInitializer(
- vocabulary_file, vocab_size=vocab_size,
- key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
- name="table_init")
-
- table = HashTable(
- init, default_value, shared_name=shared_name, name=hash_table_scope)
- if num_oov_buckets:
- table = IdTableWithHashBuckets(
- table,
- num_oov_buckets=num_oov_buckets,
- hasher_spec=hasher_spec,
- name=feat_to_id_scope,
- key_dtype=key_dtype)
-
- return table
-
-
@deprecated("2017-04-10", "Use `index_table_from_tensor`.")
def string_to_index_table_from_tensor(mapping,
num_oov_buckets=0,
@@ -1011,41 +129,13 @@ def index_table_from_tensor(mapping,
"""
if mapping is None:
raise ValueError("mapping must be specified.")
-
- if num_oov_buckets < 0:
- raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
- % num_oov_buckets)
-
- if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
- raise TypeError("Only integer and string keys are supported.")
-
- with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
- keys = ops.convert_to_tensor(mapping)
- if keys.dtype.is_integer != dtype.is_integer:
- raise ValueError("Expected %s, got %s." % (
- "integer" if dtype.is_integer else "non-integer", keys.dtype))
- if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
- raise ValueError("Expected %s, got %s." % (dtype, keys.dtype))
- num_elements = array_ops.size(keys)
- values = math_ops.to_int64(math_ops.range(num_elements))
-
- shared_name = ""
- with ops.name_scope(None, "hash_table") as hash_table_scope:
- table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys
- init = KeyValueTensorInitializer(
- table_keys, values, table_keys.dtype.base_dtype, dtypes.int64,
- name="table_init")
- table = HashTable(
- init, default_value, shared_name=shared_name, name=hash_table_scope)
- if num_oov_buckets:
- table = IdTableWithHashBuckets(
- table,
- num_oov_buckets=num_oov_buckets,
- hasher_spec=hasher_spec,
- name=feat_to_id_scope,
- key_dtype=dtype)
-
- return table
+ return lookup_ops.index_table_from_tensor(
+ vocabulary_list=mapping,
+ num_oov_buckets=num_oov_buckets,
+ default_value=default_value,
+ hasher_spec=hasher_spec,
+ dtype=dtype,
+ name=name)
@deprecated(
@@ -1098,83 +188,6 @@ def string_to_index(tensor, mapping, default_value=-1, name=None):
return table.lookup(tensor)
-def index_to_string_table_from_file(vocabulary_file,
- vocab_size=None,
- default_value="UNK",
- name=None):
- """Returns a lookup table that maps a `Tensor` of indices into strings.
-
- This operation constructs a lookup table to map int64 indices into string
- values. The table is initialized from a vocabulary file specified in
- `vocabulary_file`, where the whole line is the value and the
- zero-based line number is the index.
-
- Any input which does not have a corresponding index in the vocabulary file
- (an out-of-vocabulary entry) is assigned the `default_value`
-
- The underlying table must be initialized by calling
- `tf.tables_initializer.run()` or `table.init.run()` once.
-
- Sample Usages:
-
- If we have a vocabulary file "test.txt" with the following content:
-
- ```
- emerson
- lake
- palmer
- ```
-
- ```python
- indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_file(
- vocabulary_file="test.txt", default_value="UNKNOWN")
- values = table.lookup(indices)
- ...
- tf.tables_initializer().run()
-
- values.eval() ==> ["lake", "UNKNOWN"]
- ```
-
- Args:
- vocabulary_file: The vocabulary filename.
- vocab_size: Number of the elements in the vocabulary, if known.
- default_value: The value to use for out-of-vocabulary indices.
- name: A name for this op (optional).
-
- Returns:
- The lookup table to map a string values associated to a given index `int64`
- `Tensors`.
-
- Raises:
- ValueError: when `vocabulary_file` is empty.
- ValueError: when `vocab_size` is invalid.
- """
- if not vocabulary_file:
- raise ValueError("vocabulary_file must be specified.")
- if vocab_size is not None and vocab_size < 1:
- raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
-
- with ops.name_scope(name, "index_to_string") as scope:
- shared_name = ""
- if vocab_size:
- # Keep a shared_name
- # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>
- shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size,
- TextFileIndex.LINE_NUMBER,
- TextFileIndex.WHOLE_LINE)
- else:
- # Keep a shared_name <table_type>_<filename>_<key_index>_<value_index>
- shared_name = "hash_table_%s_%s_%s" % (vocabulary_file,
- TextFileIndex.LINE_NUMBER,
- TextFileIndex.WHOLE_LINE)
- init = TextFileStringTableInitializer(
- vocabulary_file, vocab_size=vocab_size, name="table_init")
-
- # TODO(yleon): Use a more effienct structure.
- return HashTable(init, default_value, shared_name=shared_name, name=scope)
-
-
def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None):
"""Returns a lookup table that maps a `Tensor` of indices into strings.
@@ -1223,16 +236,8 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None):
if mapping is None:
raise ValueError("mapping must be specified.")
- with ops.name_scope(name, "index_to_string") as scope:
- values = ops.convert_to_tensor(mapping, dtypes.string)
- num_elements = array_ops.size(values)
- keys = math_ops.to_int64(math_ops.range(num_elements))
-
- shared_name = ""
- init = KeyValueTensorInitializer(
- keys, values, dtypes.int64, dtypes.string, name="table_init")
- # TODO(yleon): Use a more effienct structure.
- return HashTable(init, default_value, shared_name=shared_name, name=scope)
+ return lookup_ops.index_to_string_table_from_tensor(
+ vocabulary_list=mapping, default_value=default_value, name=name)
@deprecated(
@@ -1338,14 +343,14 @@ class MutableHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
# pylint: disable=protected-access
if self._default_value.get_shape().ndims == 0:
- self._table_ref = gen_lookup_ops._mutable_hash_table(
+ self._table_ref = gen_lookup_ops._mutable_hash_table_v2(
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=key_dtype,
value_dtype=value_dtype,
name=name)
else:
- self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors(
+ self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors_v2(
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
key_dtype=key_dtype,
@@ -1372,8 +377,10 @@ class MutableHashTable(LookupInterface):
"""
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as name:
- # pylint: disable=protected-access
- return gen_lookup_ops._lookup_table_size(self._table_ref, name=name)
+ with ops.colocate_with(self._table_ref):
+
+ # pylint: disable=protected-access
+ return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=name)
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -1398,11 +405,12 @@ class MutableHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
(self._table_ref, keys, self._default_value)) as name:
- # pylint: disable=protected-access
- values = gen_lookup_ops._lookup_table_find(
- self._table_ref, keys, self._default_value, name=name)
+ with ops.colocate_with(self._table_ref):
+ # pylint: disable=protected-access
+ values = gen_lookup_ops._lookup_table_find_v2(
+ self._table_ref, keys, self._default_value, name=name)
- values.set_shape(keys.get_shape().concatenate(self._value_shape))
+ values.set_shape(keys.get_shape().concatenate(self._value_shape))
return values
def insert(self, keys, values, name=None):
@@ -1422,13 +430,16 @@ class MutableHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
- self.check_table_dtypes(keys.dtype, values.dtype)
+ # pylint: disable=protected-access
+ lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype)
+ # pylint: enable=protected-access
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
- # pylint: disable=protected-access
- op = gen_lookup_ops._lookup_table_insert(
- self._table_ref, keys, values, name=name)
- return op
+ with ops.colocate_with(self._table_ref):
+ # pylint: disable=protected-access
+ op = gen_lookup_ops._lookup_table_insert_v2(
+ self._table_ref, keys, values, name=name)
+ return op
def export(self, name=None):
"""Returns tensors of all keys and values in the table.
@@ -1442,9 +453,10 @@ class MutableHashTable(LookupInterface):
"""
with ops.name_scope(name, "%s_lookup_table_export_values" % self._name,
[self._table_ref]) as name:
- # pylint: disable=protected-access
- exported_keys, exported_values = gen_lookup_ops._lookup_table_export(
- self._table_ref, self._key_dtype, self._value_dtype, name=name)
+ with ops.colocate_with(self._table_ref):
+ # pylint: disable=protected-access
+ exported_keys, exported_values = gen_lookup_ops._lookup_table_export_v2(
+ self._table_ref, self._key_dtype, self._value_dtype, name=name)
exported_values.set_shape(exported_keys.get_shape().concatenate(
self._value_shape))
@@ -1464,8 +476,9 @@ class MutableHashTable(LookupInterface):
def restore(self, restored_tensors, unused_restored_shapes):
# pylint: disable=protected-access
- return gen_lookup_ops._lookup_table_import(
- self.op._table_ref, restored_tensors[0], restored_tensors[1])
+ with ops.colocate_with(self.op._table_ref):
+ return gen_lookup_ops._lookup_table_import_v2(
+ self.op._table_ref, restored_tensors[0], restored_tensors[1])
class MutableDenseHashTable(LookupInterface):
@@ -1539,7 +552,7 @@ class MutableDenseHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
# pylint: disable=protected-access
- self._table_ref = gen_lookup_ops._mutable_dense_hash_table(
+ self._table_ref = gen_lookup_ops._mutable_dense_hash_table_v2(
empty_key=empty_key,
shared_name=shared_name,
use_node_name_sharing=use_node_name_sharing,
@@ -1566,8 +579,9 @@ class MutableDenseHashTable(LookupInterface):
"""
with ops.name_scope(name, "%s_Size" % self._name,
[self._table_ref]) as name:
- # pylint: disable=protected-access
- return gen_lookup_ops._lookup_table_size(self._table_ref, name=name)
+ with ops.colocate_with(self._table_ref):
+ # pylint: disable=protected-access
+ return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=name)
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
@@ -1592,9 +606,10 @@ class MutableDenseHashTable(LookupInterface):
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
[self._table_ref, keys]) as name:
- # pylint: disable=protected-access
- values = gen_lookup_ops._lookup_table_find(
- self._table_ref, keys, self._default_value, name=name)
+ with ops.colocate_with(self._table_ref):
+ # pylint: disable=protected-access
+ values = gen_lookup_ops._lookup_table_find_v2(
+ self._table_ref, keys, self._default_value, name=name)
if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
values.set_shape(
@@ -1619,12 +634,15 @@ class MutableDenseHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
- self.check_table_dtypes(keys.dtype, values.dtype)
+ # pylint: disable=protected-access
+ lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype)
+ # pylint: enable=protected-access
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
- # pylint: disable=protected-access
- op = gen_lookup_ops._lookup_table_insert(
- self._table_ref, keys, values, name=name)
+ with ops.colocate_with(self._table_ref):
+ # pylint: disable=protected-access
+ op = gen_lookup_ops._lookup_table_insert_v2(
+ self._table_ref, keys, values, name=name)
return op
def export(self, name=None):
@@ -1639,9 +657,10 @@ class MutableDenseHashTable(LookupInterface):
"""
with ops.name_scope(name, "%s_lookup_table_export_values" % self._name,
[self._table_ref]) as name:
- # pylint: disable=protected-access
- exported_keys, exported_values = gen_lookup_ops._lookup_table_export(
- self._table_ref, self._key_dtype, self._value_dtype, name=name)
+ with ops.colocate_with(self._table_ref):
+ # pylint: disable=protected-access
+ exported_keys, exported_values = gen_lookup_ops._lookup_table_export_v2(
+ self._table_ref, self._key_dtype, self._value_dtype, name=name)
exported_values.set_shape(exported_keys.get_shape().concatenate(
self._value_shape))
@@ -1661,5 +680,6 @@ class MutableDenseHashTable(LookupInterface):
def restore(self, restored_tensors, unused_restored_shapes):
# pylint: disable=protected-access
- return gen_lookup_ops._lookup_table_import(
- self.op._table_ref, restored_tensors[0], restored_tensors[1])
+ with ops.colocate_with(self.op._table_ref):
+ return gen_lookup_ops._lookup_table_import_v2(
+ self.op._table_ref, restored_tensors[0], restored_tensors[1])
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index 112dacc9ab..09aa30a20b 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -2040,6 +2040,9 @@ class RawRNNTest(test.TestCase):
inputs_ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
+ # Verify emit shapes may be unknown by feeding a placeholder that
+ # determines an emit shape.
+ unknown_dim = array_ops.placeholder(dtype=dtypes.int32)
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
@@ -2047,12 +2050,12 @@ class RawRNNTest(test.TestCase):
if cell_output is None:
emit_output = (array_ops.zeros(
[2, 3], dtype=dtypes.int32), array_ops.zeros(
- [1], dtype=dtypes.int64))
+ [unknown_dim], dtype=dtypes.int64))
next_state = cell.zero_state(batch_size, dtypes.float32)
else:
emit_output = (array_ops.ones(
[batch_size, 2, 3], dtype=dtypes.int32), array_ops.ones(
- [batch_size, 1], dtype=dtypes.int64))
+ [batch_size, unknown_dim], dtype=dtypes.int64))
next_state = cell_state
elements_finished = array_ops.tile([time_ >= max_time], [batch_size])
finished = math_ops.reduce_all(elements_finished)
@@ -2069,7 +2072,7 @@ class RawRNNTest(test.TestCase):
self.assertEqual([dtypes.int32, dtypes.int64],
[ta.dtype for ta in output_ta])
output = [ta.stack() for ta in output_ta]
- output_vals = sess.run(output)
+ output_vals = sess.run(output, feed_dict={unknown_dim: 1})
self.assertAllEqual(
np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0])
self.assertAllEqual(
diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py
index de57e7d81e..92beae35dd 100644
--- a/tensorflow/contrib/rnn/python/ops/gru_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py
@@ -98,7 +98,7 @@ class GRUBlockCell(rnn_cell_impl.RNNCell):
r"""Block GRU cell implementation.
The implementation is based on: http://arxiv.org/abs/1406.1078
- Computes the LSTM cell forward propagation for 1 time step.
+ Computes the GRU cell forward propagation for 1 time step.
This kernel op implements the following mathematical equations:
diff --git a/tensorflow/core/grappler/costs/virtual_placer.cc b/tensorflow/core/grappler/costs/virtual_placer.cc
index ff6eff0249..e06774fc41 100644
--- a/tensorflow/core/grappler/costs/virtual_placer.cc
+++ b/tensorflow/core/grappler/costs/virtual_placer.cc
@@ -36,11 +36,20 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) : has_gpu_(false) {
}
const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const {
+ string device = get_canonical_device_name(node);
+ if (device.empty()) {
+ return unknown_device_;
+ }
+ auto it = devices_.find(device);
+ DCHECK(it != devices_.end());
+ return it->second;
+}
+
+string VirtualPlacer::get_canonical_device_name(const NodeDef& node) const {
string device;
if (!node.device().empty()) {
- auto it = devices_.find(node.device());
- if (it != devices_.end()) {
- return it->second;
+ if (devices_.find(node.device()) != devices_.end()) {
+ return node.device();
}
DeviceNameUtils::ParsedName parsed_name;
bool parsed = DeviceNameUtils::ParseFullName(node.device(), &parsed_name);
@@ -57,7 +66,7 @@ const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const {
}
}
if (!parsed) {
- return unknown_device_;
+ return "";
} else {
device = strings::StrCat(
"/job:", parsed_name.job, "/replica:", parsed_name.replica,
@@ -71,11 +80,10 @@ const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const {
device = "/job:localhost/replica:0/task:0/cpu:0";
}
}
- auto it = devices_.find(device);
- if (it == devices_.end()) {
- return unknown_device_;
+ if (devices_.find(device) == devices_.end()) {
+ return "";
}
- return it->second;
+ return device;
}
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/virtual_placer.h b/tensorflow/core/grappler/costs/virtual_placer.h
index 40cd64e37c..85bd502c67 100644
--- a/tensorflow/core/grappler/costs/virtual_placer.h
+++ b/tensorflow/core/grappler/costs/virtual_placer.h
@@ -33,6 +33,11 @@ class VirtualPlacer {
const DeviceProperties& get_device(const NodeDef& node) const;
+ // Returns canonical device name that has a corresponding device in the
+ // cluster; returns empty string if no device found or the node.device() can
+ // not be parsed.
+ string get_canonical_device_name(const NodeDef& node) const;
+
private:
std::unordered_map<string, DeviceProperties> devices_;
bool has_gpu_;
diff --git a/tensorflow/core/grappler/costs/virtual_placer_test.cc b/tensorflow/core/grappler/costs/virtual_placer_test.cc
index 037c52713d..bc8d0e38ba 100644
--- a/tensorflow/core/grappler/costs/virtual_placer_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_placer_test.cc
@@ -37,12 +37,18 @@ TEST(VirtualPlacerTest, LocalDevices) {
NodeDef node;
node.set_op("Conv2D");
EXPECT_EQ("GPU", placer.get_device(node).type());
+ EXPECT_EQ("/job:localhost/replica:0/task:0/gpu:0",
+ placer.get_canonical_device_name(node));
node.set_device("CPU");
EXPECT_EQ("CPU", placer.get_device(node).type());
+ EXPECT_EQ("/job:localhost/replica:0/task:0/cpu:0",
+ placer.get_canonical_device_name(node));
node.set_device("GPU:0");
EXPECT_EQ("GPU", placer.get_device(node).type());
+ EXPECT_EQ("/job:localhost/replica:0/task:0/gpu:0",
+ placer.get_canonical_device_name(node));
}
TEST(VirtualPlacerTest, RemoteDevices) {
@@ -60,24 +66,32 @@ TEST(VirtualPlacerTest, RemoteDevices) {
node.set_op("Conv2D");
// There is no local device available
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
+ EXPECT_EQ("", placer.get_canonical_device_name(node));
node.set_device("/job:my_job/replica:0/task:0/cpu:0");
EXPECT_EQ("CPU", placer.get_device(node).type());
+ EXPECT_EQ("/job:my_job/replica:0/task:0/cpu:0",
+ placer.get_canonical_device_name(node));
node.set_device("/job:my_job/replica:0/task:0/gpu:0");
EXPECT_EQ("GPU", placer.get_device(node).type());
+ EXPECT_EQ("/job:my_job/replica:0/task:0/gpu:0",
+ placer.get_canonical_device_name(node));
// There is no local CPU available
node.set_device("CPU");
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
+ EXPECT_EQ("", placer.get_canonical_device_name(node));
node.set_device("GPU:0");
// There is no local GPU available
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
+ EXPECT_EQ("", placer.get_canonical_device_name(node));
// This isn't a valid name
node.set_device("/job:my_job/replica:0/task:0");
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
+ EXPECT_EQ("", placer.get_canonical_device_name(node));
}
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 8bf6a081e3..33b0e16093 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -53,7 +54,7 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
class DeviceSimple : public DeviceBase {
public:
- DeviceSimple() : DeviceBase(nullptr) {
+ DeviceSimple() : DeviceBase(Env::Default()) {
eigen_worker_threads_.num_threads = 1;
eigen_worker_threads_.workers = new thread::ThreadPool(
Env::Default(), "constant_folding", eigen_worker_threads_.num_threads);
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index e4d7c3d11e..319dbb68e6 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -1308,9 +1308,11 @@ __global__ void __launch_bounds__(640, 2)
// a partial convolution for two elements, one each in the lower and upper half
// of a tile. The intermediate result of 4 consecutive columns are then
// accumulated and written to shared memory. Finally, the values in shared
-// memory are warp-accumulated (in chunks of 32 elements) and summed up in
-// global memory using atomics.
-template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
+// up in global memory using atomics.
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
+ // Requirement: kAccumPixels * 8 >= args.in_rows * args.in_cols
+ int kAccumPixels>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
@@ -1321,7 +1323,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const int batches = args.batch;
const int in_rows = args.in_rows;
- const int in_cols = args.in_cols;
+ const int in_cols = blockDim.y; // slower (see b/62280718): args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
@@ -1352,8 +1354,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const int tensor_offset = block_rows * in_row_size;
// The accumulator has a fixed number of pixels that can be reduced by one
// warp. Pixels beyond block_pixels/4 are never written.
- const int accum_pixels = 32;
- const int accum_increment = accum_pixels * block_slices;
+ const int accum_increment = kAccumPixels * block_slices;
const int accum_size = filter_pixels * accum_increment;
const int thread_depth = threadIdx.x;
@@ -1383,7 +1384,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Position in accumulator (1 per 4 threads, depth major).
const int accum_pix = thread_pix / 4;
- const int accum_idx = thread_depth * accum_pixels + accum_pix;
+ const int accum_idx = thread_depth * kAccumPixels + accum_pix;
const int max_depth = in_depth - thread_depth;
const int accum_offset = tile_size + accum_idx;
@@ -1438,19 +1439,17 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const T* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
- const int filter_idx = i / accum_pixels;
+ const int filter_idx = i / kAccumPixels;
const int filter_pix = filter_idx / block_slices;
const int filter_depth = filter_idx % block_slices + start_depth;
const int filter_offset = filter_pix * in_depth + filter_depth;
if (filter_depth < in_depth) {
T val = accum_data[i];
- // Sum up the 32 pixels of the same depth from the accumulator.
- val += CudaShuffleDown(val, 16);
- val += CudaShuffleDown(val, 8);
- val += CudaShuffleDown(val, 4);
- val += CudaShuffleDown(val, 2);
- val += CudaShuffleDown(val, 1);
- if (!(thread_idx & 31) /* i.e. 'lane_idx == 0' */) {
+ // Warp-accumulate the pixels of the same depth from the accumulator.
+ for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
+ val += CudaShuffleDown(val, delta);
+ }
+ if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}
}
@@ -1567,9 +1566,11 @@ __global__ void __launch_bounds__(640, 2)
// a partial convolution for two elements, one each in the lower and upper half
// of a tile. The intermediate result of 4 consecutive columns are then
// accumulated and written to shared memory. Finally, the values in shared
-// memory are warp-accumulated (in chunks of 32 elements) and summed up in
-// global memory using atomics.
-template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
+// up in global memory using atomics.
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
+ // Requirement: kAccumPixels * 8 >= args.in_rows * args.in_cols
+ int kAccumPixels>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
@@ -1580,7 +1581,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const int batches = args.batch;
const int in_rows = args.in_rows;
- const int in_cols = args.in_cols;
+ const int in_cols = blockDim.x; // slower (see b/62280718): args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
@@ -1610,8 +1611,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const int in_blocks = (in_slices + block_slices - 1) / block_slices;
// The accumulator has a fixed number of pixels that can be reduced by one
// warp. Pixels beyond block_pixels/4 are never written.
- const int accum_pixels = 32;
- const int accum_increment = accum_pixels * block_slices;
+ const int accum_increment = kAccumPixels * block_slices;
const int accum_size = filter_pixels * accum_increment;
const int thread_col = threadIdx.x;
@@ -1640,7 +1640,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Position in accumulator (1 per 4 threads, depth major).
const int accum_pix = thread_pix / 4;
- const int accum_idx = thread_depth * accum_pixels + accum_pix;
+ const int accum_idx = thread_depth * kAccumPixels + accum_pix;
const int max_slice = in_slices - thread_depth;
const int accum_offset = tile_size + accum_idx;
@@ -1692,19 +1692,17 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const T* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
- const int filter_idx = i / accum_pixels;
+ const int filter_idx = i / kAccumPixels;
const int filter_pix = filter_idx / block_slices;
const int filter_depth = (slice + filter_idx % block_slices) % in_depth;
const int filter_offset = filter_pix * in_depth + filter_depth;
if (filter_depth < in_depth) {
T val = accum_data[i];
- // Sum up 32 pixels of the same depth from the accumulator.
- val += CudaShuffleDown(val, 16);
- val += CudaShuffleDown(val, 8);
- val += CudaShuffleDown(val, 4);
- val += CudaShuffleDown(val, 2);
- val += CudaShuffleDown(val, 1);
- if (!(thread_idx & 31) /* i.e. 'lane_idx == 0' */) {
+ // Warp-accumulate pixels of the same depth from the accumulator.
+ for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
+ val += CudaShuffleDown(val, delta);
+ }
+ if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}
}
@@ -1712,7 +1710,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
}
}
-template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
+ int kAccumPixels>
void LaunchDepthwiseConv2dBackpropFilterGPUSmall(
const GpuDevice& d, const DepthwiseArgs args, int block_rows,
int shared_memory_size, const T* out_backprop, const T* input,
@@ -1724,22 +1723,22 @@ void LaunchDepthwiseConv2dBackpropFilterGPUSmall(
dim3 block_dim = dim3(block_slices, args.in_cols, block_rows);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_out_backprop, d,
- DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>,
+ DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
- DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>
+ DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, out_backprop, input, filter_backprop);
} else if (data_format == FORMAT_NCHW) {
dim3 block_dim = dim3(args.in_cols, block_rows, block_slices);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_out_backprop, d,
- DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>,
+ DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
- DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>
+ DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, out_backprop, input, filter_backprop);
} else {
@@ -1759,21 +1758,39 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
return false;
}
+ const int in_pixels = args.in_rows * args.in_cols;
+ int accum_pixels = 8;
+ while (accum_pixels * 8 < in_pixels) {
+ accum_pixels *= 2;
+ }
+
const int block_slices = 8;
const int tile_cols = args.in_cols + args.filter_cols - 1;
const int tile_rows = block_rows * 2 + args.filter_rows - 1;
const int tile_pixels = tile_rows * tile_cols;
- const int accum_size = args.filter_rows * args.filter_cols * 32;
+ const int filter_pixels = args.filter_rows * args.filter_cols;
const int shared_memory_size =
- block_slices * (tile_pixels + accum_size) * sizeof(T);
+ block_slices * (tile_pixels + filter_pixels * accum_pixels) * sizeof(T);
if (shared_memory_size > d.sharedMemPerBlock()) {
return false;
}
- LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>(
- d, args, block_rows, shared_memory_size, out_backprop, input,
- filter_backprop, data_format);
+ if (accum_pixels == 8) {
+ LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight, 8>(
+ d, args, block_rows, shared_memory_size, out_backprop, input,
+ filter_backprop, data_format);
+ } else if (accum_pixels == 16) {
+ LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight, 16>(
+ d, args, block_rows, shared_memory_size, out_backprop, input,
+ filter_backprop, data_format);
+ } else {
+ LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight, 32>(
+ d, args, block_rows, shared_memory_size, out_backprop, input,
+ filter_backprop, data_format);
+ }
return true;
}
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index 6bffe48b61..593fa487c9 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -39,30 +39,47 @@ class FFTBase : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& in = ctx->input(0);
- const TensorShape& shape = in.shape();
+ const TensorShape& input_shape = in.shape();
const int fft_rank = Rank();
OP_REQUIRES(
- ctx, shape.dims() >= fft_rank,
+ ctx, input_shape.dims() >= fft_rank,
errors::InvalidArgument("Input must have rank of at least ", fft_rank,
- " but got: ", shape.DebugString()));
+ " but got: ", input_shape.DebugString()));
Tensor* out;
- TensorShape output_shape = shape;
+ TensorShape output_shape = input_shape;
uint64 fft_shape[3] = {0, 0, 0};
// In R2C or C2R mode, we use a second input to specify the FFT length
// instead of inferring it from the input shape.
if (IsReal()) {
const Tensor& fft_length = ctx->input(1);
- OP_REQUIRES(ctx, fft_length.shape().dims() == 1 &&
- fft_length.shape().dim_size(0) == fft_rank,
- errors::InvalidArgument("fft_length must have shape [",
+ OP_REQUIRES(ctx,
+ fft_length.shape().dims() == 1 &&
+ fft_length.shape().dim_size(0) == fft_rank,
+ errors::InvalidArgument("fft_length must have shape [",
fft_rank, "]"));
auto fft_length_as_vec = fft_length.vec<int32>();
for (int i = 0; i < fft_rank; ++i) {
fft_shape[i] = fft_length_as_vec(i);
- uint64 dim = IsForward() && i == fft_rank - 1 && fft_shape[i] != 0
+ // Each input dimension must have length of at least fft_shape[i]. For
+ // IRFFTs, the inner-most input dimension must have length of at least
+ // fft_shape[i] / 2 + 1.
+ bool inner_most = (i == fft_rank - 1);
+ uint64 min_input_dim_length =
+ !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
+ auto input_index = input_shape.dims() - fft_rank + i;
+ OP_REQUIRES(
+ ctx,
+ // We pass through empty tensors, so special case them here.
+ input_shape.dim_size(input_index) == 0 ||
+ input_shape.dim_size(input_index) >= min_input_dim_length,
+ errors::InvalidArgument(
+ "Input dimension ", input_index,
+ " must have length of at least ", min_input_dim_length,
+ " but got: ", input_shape.dim_size(input_index)));
+ uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
? fft_shape[i] / 2 + 1
: fft_shape[i];
output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
@@ -75,7 +92,7 @@ class FFTBase : public OpKernel {
}
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
- if (shape.num_elements() == 0) {
+ if (input_shape.num_elements() == 0) {
return;
}
@@ -119,20 +136,32 @@ class FFTCPU : public FFTBase {
} else {
if (IsForward()) {
auto input = (Tensor(in)).flat_inner_dims<float, FFTRank + 1>();
+ auto input_dims = input.dimensions();
+
+ // Slice input to fft_shape on its inner-most dimensions.
+ Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
+ input_slice_sizes[0] = input_dims[0];
+ TensorShape temp_shape{input_dims[0]};
+ for (int i = 1; i <= FFTRank; ++i) {
+ input_slice_sizes[i] = fft_shape[i - 1];
+ temp_shape.AddDim(fft_shape[i - 1]);
+ }
+
auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
- Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> startIndices;
+ const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
// Compute the full FFT using a temporary tensor.
Tensor temp;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
- in.shape(), &temp));
+ temp_shape, &temp));
auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
full_fft.device(device) =
- input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
+ input.slice(zero_start_indices, input_slice_sizes)
+ .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
// Slice away the negative frequency components.
output.device(device) =
- full_fft.slice(startIndices, output.dimensions());
+ full_fft.slice(zero_start_indices, output.dimensions());
} else {
// Reconstruct the full fft and take the inverse.
auto input = ((Tensor)in).flat_inner_dims<complex64, FFTRank + 1>();
diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc
index 51f11f9d2e..ed350d9833 100644
--- a/tensorflow/core/kernels/iterator_ops.cc
+++ b/tensorflow/core/kernels/iterator_ops.cc
@@ -307,7 +307,7 @@ class IteratorGetNextOp : public AsyncOpKernel {
core::ScopedUnref unref_iterator(iterator);
std::vector<Tensor> components;
- bool end_of_sequence;
+ bool end_of_sequence = false;
IteratorContext::Params params;
params.env = ctx->env();
diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc
index ed93caad33..c7bf250fad 100644
--- a/tensorflow/core/kernels/sparse_cross_op.cc
+++ b/tensorflow/core/kernels/sparse_cross_op.cc
@@ -41,13 +41,7 @@ class ColumnInterface {
virtual int64 FeatureCount(int64 batch) const = 0;
// Returns the fingerprint of nth feature from the specified batch.
- InternalType Feature(int64 batch, int64 n) const {
- InternalType not_used = InternalType();
- return DoFeature(batch, n, not_used);
- }
-
- virtual InternalType DoFeature(int64 batch, int64 n,
- InternalType not_used) const = 0;
+ virtual InternalType Feature(int64 batch, int64 n) const = 0;
virtual ~ColumnInterface() {}
};
@@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
return feature_counts_[batch];
}
- // InternalType is int64 only when using HashCrosser.
- int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
- const int64 start = feature_start_indices_[batch];
- if (DT_STRING == values_.dtype())
- return Fingerprint64(values_.vec<string>().data()[start + n]);
- return values_.vec<int64>().data()[start + n];
- }
-
- // InternalType is string or StringPiece when using StringCrosser.
- string DoFeature(int64 batch, int64 n, string not_used) const {
- const int64 start = feature_start_indices_[batch];
- if (DT_STRING == values_.dtype())
- return values_.vec<string>().data()[start + n];
- return std::to_string(values_.vec<int64>().data()[start + n]);
- }
-
- StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
- const int64 start = feature_start_indices_[batch];
- return values_.vec<string>().data()[start + n];
- }
+ InternalType Feature(int64 batch, int64 n) const override;
~SparseTensorColumn() override {}
@@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface<InternalType> {
std::vector<int64> feature_start_indices_;
};
+// InternalType is int64 only when using HashCrosser.
+template <>
+int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
+ const int64 start = feature_start_indices_[batch];
+ if (DT_STRING == values_.dtype())
+ return Fingerprint64(values_.vec<string>().data()[start + n]);
+ return values_.vec<int64>().data()[start + n];
+}
+
+// InternalType is string or StringPiece when using StringCrosser.
+template <>
+string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
+ const int64 start = feature_start_indices_[batch];
+ if (DT_STRING == values_.dtype())
+ return values_.vec<string>().data()[start + n];
+ return std::to_string(values_.vec<int64>().data()[start + n]);
+}
+
+template <>
+StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
+ int64 n) const {
+ const int64 start = feature_start_indices_[batch];
+ return values_.vec<string>().data()[start + n];
+}
+
// A column that is backed by a dense tensor.
template <typename InternalType>
class DenseTensorColumn : public ColumnInterface<InternalType> {
@@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
- // InternalType is int64 only when using HashCrosser.
- int64 DoFeature(int64 batch, int64 n, int64 not_used) const {
- if (DT_STRING == tensor_.dtype())
- return Fingerprint64(tensor_.matrix<string>()(batch, n));
- return tensor_.matrix<int64>()(batch, n);
- }
-
- // Internal type is string or StringPiece when using StringCrosser.
- string DoFeature(int64 batch, int64 n, string not_used) const {
- if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
- return std::to_string(tensor_.matrix<int64>()(batch, n));
- }
-
- StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const {
- return tensor_.matrix<string>()(batch, n);
- }
+ InternalType Feature(int64 batch, int64 n) const override;
~DenseTensorColumn() override {}
@@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface<InternalType> {
const Tensor& tensor_;
};
+// InternalType is int64 only when using HashCrosser.
+template <>
+int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
+ if (DT_STRING == tensor_.dtype())
+ return Fingerprint64(tensor_.matrix<string>()(batch, n));
+ return tensor_.matrix<int64>()(batch, n);
+}
+
+// Internal type is string or StringPiece when using StringCrosser.
+template <>
+string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
+ if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
+ return std::to_string(tensor_.matrix<int64>()(batch, n));
+}
+
+template <>
+StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
+ int64 n) const {
+ return tensor_.matrix<string>()(batch, n);
+}
+
// Updates Output tensors with sparse crosses.
template <typename OutType>
class OutputUpdater {
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index fd13c7bebf..f0fcd02835 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1221,7 +1221,7 @@ of the forward TensorArray is known when this operation is called.
TensorArray gradient calls use an accumulator TensorArray object. If
multiple gradients are calculated and run in the same session, the multiple
-gradient nodes may accidentally flow throuth the same accumulator TensorArray.
+gradient nodes may accidentally flow through the same accumulator TensorArray.
This double counts and generally breaks the TensorArray gradient flow.
The solution is to identify which gradient call this particular
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 59c3cf5e3d..c02bd4b8a0 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -9268,7 +9268,7 @@ op {
type: DT_FLOAT
}
summary: "Inverse real-valued fast Fourier transform."
- description: "Computes the inverse 1-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most dimension of `input`.\n\nThe inner-most dimension of `input` is assumed to be the result of `RFFT`: the\n`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If\n`fft_length` is not provided, it is computed from the size of the inner-most\ndimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to\ncompute `input` is odd, it should be provided since it cannot be inferred\nproperly."
+ description: "Computes the inverse 1-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most dimension of `input`.\n\nThe inner-most dimension of `input` is assumed to be the result of `RFFT`: the\n`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If\n`fft_length` is not provided, it is computed from the size of the inner-most\ndimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to\ncompute `input` is odd, it should be provided since it cannot be inferred\nproperly.\n\nAlong the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller\nthan the corresponding dimension of `input`, the dimension is cropped. If it is\nlarger, the dimension is padded with zeros."
}
op {
name: "IRFFT2D"
@@ -9288,7 +9288,7 @@ op {
type: DT_FLOAT
}
summary: "Inverse 2D real-valued fast Fourier transform."
- description: "Computes the inverse 2-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 2 dimensions of `input`.\n\nThe inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 2 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly."
+ description: "Computes the inverse 2-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 2 dimensions of `input`.\n\nThe inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 2 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly.\n\nAlong each axis `IRFFT2D` is computed on, if `fft_length` (or\n`fft_length / 2 + 1` for the inner-most dimension) is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros."
}
op {
name: "IRFFT3D"
@@ -9308,7 +9308,7 @@ op {
type: DT_FLOAT
}
summary: "Inverse 3D real-valued fast Fourier transform."
- description: "Computes the inverse 3-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 3 dimensions of `input`.\n\nThe inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 3 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly."
+ description: "Computes the inverse 3-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 3 dimensions of `input`.\n\nThe inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 3 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly.\n\nAlong each axis `IRFFT3D` is computed on, if `fft_length` (or\n`fft_length / 2 + 1` for the inner-most dimension) is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros."
}
op {
name: "Identity"
@@ -16180,7 +16180,7 @@ op {
type: DT_COMPLEX64
}
summary: "Real-valued fast Fourier transform."
- description: "Computes the 1-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most dimension of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the\n`fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,\nfollowed by the `fft_length / 2` positive-frequency terms."
+ description: "Computes the 1-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most dimension of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the\n`fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,\nfollowed by the `fft_length / 2` positive-frequency terms.\n\nAlong the axis `RFFT` is computed on, if `fft_length` is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros."
}
op {
name: "RFFT2D"
@@ -16200,7 +16200,7 @@ op {
type: DT_COMPLEX64
}
summary: "2D real-valued fast Fourier transform."
- description: "Computes the 2-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 2 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms."
+ description: "Computes the 2-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 2 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms.\n\nAlong each axis `RFFT2D` is computed on, if `fft_length` is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros."
}
op {
name: "RFFT3D"
@@ -16220,7 +16220,7 @@ op {
type: DT_COMPLEX64
}
summary: "3D real-valued fast Fourier transform."
- description: "Computes the 3-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 3 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms."
+ description: "Computes the 3-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 3 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms.\n\nAlong each axis `RFFT3D` is computed on, if `fft_length` is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros."
}
op {
name: "RGBToHSV"
@@ -26648,7 +26648,7 @@ op {
description: "The gradient source string, used to decide which gradient TensorArray\nto return."
}
summary: "Creates a TensorArray for storing the gradients of values in the given handle."
- description: "If the given TensorArray gradient already exists, returns a reference to it.\n\nLocks the size of the original TensorArray by disabling its dynamic size flag.\n\n**A note about the input flow_in:**\n\nThe handle flow_in forces the execution of the gradient lookup to occur\nonly after certain other operations have occurred. For example, when\nthe forward TensorArray is dynamically sized, writes to this TensorArray\nmay resize the object. The gradient TensorArray is statically sized based\non the size of the forward TensorArray when this operation executes.\nFurthermore, the size of the forward TensorArray is frozen by this call.\nAs a result, the flow is used to ensure that the call to generate the gradient\nTensorArray only happens after all writes are executed.\n\nIn the case of dynamically sized TensorArrays, gradient computation should\nonly be performed on read operations that have themselves been chained via\nflow to occur only after all writes have executed. That way the final size\nof the forward TensorArray is known when this operation is called.\n\n**A note about the source attribute:**\n\nTensorArray gradient calls use an accumulator TensorArray object. If\nmultiple gradients are calculated and run in the same session, the multiple\ngradient nodes may accidentally flow throuth the same accumulator TensorArray.\nThis double counts and generally breaks the TensorArray gradient flow.\n\nThe solution is to identify which gradient call this particular\nTensorArray gradient is being called in. This is performed by identifying\na unique string (e.g. \"gradients\", \"gradients_1\", ...) from the input\ngradient Tensor\'s name. This string is used as a suffix when creating\nthe TensorArray gradient object here (the attribute `source`).\n\nThe attribute `source` is added as a suffix to the forward TensorArray\'s\nname when performing the creation / lookup, so that each separate gradient\ncalculation gets its own TensorArray accumulator."
+ description: "If the given TensorArray gradient already exists, returns a reference to it.\n\nLocks the size of the original TensorArray by disabling its dynamic size flag.\n\n**A note about the input flow_in:**\n\nThe handle flow_in forces the execution of the gradient lookup to occur\nonly after certain other operations have occurred. For example, when\nthe forward TensorArray is dynamically sized, writes to this TensorArray\nmay resize the object. The gradient TensorArray is statically sized based\non the size of the forward TensorArray when this operation executes.\nFurthermore, the size of the forward TensorArray is frozen by this call.\nAs a result, the flow is used to ensure that the call to generate the gradient\nTensorArray only happens after all writes are executed.\n\nIn the case of dynamically sized TensorArrays, gradient computation should\nonly be performed on read operations that have themselves been chained via\nflow to occur only after all writes have executed. That way the final size\nof the forward TensorArray is known when this operation is called.\n\n**A note about the source attribute:**\n\nTensorArray gradient calls use an accumulator TensorArray object. If\nmultiple gradients are calculated and run in the same session, the multiple\ngradient nodes may accidentally flow through the same accumulator TensorArray.\nThis double counts and generally breaks the TensorArray gradient flow.\n\nThe solution is to identify which gradient call this particular\nTensorArray gradient is being called in. This is performed by identifying\na unique string (e.g. \"gradients\", \"gradients_1\", ...) from the input\ngradient Tensor\'s name. This string is used as a suffix when creating\nthe TensorArray gradient object here (the attribute `source`).\n\nThe attribute `source` is added as a suffix to the forward TensorArray\'s\nname when performing the creation / lookup, so that each separate gradient\ncalculation gets its own TensorArray accumulator."
is_stateful: true
}
op {
diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc
index 09b460fd14..592aaa25c3 100644
--- a/tensorflow/core/ops/spectral_ops.cc
+++ b/tensorflow/core/ops/spectral_ops.cc
@@ -201,6 +201,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the
`fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,
followed by the `fft_length / 2` positive-frequency terms.
+Along the axis `RFFT` is computed on, if `fft_length` is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A float32 tensor.
fft_length: An int32 tensor of shape [1]. The FFT length.
output: A complex64 tensor of the same rank as `input`. The inner-most
@@ -230,6 +234,10 @@ dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to
compute `input` is odd, it should be provided since it cannot be inferred
properly.
+Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller
+than the corresponding dimension of `input`, the dimension is cropped. If it is
+larger, the dimension is padded with zeros.
+
input: A complex64 tensor.
fft_length: An int32 tensor of shape [1]. The FFT length.
output: A float32 tensor of the same rank as `input`. The inner-most
@@ -257,6 +265,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
+Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A float32 tensor.
fft_length: An int32 tensor of shape [2]. The FFT length for each dimension.
output: A complex64 tensor of the same rank as `input`. The inner-most 2
@@ -287,6 +299,11 @@ from the size of the inner-most 2 dimensions of `input`. If the FFT length used
to compute `input` is odd, it should be provided since it cannot be inferred
properly.
+Along each axis `IRFFT2D` is computed on, if `fft_length` (or
+`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A complex64 tensor.
fft_length: An int32 tensor of shape [2]. The FFT length for each dimension.
output: A float32 tensor of the same rank as `input`. The inner-most 2
@@ -314,6 +331,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
+Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A float32 tensor.
fft_length: An int32 tensor of shape [3]. The FFT length for each dimension.
output: A complex64 tensor of the same rank as `input`. The inner-most 3
@@ -344,6 +365,11 @@ from the size of the inner-most 3 dimensions of `input`. If the FFT length used
to compute `input` is odd, it should be provided since it cannot be inferred
properly.
+Along each axis `IRFFT3D` is computed on, if `fft_length` (or
+`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
+corresponding dimension of `input`, the dimension is cropped. If it is larger,
+the dimension is padded with zeros.
+
input: A complex64 tensor.
fft_length: An int32 tensor of shape [3]. The FFT length for each dimension.
output: A float32 tensor of the same rank as `input`. The inner-most 3
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index cf66277b9d..d30d7819fc 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -88,10 +88,11 @@ limitations under the License.
// shapes, particularly when restoring a graph from GraphDef
// produced at version 22 or later. (04/10/2016)
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
+// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 23
+#define TF_GRAPH_DEF_VERSION 24
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//
diff --git a/tensorflow/contrib/ios_examples/.gitignore b/tensorflow/examples/ios/.gitignore
index e572b3012c..e572b3012c 100644
--- a/tensorflow/contrib/ios_examples/.gitignore
+++ b/tensorflow/examples/ios/.gitignore
diff --git a/tensorflow/contrib/ios_examples/README.md b/tensorflow/examples/ios/README.md
index 6bac33c0ec..9832399d72 100644
--- a/tensorflow/contrib/ios_examples/README.md
+++ b/tensorflow/examples/ios/README.md
@@ -2,15 +2,15 @@
This folder contains examples of how to build applications for iOS devices using TensorFlow.
-## Building the Examples
+## Running the Samples using CocoaPod
+ - You'll need Xcode 7.3 or later.
- - You'll need Xcode 7.3 or later, with the command-line tools installed.
+ - There are currently three examples: simple, benchmark, and camera. For now,
+ you can download the sample code by cloning the main tensorflow repository
+ (we are planning to make the samples available as a separate repository
+ later).
- - Follow the instructions at
- [tensorflow/contrib/makefile](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile)
- under "iOS" to compile a static library containing the core TensorFlow code.
-
- - From the root of the Tensorflow folder, download
+ - From the root of the tensorflow folder, download
[Inception v1](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip),
and extract the label and graph files into the data folders inside both the
simple and camera examples:
@@ -25,8 +25,62 @@ cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/camera/data/
cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/simple/data/
```
- - Load the Xcode project inside the `simple` subfolder, and press Command-R to
- build and run it on the simulator or your connected device.
+ - Change directory to one of the samples, download the TensorFlow-experimental
+ pod, and open the Xcode workspace. Observe: installing the pod can take a
+ long time since it is big (~450MB). For example, if you want to run the
+ simple example, then:
+```bash
+cd tensorflow/contrib/ios_examples/simple
+pod install
+open tf_simple_example.xcworkspace # obs, not the .xcodeproj directory
+```
+
+ - Run the simple app in the simulator. You should see a single-screen app with
+ a "Run Model" button. Tap that, and you should see some debug output appear
+ below indicating that the example Grace Hopper image in directory data has
+ been analyzed, with a military uniform recognized.
+
+ - Run the other samples using the same process. The camera example requires a
+ real device connected. Once you build and run that, you should get a live
+ camera view that you can point at objects to get real-time recognition
+ results.
+
+### Troubleshooting
+
+ - Make sure you use the TensorFlow-experimental pod (and not TensorFlow).
+
+ - The TensorFlow-experimental pod is current about ~450MB. The reason it is
+ so big is because we are bundling multiple platforms, and the pod includes
+ all TensorFlow functionality (e.g. operations). This is convenient during
+ development, but see below section on how you can build your own custom
+ TensorFlow library to reduce the size.
+
+### Creating Your own App
+
+ - Create your own app using Xcode then add a file named Podfile at the project
+ root directory with the following content:
+```bash
+target 'YourProjectName'
+ pod 'TensorFlow-experimental'
+```
+
+ - Then you run ```pod install``` to download and install the
+ TensorFlow-experimental pod, and finaly perform
+ ```open YourProjectName.xcworkspace``` and add your code.
+
+ - In your apps "Build Settings", make sure to add $(inherited) to sections
+ "Other Linker Flags", and "Header Search Paths".
+
+ - That's it. If you want to create your custom TensorFlow iOS library, for
+ example to reduce binary footprint, see below section.
+
+## Building the TensorFlow iOS libraries from source
+
+ - You'll need Xcode 7.3 or later, with the command-line tools installed.
+
+ - Follow the instructions at
+ [tensorflow/contrib/makefile](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile)
+ under "iOS" to compile a static library containing the core TensorFlow code.
- You should see a single-screen app with a "Run Model" button. Tap that, and
you should see some debug output appear below indicating that the example
@@ -36,8 +90,8 @@ cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/simple/data/
open up the Xcode project in the `camera` subfolder. Once you build and run
that, you should get a live camera view that you can point at objects to get
real-time recognition results.
-
-## Troubleshooting
+
+### Troubleshooting
If you're hitting problems, here's a checklist of common things to investigate:
@@ -52,7 +106,7 @@ If you're hitting problems, here's a checklist of common things to investigate:
linked in properly. You'll have to make sure your project uses force_load, as
described below.
-## Creating your Own App
+### Creating your Own App from your source libraries
You'll need to update various settings in your app to link against
TensorFlow. You can view them in the example projects, but here's a full
@@ -96,7 +150,7 @@ rundown:
`-all_load` to avoid issues with Objective-C categories in static libraries,
you may be able to replace it with the `-ObjC` flag.
-## Reducing the binary size
+### Reducing the binary size
TensorFlow is a comparatively large library for a mobile device, so it will
increase the size of your app. Currently on iOS we see around a 11 MB binary
@@ -115,17 +169,17 @@ looking at the simple example to examine its size. Here's how you do that:
- Once the build's complete, open the Report Navigator and select the logs.
- - Near the bottom, you'll see a line saying "Touch tf_ios_makefile_example.app".
+ - Near the bottom, you'll see a line saying "Touch tf_simple_example.app".
- Expand that line using the icon on the right, and copy the first argument to
the Touch command.
- Go to the terminal, type `ls -lah ` and then paste the path you copied.
- - For example it might look like `ls -lah /Users/petewarden/Library/Developer/Xcode/DerivedData/tf_ios_makefile_example-etdbksqytcnzeyfgdwiihzkqpxwr/Build/Products/Debug-iphoneos/tf_ios_makefile_example.app`
+ - For example it might look like `ls -lah /Users/petewarden/Library/Developer/Xcode/DerivedData/tf_simple_example-etdbksqytcnzeyfgdwiihzkqpxwr/Build/Products/Debug-iphoneos/tf_simple_example.app`
- Running this command will show the size of the executable as the
- `tf_ios_makefile_example` line.
+ `tf_simple_example` line.
Right now you'll see a size of around 23 MB, since it's including two
architectures (armv7 and arm64). As a first step, you should make sure the size
diff --git a/tensorflow/contrib/ios_examples/benchmark/AppDelegate.h b/tensorflow/examples/ios/benchmark/AppDelegate.h
index 94046d9728..94046d9728 100644
--- a/tensorflow/contrib/ios_examples/benchmark/AppDelegate.h
+++ b/tensorflow/examples/ios/benchmark/AppDelegate.h
diff --git a/tensorflow/contrib/ios_examples/benchmark/AppDelegate.mm b/tensorflow/examples/ios/benchmark/AppDelegate.mm
index 23ffba0f7b..23ffba0f7b 100644
--- a/tensorflow/contrib/ios_examples/benchmark/AppDelegate.mm
+++ b/tensorflow/examples/ios/benchmark/AppDelegate.mm
diff --git a/tensorflow/contrib/ios_examples/benchmark/Benchmark-Info.plist b/tensorflow/examples/ios/benchmark/Benchmark-Info.plist
index 8d17162b87..0cdbf28a31 100644
--- a/tensorflow/contrib/ios_examples/benchmark/Benchmark-Info.plist
+++ b/tensorflow/examples/ios/benchmark/Benchmark-Info.plist
@@ -5,11 +5,11 @@
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleDisplayName</key>
- <string>TF Benchmark</string>
+ <string>tf_benchmark_example</string>
<key>CFBundleExecutable</key>
- <string>benchmark</string>
+ <string>tf_benchmark_example</string>
<key>CFBundleIdentifier</key>
- <string>Google.Benchmark</string>
+ <string>com.google.tf_benchmark_example</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
diff --git a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.h b/tensorflow/examples/ios/benchmark/BenchmarkViewController.h
index c9cbc49280..c9cbc49280 100644
--- a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.h
+++ b/tensorflow/examples/ios/benchmark/BenchmarkViewController.h
diff --git a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.mm b/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm
index 4421c88651..cab7b36f17 100644
--- a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.mm
+++ b/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm
@@ -22,17 +22,17 @@
#include <sstream>
#include <string>
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
-#include "google/protobuf/message_lite.h"
+//#include "google/protobuf/io/coded_stream.h"
+//#include "google/protobuf/io/zero_copy_stream_impl.h"
+//#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+//#include "google/protobuf/message_lite.h"
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
+//#include "tensorflow/core/framework/tensor.h"
+//#include "tensorflow/core/framework/types.pb.h"
+//#include "tensorflow/core/platform/env.h"
+//#include "tensorflow/core/platform/logging.h"
+//#include "tensorflow/core/platform/mutex.h"
+//#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/stat_summarizer.h"
@@ -52,7 +52,7 @@ class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
return -1;
}
ifs_.read(static_cast<char*>(buffer), size);
- return ifs_.gcount();
+ return (int)ifs_.gcount();
}
private:
@@ -85,7 +85,7 @@ static void GetTopN(
std::greater<std::pair<float, int>>>
top_result_pq;
- const int count = prediction.size();
+ long count = prediction.size();
for (int i = 0; i < count; ++i) {
const float value = prediction(i);
@@ -178,7 +178,7 @@ tensorflow::Status BenchmarkInference(
stat_summarizer->PrintStepStats();
*average_time = total_time / iterations_count;
- NSLog(@"Took %f seconds", average_time);
+ NSLog(@"Took %f seconds", *average_time);
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.xib b/tensorflow/examples/ios/benchmark/BenchmarkViewController.xib
index 56c3708062..56c3708062 100644
--- a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.xib
+++ b/tensorflow/examples/ios/benchmark/BenchmarkViewController.xib
diff --git a/tensorflow/examples/ios/benchmark/Podfile b/tensorflow/examples/ios/benchmark/Podfile
new file mode 100644
index 0000000000..e163d56e8d
--- /dev/null
+++ b/tensorflow/examples/ios/benchmark/Podfile
@@ -0,0 +1,5 @@
+platform :ios, '8.0'
+inhibit_all_warnings!
+
+target 'tf_benchmark_example'
+ pod 'TensorFlow-experimental'
diff --git a/tensorflow/contrib/ios_examples/benchmark/data/grace_hopper.jpg b/tensorflow/examples/ios/benchmark/data/grace_hopper.jpg
index d2a427810f..d2a427810f 100644
--- a/tensorflow/contrib/ios_examples/benchmark/data/grace_hopper.jpg
+++ b/tensorflow/examples/ios/benchmark/data/grace_hopper.jpg
Binary files differ
diff --git a/tensorflow/contrib/ios_examples/benchmark/ios_image_load.h b/tensorflow/examples/ios/benchmark/ios_image_load.h
index 78eaded8d7..78eaded8d7 100644
--- a/tensorflow/contrib/ios_examples/benchmark/ios_image_load.h
+++ b/tensorflow/examples/ios/benchmark/ios_image_load.h
diff --git a/tensorflow/contrib/ios_examples/benchmark/ios_image_load.mm b/tensorflow/examples/ios/benchmark/ios_image_load.mm
index 64d1ea21cf..64d1ea21cf 100644
--- a/tensorflow/contrib/ios_examples/benchmark/ios_image_load.mm
+++ b/tensorflow/examples/ios/benchmark/ios_image_load.mm
diff --git a/tensorflow/contrib/ios_examples/benchmark/main.mm b/tensorflow/examples/ios/benchmark/main.mm
index d70550a730..d70550a730 100644
--- a/tensorflow/contrib/ios_examples/benchmark/main.mm
+++ b/tensorflow/examples/ios/benchmark/main.mm
diff --git a/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj b/tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj
index 5cd173b416..d61b65ba61 100644
--- a/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj
+++ b/tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj
@@ -7,33 +7,28 @@
objects = {
/* Begin PBXBuildFile section */
- 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */; };
- 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D871D02091F00DF5523 /* libprotobuf.a */; };
- 5993C7701D5D4E7F0048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */; };
- 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; };
- 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; };
- 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; };
- 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; };
- 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; };
- 59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; };
- 59A3D00B1CF4E68100C4259F /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */; };
- 59A3D00C1CF4E68100C4259F /* BenchmarkViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */; };
- 59A3D0141CF4E82500C4259F /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */; };
- 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 59A3D0171CF4E86100C4259F /* UIKit.framework */; };
+ 1C8BA8FD1EC682E700CCCC8C /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; };
+ 1C8BA8FE1EC682E700CCCC8C /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; };
+ 1C8BA8FF1EC682E700CCCC8C /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */; };
+ 1C8BA9001EC682E700CCCC8C /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; };
+ 1C8BA9051EC682E700CCCC8C /* BenchmarkViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */; };
+ 1C8BA9061EC682E700CCCC8C /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; };
+ 1C8BA9071EC682E700CCCC8C /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; };
+ 1C8BA9081EC682E700CCCC8C /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; };
+ 1CB1883E1ECCC0DC00C93EF7 /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */; };
+ 1CB1883F1ECCC10D00C93EF7 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */; };
+ 1E0EBA4DF4C722C63814B257 /* libPods-tf_benchmark_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
- 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libprotobuf-lite.a"; path = "../../makefile/gen/protobuf_ios/lib/libprotobuf-lite.a"; sourceTree = "<group>"; };
- 590E7D871D02091F00DF5523 /* libprotobuf.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libprotobuf.a; path = ../../makefile/gen/protobuf_ios/lib/libprotobuf.a; sourceTree = "<group>"; };
- 5911579B1CF4011C00C31E3A /* benchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = benchmark.app; sourceTree = BUILT_PRODUCTS_DIR; };
- 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; };
+ 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; };
+ 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_benchmark_example.app; sourceTree = BUILT_PRODUCTS_DIR; };
+ 1CB1883B1ECCC09A00C93EF7 /* CoreFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreFoundation.framework; path = System/Library/Frameworks/CoreFoundation.framework; sourceTree = SDKROOT; };
+ 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = "<group>"; };
- 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = cropped_panda.jpg; sourceTree = "<group>"; };
59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = "<group>"; };
- 59A3CFF61CF4E68100C4259F /* imagenet_2012_challenge_label_map_proto.pbtxt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_2012_challenge_label_map_proto.pbtxt; sourceTree = "<group>"; };
59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = "<group>"; };
- 59A3CFF81CF4E68100C4259F /* LICENSE */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = LICENSE; sourceTree = "<group>"; };
59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = "<group>"; };
59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = "<group>"; };
59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = "<group>"; };
@@ -42,36 +37,37 @@
59A3CFFE1CF4E68100C4259F /* BenchmarkViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = BenchmarkViewController.h; sourceTree = "<group>"; };
59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = "<group>"; };
59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = BenchmarkViewController.xib; sourceTree = "<group>"; };
- 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
- 59A3D0151CF4E83D00C4259F /* Foundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Foundation.framework; path = System/Library/Frameworks/Foundation.framework; sourceTree = SDKROOT; };
- 59A3D0171CF4E86100C4259F /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; };
+ 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_benchmark_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example.debug.xcconfig"; sourceTree = "<group>"; };
+ 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_benchmark_example.a"; sourceTree = BUILT_PRODUCTS_DIR; };
+ DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_benchmark_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example.release.xcconfig"; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
- 591157981CF4011C00C31E3A /* Frameworks */ = {
+ 1C8BA9011EC682E700CCCC8C /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
- 5993C7701D5D4E7F0048CE6A /* Accelerate.framework in Frameworks */,
- 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */,
- 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */,
- 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */,
- 59A3D0141CF4E82500C4259F /* CoreGraphics.framework in Frameworks */,
+ 1CB1883F1ECCC10D00C93EF7 /* UIKit.framework in Frameworks */,
+ 1CB1883E1ECCC0DC00C93EF7 /* CoreGraphics.framework in Frameworks */,
+ 1E0EBA4DF4C722C63814B257 /* libPods-tf_benchmark_example.a in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
+ 2BD56010B574F539C2070A57 /* Pods */ = {
+ isa = PBXGroup;
+ children = (
+ 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */,
+ DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */,
+ );
+ name = Pods;
+ sourceTree = "<group>";
+ };
591157921CF4011C00C31E3A = {
isa = PBXGroup;
children = (
- 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */,
- 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */,
- 590E7D871D02091F00DF5523 /* libprotobuf.a */,
- 59A3D0171CF4E86100C4259F /* UIKit.framework */,
- 59A3D0151CF4E83D00C4259F /* Foundation.framework */,
- 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */,
59A3CFF11CF4E68100C4259F /* AppDelegate.h */,
59A3CFF21CF4E68100C4259F /* AppDelegate.mm */,
59A3CFF31CF4E68100C4259F /* data */,
@@ -83,13 +79,15 @@
59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */,
59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */,
5911579C1CF4011C00C31E3A /* Products */,
+ 2BD56010B574F539C2070A57 /* Pods */,
+ 76A25A27041EB307BDFF0DD1 /* Frameworks */,
);
sourceTree = "<group>";
};
5911579C1CF4011C00C31E3A /* Products */ = {
isa = PBXGroup;
children = (
- 5911579B1CF4011C00C31E3A /* benchmark.app */,
+ 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */,
);
name = Products;
sourceTree = "<group>";
@@ -97,34 +95,45 @@
59A3CFF31CF4E68100C4259F /* data */ = {
isa = PBXGroup;
children = (
- 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */,
59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */,
- 59A3CFF61CF4E68100C4259F /* imagenet_2012_challenge_label_map_proto.pbtxt */,
59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */,
- 59A3CFF81CF4E68100C4259F /* LICENSE */,
59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */,
);
path = data;
sourceTree = "<group>";
};
+ 76A25A27041EB307BDFF0DD1 /* Frameworks */ = {
+ isa = PBXGroup;
+ children = (
+ 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */,
+ 1CB1883B1ECCC09A00C93EF7 /* CoreFoundation.framework */,
+ 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */,
+ 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */,
+ );
+ name = Frameworks;
+ sourceTree = "<group>";
+ };
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
- 5911579A1CF4011C00C31E3A /* benchmark */ = {
+ 1C8BA8FB1EC682E700CCCC8C /* tf_benchmark_example */ = {
isa = PBXNativeTarget;
- buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "benchmark" */;
+ buildConfigurationList = 1C8BA9091EC682E700CCCC8C /* Build configuration list for PBXNativeTarget "tf_benchmark_example" */;
buildPhases = (
- 591157971CF4011C00C31E3A /* Sources */,
- 591157981CF4011C00C31E3A /* Frameworks */,
- 591157991CF4011C00C31E3A /* Resources */,
+ 0388D751057A257A12848245 /* [CP] Check Pods Manifest.lock */,
+ 1C8BA8FC1EC682E700CCCC8C /* Sources */,
+ 1C8BA9011EC682E700CCCC8C /* Frameworks */,
+ 1C8BA9041EC682E700CCCC8C /* Resources */,
+ 8999A303091D4E86202C2F64 /* [CP] Embed Pods Frameworks */,
+ A7B4B278BCC417B76A47ABB0 /* [CP] Copy Pods Resources */,
);
buildRules = (
);
dependencies = (
);
- name = benchmark;
+ name = tf_benchmark_example;
productName = benchmark;
- productReference = 5911579B1CF4011C00C31E3A /* benchmark.app */;
+ productReference = 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */;
productType = "com.apple.product-type.application";
};
/* End PBXNativeTarget section */
@@ -133,16 +142,10 @@
591157931CF4011C00C31E3A /* Project object */ = {
isa = PBXProject;
attributes = {
- LastUpgradeCheck = 0720;
+ LastUpgradeCheck = 0830;
ORGANIZATIONNAME = Google;
- TargetAttributes = {
- 5911579A1CF4011C00C31E3A = {
- CreatedOnToolsVersion = 7.2;
- DevelopmentTeam = 85Z3VXS37U;
- };
- };
};
- buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "benchmark" */;
+ buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_benchmark_example" */;
compatibilityVersion = "Xcode 3.2";
developmentRegion = English;
hasScannedForEncodings = 0;
@@ -155,40 +158,123 @@
projectDirPath = "";
projectRoot = "";
targets = (
- 5911579A1CF4011C00C31E3A /* benchmark */,
+ 1C8BA8FB1EC682E700CCCC8C /* tf_benchmark_example */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
- 591157991CF4011C00C31E3A /* Resources */ = {
+ 1C8BA9041EC682E700CCCC8C /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
- 59A3D00C1CF4E68100C4259F /* BenchmarkViewController.xib in Resources */,
- 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */,
- 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */,
- 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */,
+ 1C8BA9051EC682E700CCCC8C /* BenchmarkViewController.xib in Resources */,
+ 1C8BA9061EC682E700CCCC8C /* imagenet_comp_graph_label_strings.txt in Resources */,
+ 1C8BA9071EC682E700CCCC8C /* tensorflow_inception_graph.pb in Resources */,
+ 1C8BA9081EC682E700CCCC8C /* grace_hopper.jpg in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
+/* Begin PBXShellScriptBuildPhase section */
+ 0388D751057A257A12848245 /* [CP] Check Pods Manifest.lock */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Check Pods Manifest.lock";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n";
+ showEnvVarsInLog = 0;
+ };
+ 8999A303091D4E86202C2F64 /* [CP] Embed Pods Frameworks */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Embed Pods Frameworks";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example-frameworks.sh\"\n";
+ showEnvVarsInLog = 0;
+ };
+ A7B4B278BCC417B76A47ABB0 /* [CP] Copy Pods Resources */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Copy Pods Resources";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example-resources.sh\"\n";
+ showEnvVarsInLog = 0;
+ };
+/* End PBXShellScriptBuildPhase section */
+
/* Begin PBXSourcesBuildPhase section */
- 591157971CF4011C00C31E3A /* Sources */ = {
+ 1C8BA8FC1EC682E700CCCC8C /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
- 59A3D0091CF4E68100C4259F /* main.mm in Sources */,
- 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */,
- 59A3D00B1CF4E68100C4259F /* BenchmarkViewController.mm in Sources */,
- 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */,
+ 1C8BA8FD1EC682E700CCCC8C /* main.mm in Sources */,
+ 1C8BA8FE1EC682E700CCCC8C /* AppDelegate.mm in Sources */,
+ 1C8BA8FF1EC682E700CCCC8C /* BenchmarkViewController.mm in Sources */,
+ 1C8BA9001EC682E700CCCC8C /* ios_image_load.mm in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
+ 1C8BA90A1EC682E700CCCC8C /* Debug */ = {
+ isa = XCBuildConfiguration;
+ baseConfigurationReference = 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */;
+ buildSettings = {
+ CODE_SIGN_IDENTITY = "iPhone Developer";
+ ENABLE_BITCODE = NO;
+ HEADER_SEARCH_PATHS = "$(inherited)";
+ INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist";
+ IPHONEOS_DEPLOYMENT_TARGET = 8.0;
+ LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
+ LIBRARY_SEARCH_PATHS = "";
+ OTHER_LDFLAGS = "$(inherited)";
+ PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-benchmark-example";
+ PRODUCT_NAME = "$(TARGET_NAME)";
+ };
+ name = Debug;
+ };
+ 1C8BA90B1EC682E700CCCC8C /* Release */ = {
+ isa = XCBuildConfiguration;
+ baseConfigurationReference = DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */;
+ buildSettings = {
+ CODE_SIGN_IDENTITY = "iPhone Developer";
+ ENABLE_BITCODE = NO;
+ HEADER_SEARCH_PATHS = "$(inherited)";
+ INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist";
+ IPHONEOS_DEPLOYMENT_TARGET = 8.0;
+ LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
+ LIBRARY_SEARCH_PATHS = "";
+ ONLY_ACTIVE_ARCH = YES;
+ OTHER_LDFLAGS = "$(inherited)";
+ PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-benchmark-example";
+ PRODUCT_NAME = "$(TARGET_NAME)";
+ };
+ name = Release;
+ };
591157B01CF4011D00C31E3A /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
@@ -202,8 +288,10 @@
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
@@ -225,7 +313,7 @@
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
- IPHONEOS_DEPLOYMENT_TARGET = 9.2;
+ IPHONEOS_DEPLOYMENT_TARGET = 8.0;
MTL_ENABLE_DEBUG_INFO = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
@@ -246,8 +334,10 @@
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
@@ -263,7 +353,7 @@
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
- IPHONEOS_DEPLOYMENT_TARGET = 9.2;
+ IPHONEOS_DEPLOYMENT_TARGET = 8.0;
MTL_ENABLE_DEBUG_INFO = NO;
SDKROOT = iphoneos;
TARGETED_DEVICE_FAMILY = "1,2";
@@ -271,92 +361,23 @@
};
name = Release;
};
- 591157B31CF4011D00C31E3A /* Debug */ = {
- isa = XCBuildConfiguration;
- buildSettings = {
- CODE_SIGN_IDENTITY = "iPhone Developer";
- ENABLE_BITCODE = NO;
- HEADER_SEARCH_PATHS = (
- "$(SRCROOT)/../../../..",
- "$(SRCROOT)/../../makefile/downloads/protobuf/src/",
- "$(SRCROOT)/../../makefile/downloads",
- "$(SRCROOT)/../../makefile/gen/proto",
- "$(SRCROOT)/../../makefile/downloads/eigen",
- );
- INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist";
- IPHONEOS_DEPLOYMENT_TARGET = 9.2;
- LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
- LIBRARY_SEARCH_PATHS = (
- "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
- "$(SRCROOT)/../../makefile/gen/lib",
- );
- OTHER_LDFLAGS = (
- "-force_load",
- "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
- "-Xlinker",
- "-S",
- "-Xlinker",
- "-x",
- "-Xlinker",
- "-dead_strip",
- );
- PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
- PRODUCT_NAME = "$(TARGET_NAME)";
- };
- name = Debug;
- };
- 591157B41CF4011D00C31E3A /* Release */ = {
- isa = XCBuildConfiguration;
- buildSettings = {
- CODE_SIGN_IDENTITY = "iPhone Developer";
- ENABLE_BITCODE = NO;
- HEADER_SEARCH_PATHS = (
- "$(SRCROOT)/../../../..",
- "$(SRCROOT)/../../makefile/downloads/protobuf/src/",
- "$(SRCROOT)/../../makefile/downloads",
- "$(SRCROOT)/../../makefile/gen/proto",
- "$(SRCROOT)/../../makefile/downloads/eigen",
- );
- INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist";
- IPHONEOS_DEPLOYMENT_TARGET = 9.2;
- LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
- LIBRARY_SEARCH_PATHS = (
- "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
- "$(SRCROOT)/../../makefile/gen/lib",
- );
- ONLY_ACTIVE_ARCH = YES;
- OTHER_LDFLAGS = (
- "-force_load",
- "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
- "-Xlinker",
- "-S",
- "-Xlinker",
- "-x",
- "-Xlinker",
- "-dead_strip",
- );
- PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
- PRODUCT_NAME = "$(TARGET_NAME)";
- };
- name = Release;
- };
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
- 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "benchmark" */ = {
+ 1C8BA9091EC682E700CCCC8C /* Build configuration list for PBXNativeTarget "tf_benchmark_example" */ = {
isa = XCConfigurationList;
buildConfigurations = (
- 591157B01CF4011D00C31E3A /* Debug */,
- 591157B11CF4011D00C31E3A /* Release */,
+ 1C8BA90A1EC682E700CCCC8C /* Debug */,
+ 1C8BA90B1EC682E700CCCC8C /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
- 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "benchmark" */ = {
+ 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_benchmark_example" */ = {
isa = XCConfigurationList;
buildConfigurations = (
- 591157B31CF4011D00C31E3A /* Debug */,
- 591157B41CF4011D00C31E3A /* Release */,
+ 591157B01CF4011D00C31E3A /* Debug */,
+ 591157B11CF4011D00C31E3A /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.h b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.h
index 0039d5e7ca..0039d5e7ca 100644
--- a/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.h
+++ b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.h
diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.m b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.m
index d134c2b591..d134c2b591 100644
--- a/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.m
+++ b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.m
diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.h b/tensorflow/examples/ios/camera/CameraExampleViewController.h
index df744428a8..0aefbc6eed 100644
--- a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.h
+++ b/tensorflow/examples/ios/camera/CameraExampleViewController.h
@@ -29,6 +29,7 @@
dispatch_queue_t videoDataOutputQueue;
AVCaptureStillImageOutput *stillImageOutput;
UIView *flashView;
+ UIImage *square;
BOOL isUsingFrontFacingCamera;
AVSpeechSynthesizer *synth;
NSMutableDictionary *oldPredictionValues;
diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm b/tensorflow/examples/ios/camera/CameraExampleViewController.mm
index 27df3d3d71..d113d50ff8 100644
--- a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm
+++ b/tensorflow/examples/ios/camera/CameraExampleViewController.mm
@@ -323,10 +323,10 @@ didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer
auto predictions = output->flat<float>();
NSMutableDictionary *newValues = [NSMutableDictionary dictionary];
- for (int index = 0; index < predictions.size(); ++index) {
+ for (int index = 0; index < predictions.size(); index += 1) {
const float predictionValue = predictions(index);
if (predictionValue > 0.05f) {
- std::string label = labels[index];
+ std::string label = labels[index % predictions.size()];
NSString *labelObject = [NSString stringWithUTF8String:label.c_str()];
NSNumber *valueObject = [NSNumber numberWithFloat:predictionValue];
[newValues setObject:valueObject forKey:labelObject];
@@ -369,12 +369,17 @@ didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer
isUsingFrontFacingCamera = !isUsingFrontFacingCamera;
}
+- (void)didReceiveMemoryWarning {
+ [super didReceiveMemoryWarning];
+}
+
- (void)viewDidLoad {
[super viewDidLoad];
+ square = [UIImage imageNamed:@"squarePNG"];
synth = [[AVSpeechSynthesizer alloc] init];
labelLayers = [[NSMutableArray alloc] init];
oldPredictionValues = [[NSMutableDictionary alloc] init];
-
+
tensorflow::Status load_status;
if (model_uses_memory_mapping) {
load_status = LoadMemoryMappedModel(
@@ -394,6 +399,26 @@ didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer
[self setupAVCapture];
}
+- (void)viewDidUnload {
+ [super viewDidUnload];
+}
+
+- (void)viewWillAppear:(BOOL)animated {
+ [super viewWillAppear:animated];
+}
+
+- (void)viewDidAppear:(BOOL)animated {
+ [super viewDidAppear:animated];
+}
+
+- (void)viewWillDisappear:(BOOL)animated {
+ [super viewWillDisappear:animated];
+}
+
+- (void)viewDidDisappear:(BOOL)animated {
+ [super viewDidDisappear:animated];
+}
+
- (BOOL)shouldAutorotateToInterfaceOrientation:
(UIInterfaceOrientation)interfaceOrientation {
return (interfaceOrientation == UIInterfaceOrientationPortrait);
diff --git a/tensorflow/contrib/ios_examples/camera/Info.plist b/tensorflow/examples/ios/camera/Info.plist
index 82978ca278..772fb38dcc 100644
--- a/tensorflow/contrib/ios_examples/camera/Info.plist
+++ b/tensorflow/examples/ios/camera/Info.plist
@@ -5,7 +5,7 @@
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleDisplayName</key>
- <string>${PRODUCT_NAME}</string>
+ <string>tf_camera_example</string>
<key>CFBundleExecutable</key>
<string>${EXECUTABLE_NAME}</string>
<key>CFBundleIdentifier</key>
diff --git a/tensorflow/contrib/ios_examples/camera/en.lproj/MainStoryboard_iPhone.storyboard b/tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard
index 0f10a22e41..0f10a22e41 100644
--- a/tensorflow/contrib/ios_examples/camera/en.lproj/MainStoryboard_iPhone.storyboard
+++ b/tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard
diff --git a/tensorflow/examples/ios/camera/Podfile b/tensorflow/examples/ios/camera/Podfile
new file mode 100644
index 0000000000..117828f071
--- /dev/null
+++ b/tensorflow/examples/ios/camera/Podfile
@@ -0,0 +1,5 @@
+platform :ios, '8.0'
+inhibit_all_warnings!
+
+target 'tf_camera_example'
+ pod 'TensorFlow-experimental'
diff --git a/tensorflow/contrib/ios_examples/simple/data/grace_hopper.jpg b/tensorflow/examples/ios/camera/data/grace_hopper.jpg
index d2a427810f..d2a427810f 100644
--- a/tensorflow/contrib/ios_examples/simple/data/grace_hopper.jpg
+++ b/tensorflow/examples/ios/camera/data/grace_hopper.jpg
Binary files differ
diff --git a/tensorflow/contrib/ios_examples/camera/ios_image_load.h b/tensorflow/examples/ios/camera/ios_image_load.h
index 87a847e145..87a847e145 100644
--- a/tensorflow/contrib/ios_examples/camera/ios_image_load.h
+++ b/tensorflow/examples/ios/camera/ios_image_load.h
diff --git a/tensorflow/contrib/ios_examples/camera/ios_image_load.mm b/tensorflow/examples/ios/camera/ios_image_load.mm
index 64d1ea21cf..64d1ea21cf 100644
--- a/tensorflow/contrib/ios_examples/camera/ios_image_load.mm
+++ b/tensorflow/examples/ios/camera/ios_image_load.mm
diff --git a/tensorflow/contrib/ios_examples/camera/main.mm b/tensorflow/examples/ios/camera/main.mm
index 42eff697ef..42eff697ef 100644
--- a/tensorflow/contrib/ios_examples/camera/main.mm
+++ b/tensorflow/examples/ios/camera/main.mm
diff --git a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.h b/tensorflow/examples/ios/camera/tensorflow_utils.h
index 78bdb82aae..78bdb82aae 100644
--- a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.h
+++ b/tensorflow/examples/ios/camera/tensorflow_utils.h
diff --git a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm b/tensorflow/examples/ios/camera/tensorflow_utils.mm
index 43746882ee..56d1e53081 100644
--- a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm
+++ b/tensorflow/examples/ios/camera/tensorflow_utils.mm
@@ -23,18 +23,6 @@
#include <sstream>
#include <string>
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
-#include "google/protobuf/message_lite.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/public/session.h"
-
namespace {
// Helper class used to load protobufs efficiently.
@@ -228,4 +216,4 @@ tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type,
}
t.close();
return tensorflow::Status::OK();
-} \ No newline at end of file
+}
diff --git a/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj b/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj
new file mode 100644
index 0000000000..ee9fe57c79
--- /dev/null
+++ b/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj
@@ -0,0 +1,412 @@
+// !$*UTF8*$!
+{
+ archiveVersion = 1;
+ classes = {
+ };
+ objectVersion = 46;
+ objects = {
+
+/* Begin PBXBuildFile section */
+ 1C3C9DCB1ED3AB4200B8B5FA /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */; };
+ 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */; };
+ 1C968D171ED3B8F20054F5C3 /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; };
+ 1C968D181ED3B8F20054F5C3 /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; };
+ 1C968D191ED3B8F20054F5C3 /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; };
+ 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */; };
+ 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */; };
+ 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */; };
+ 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */; };
+ 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */; };
+ 1CDB2D4C1ED3A9CD007929E9 /* tensorflow_utils.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */; };
+ 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; };
+ 54DC6C3C5F734F3A58069F0C /* libPods-tf_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */; };
+/* End PBXBuildFile section */
+
+/* Begin PBXFileReference section */
+ 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; };
+ 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
+ 1C3C9DC81ED3AB4200B8B5FA /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = "<group>"; };
+ 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = "<group>"; };
+ 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = "<group>"; };
+ 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_camera_example.app; sourceTree = BUILT_PRODUCTS_DIR; };
+ 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; path = MainStoryboard_iPhone.storyboard; sourceTree = "<group>"; };
+ 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; };
+ 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = System/Library/Frameworks/CoreMedia.framework; sourceTree = SDKROOT; };
+ 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = System/Library/Frameworks/AVFoundation.framework; sourceTree = SDKROOT; };
+ 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = "<group>"; };
+ 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = "<group>"; };
+ 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = "<group>"; };
+ 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = "<group>"; };
+ 1CDB2D471ED3A9CD007929E9 /* tensorflow_utils.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = tensorflow_utils.h; sourceTree = "<group>"; };
+ 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = tensorflow_utils.mm; sourceTree = "<group>"; };
+ 1CDB2D4D1ED3AA35007929E9 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
+ 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_camera_example.a"; sourceTree = BUILT_PRODUCTS_DIR; };
+ 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example.debug.xcconfig"; sourceTree = "<group>"; };
+ 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example.release.xcconfig"; sourceTree = "<group>"; };
+ 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = "<group>"; };
+ 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = "<group>"; };
+ 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = "<group>"; };
+/* End PBXFileReference section */
+
+/* Begin PBXFrameworksBuildPhase section */
+ 1C564C0A1ED3A92E00087306 /* Frameworks */ = {
+ isa = PBXFrameworksBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */,
+ 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */,
+ 54DC6C3C5F734F3A58069F0C /* libPods-tf_camera_example.a in Frameworks */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXFrameworksBuildPhase section */
+
+/* Begin PBXGroup section */
+ 24D7686C331131624F4454A0 /* Frameworks */ = {
+ isa = PBXGroup;
+ children = (
+ 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */,
+ 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */,
+ 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */,
+ 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */,
+ 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */,
+ 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */,
+ );
+ name = Frameworks;
+ sourceTree = "<group>";
+ };
+ 3E9FC355632FB928EA23BEED /* Pods */ = {
+ isa = PBXGroup;
+ children = (
+ 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */,
+ 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */,
+ );
+ name = Pods;
+ sourceTree = "<group>";
+ };
+ 591157921CF4011C00C31E3A = {
+ isa = PBXGroup;
+ children = (
+ 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */,
+ 1C3C9DC81ED3AB4200B8B5FA /* ios_image_load.h */,
+ 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */,
+ 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */,
+ 1CDB2D4D1ED3AA35007929E9 /* Info.plist */,
+ 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */,
+ 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */,
+ 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */,
+ 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */,
+ 1CDB2D471ED3A9CD007929E9 /* tensorflow_utils.h */,
+ 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */,
+ 59A3CFF31CF4E68100C4259F /* data */,
+ 5911579C1CF4011C00C31E3A /* Products */,
+ 3E9FC355632FB928EA23BEED /* Pods */,
+ 24D7686C331131624F4454A0 /* Frameworks */,
+ );
+ sourceTree = "<group>";
+ };
+ 5911579C1CF4011C00C31E3A /* Products */ = {
+ isa = PBXGroup;
+ children = (
+ 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */,
+ );
+ name = Products;
+ sourceTree = "<group>";
+ };
+ 59A3CFF31CF4E68100C4259F /* data */ = {
+ isa = PBXGroup;
+ children = (
+ 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */,
+ 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */,
+ 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */,
+ );
+ path = data;
+ sourceTree = "<group>";
+ };
+/* End PBXGroup section */
+
+/* Begin PBXNativeTarget section */
+ 1C564C0C1ED3A92E00087306 /* tf_camera_example */ = {
+ isa = PBXNativeTarget;
+ buildConfigurationList = 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tf_camera_example" */;
+ buildPhases = (
+ 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */,
+ 1C564C091ED3A92E00087306 /* Sources */,
+ 1C564C0A1ED3A92E00087306 /* Frameworks */,
+ 1C564C0B1ED3A92E00087306 /* Resources */,
+ 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */,
+ 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */,
+ );
+ buildRules = (
+ );
+ dependencies = (
+ );
+ name = tf_camera_example;
+ productName = tf_camera_example;
+ productReference = 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */;
+ productType = "com.apple.product-type.application";
+ };
+/* End PBXNativeTarget section */
+
+/* Begin PBXProject section */
+ 591157931CF4011C00C31E3A /* Project object */ = {
+ isa = PBXProject;
+ attributes = {
+ LastSwiftUpdateCheck = 0830;
+ LastUpgradeCheck = 0830;
+ ORGANIZATIONNAME = Google;
+ TargetAttributes = {
+ 1C564C0C1ED3A92E00087306 = {
+ CreatedOnToolsVersion = 8.3.2;
+ DevelopmentTeam = 5DRPWFQSHP;
+ ProvisioningStyle = Automatic;
+ };
+ };
+ };
+ buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_camera_example" */;
+ compatibilityVersion = "Xcode 3.2";
+ developmentRegion = English;
+ hasScannedForEncodings = 0;
+ knownRegions = (
+ en,
+ Base,
+ );
+ mainGroup = 591157921CF4011C00C31E3A;
+ productRefGroup = 5911579C1CF4011C00C31E3A /* Products */;
+ projectDirPath = "";
+ projectRoot = "";
+ targets = (
+ 1C564C0C1ED3A92E00087306 /* tf_camera_example */,
+ );
+ };
+/* End PBXProject section */
+
+/* Begin PBXResourcesBuildPhase section */
+ 1C564C0B1ED3A92E00087306 /* Resources */ = {
+ isa = PBXResourcesBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 1C968D171ED3B8F20054F5C3 /* grace_hopper.jpg in Resources */,
+ 1C968D181ED3B8F20054F5C3 /* imagenet_comp_graph_label_strings.txt in Resources */,
+ 1C968D191ED3B8F20054F5C3 /* tensorflow_inception_graph.pb in Resources */,
+ 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */,
+ 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXResourcesBuildPhase section */
+
+/* Begin PBXShellScriptBuildPhase section */
+ 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Embed Pods Frameworks";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example-frameworks.sh\"\n";
+ showEnvVarsInLog = 0;
+ };
+ 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Copy Pods Resources";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example-resources.sh\"\n";
+ showEnvVarsInLog = 0;
+ };
+ 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Check Pods Manifest.lock";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n";
+ showEnvVarsInLog = 0;
+ };
+/* End PBXShellScriptBuildPhase section */
+
+/* Begin PBXSourcesBuildPhase section */
+ 1C564C091ED3A92E00087306 /* Sources */ = {
+ isa = PBXSourcesBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 1CDB2D4C1ED3A9CD007929E9 /* tensorflow_utils.mm in Sources */,
+ 1C3C9DCB1ED3AB4200B8B5FA /* ios_image_load.mm in Sources */,
+ 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */,
+ 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */,
+ 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXSourcesBuildPhase section */
+
+/* Begin XCBuildConfiguration section */
+ 1C564C361ED3A92E00087306 /* Debug */ = {
+ isa = XCBuildConfiguration;
+ baseConfigurationReference = 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */;
+ buildSettings = {
+ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
+ CLANG_ANALYZER_NONNULL = YES;
+ CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
+ CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
+ DEVELOPMENT_TEAM = 5DRPWFQSHP;
+ INFOPLIST_FILE = Info.plist;
+ IPHONEOS_DEPLOYMENT_TARGET = 10.3;
+ LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
+ PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example";
+ PRODUCT_NAME = "$(TARGET_NAME)";
+ SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
+ SWIFT_OPTIMIZATION_LEVEL = "-Onone";
+ SWIFT_VERSION = 3.0;
+ };
+ name = Debug;
+ };
+ 1C564C371ED3A92E00087306 /* Release */ = {
+ isa = XCBuildConfiguration;
+ baseConfigurationReference = 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */;
+ buildSettings = {
+ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
+ CLANG_ANALYZER_NONNULL = YES;
+ CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
+ CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
+ DEVELOPMENT_TEAM = 5DRPWFQSHP;
+ INFOPLIST_FILE = Info.plist;
+ IPHONEOS_DEPLOYMENT_TARGET = 10.3;
+ LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
+ PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example";
+ PRODUCT_NAME = "$(TARGET_NAME)";
+ SWIFT_OPTIMIZATION_LEVEL = "-Owholemodule";
+ SWIFT_VERSION = 3.0;
+ };
+ name = Release;
+ };
+ 591157B01CF4011D00C31E3A /* Debug */ = {
+ isa = XCBuildConfiguration;
+ buildSettings = {
+ ALWAYS_SEARCH_USER_PATHS = NO;
+ CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
+ CLANG_CXX_LIBRARY = "libc++";
+ CLANG_ENABLE_MODULES = YES;
+ CLANG_ENABLE_OBJC_ARC = YES;
+ CLANG_WARN_BOOL_CONVERSION = YES;
+ CLANG_WARN_CONSTANT_CONVERSION = YES;
+ CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
+ CLANG_WARN_EMPTY_BODY = YES;
+ CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
+ CLANG_WARN_INT_CONVERSION = YES;
+ CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
+ CLANG_WARN_UNREACHABLE_CODE = YES;
+ CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
+ "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
+ COPY_PHASE_STRIP = NO;
+ DEBUG_INFORMATION_FORMAT = dwarf;
+ ENABLE_STRICT_OBJC_MSGSEND = YES;
+ ENABLE_TESTABILITY = YES;
+ GCC_C_LANGUAGE_STANDARD = gnu99;
+ GCC_DYNAMIC_NO_PIC = NO;
+ GCC_NO_COMMON_BLOCKS = YES;
+ GCC_OPTIMIZATION_LEVEL = 0;
+ GCC_PREPROCESSOR_DEFINITIONS = (
+ "DEBUG=1",
+ "$(inherited)",
+ );
+ GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
+ GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
+ GCC_WARN_UNDECLARED_SELECTOR = YES;
+ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
+ GCC_WARN_UNUSED_FUNCTION = YES;
+ GCC_WARN_UNUSED_VARIABLE = YES;
+ IPHONEOS_DEPLOYMENT_TARGET = 8.0;
+ MTL_ENABLE_DEBUG_INFO = YES;
+ ONLY_ACTIVE_ARCH = YES;
+ SDKROOT = iphoneos;
+ TARGETED_DEVICE_FAMILY = "1,2";
+ };
+ name = Debug;
+ };
+ 591157B11CF4011D00C31E3A /* Release */ = {
+ isa = XCBuildConfiguration;
+ buildSettings = {
+ ALWAYS_SEARCH_USER_PATHS = NO;
+ CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
+ CLANG_CXX_LIBRARY = "libc++";
+ CLANG_ENABLE_MODULES = YES;
+ CLANG_ENABLE_OBJC_ARC = YES;
+ CLANG_WARN_BOOL_CONVERSION = YES;
+ CLANG_WARN_CONSTANT_CONVERSION = YES;
+ CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
+ CLANG_WARN_EMPTY_BODY = YES;
+ CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
+ CLANG_WARN_INT_CONVERSION = YES;
+ CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
+ CLANG_WARN_UNREACHABLE_CODE = YES;
+ CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
+ "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
+ COPY_PHASE_STRIP = NO;
+ DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
+ ENABLE_NS_ASSERTIONS = NO;
+ ENABLE_STRICT_OBJC_MSGSEND = YES;
+ GCC_C_LANGUAGE_STANDARD = gnu99;
+ GCC_NO_COMMON_BLOCKS = YES;
+ GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
+ GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
+ GCC_WARN_UNDECLARED_SELECTOR = YES;
+ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
+ GCC_WARN_UNUSED_FUNCTION = YES;
+ GCC_WARN_UNUSED_VARIABLE = YES;
+ IPHONEOS_DEPLOYMENT_TARGET = 8.0;
+ MTL_ENABLE_DEBUG_INFO = NO;
+ SDKROOT = iphoneos;
+ TARGETED_DEVICE_FAMILY = "1,2";
+ VALIDATE_PRODUCT = YES;
+ };
+ name = Release;
+ };
+/* End XCBuildConfiguration section */
+
+/* Begin XCConfigurationList section */
+ 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tf_camera_example" */ = {
+ isa = XCConfigurationList;
+ buildConfigurations = (
+ 1C564C361ED3A92E00087306 /* Debug */,
+ 1C564C371ED3A92E00087306 /* Release */,
+ );
+ defaultConfigurationIsVisible = 0;
+ defaultConfigurationName = Release;
+ };
+ 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_camera_example" */ = {
+ isa = XCConfigurationList;
+ buildConfigurations = (
+ 591157B01CF4011D00C31E3A /* Debug */,
+ 591157B11CF4011D00C31E3A /* Release */,
+ );
+ defaultConfigurationIsVisible = 0;
+ defaultConfigurationName = Release;
+ };
+/* End XCConfigurationList section */
+ };
+ rootObject = 591157931CF4011C00C31E3A /* Project object */;
+}
diff --git a/tensorflow/contrib/ios_examples/simple/AppDelegate.h b/tensorflow/examples/ios/simple/AppDelegate.h
index 75b1f1da38..75b1f1da38 100644
--- a/tensorflow/contrib/ios_examples/simple/AppDelegate.h
+++ b/tensorflow/examples/ios/simple/AppDelegate.h
diff --git a/tensorflow/contrib/ios_examples/simple/AppDelegate.mm b/tensorflow/examples/ios/simple/AppDelegate.mm
index 1e808eb976..1e808eb976 100644
--- a/tensorflow/contrib/ios_examples/simple/AppDelegate.mm
+++ b/tensorflow/examples/ios/simple/AppDelegate.mm
diff --git a/tensorflow/examples/ios/simple/Podfile b/tensorflow/examples/ios/simple/Podfile
new file mode 100644
index 0000000000..1740ad6457
--- /dev/null
+++ b/tensorflow/examples/ios/simple/Podfile
@@ -0,0 +1,5 @@
+platform :ios, '8.0'
+inhibit_all_warnings!
+
+target 'tf_simple_example'
+ pod 'TensorFlow-experimental'
diff --git a/tensorflow/contrib/ios_examples/simple/RunModel-Info.plist b/tensorflow/examples/ios/simple/RunModel-Info.plist
index ca80e68091..d0a8742456 100644
--- a/tensorflow/contrib/ios_examples/simple/RunModel-Info.plist
+++ b/tensorflow/examples/ios/simple/RunModel-Info.plist
@@ -5,11 +5,11 @@
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleDisplayName</key>
- <string>tf_ios_makefile_example</string>
+ <string>tf_simple_example</string>
<key>CFBundleExecutable</key>
- <string>tf_ios_makefile_example</string>
+ <string>tf_simple_example</string>
<key>CFBundleIdentifier</key>
- <string>Google.RunModel</string>
+ <string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
diff --git a/tensorflow/contrib/ios_examples/simple/RunModelViewController.h b/tensorflow/examples/ios/simple/RunModelViewController.h
index 4e1a83ccf5..4e1a83ccf5 100644
--- a/tensorflow/contrib/ios_examples/simple/RunModelViewController.h
+++ b/tensorflow/examples/ios/simple/RunModelViewController.h
diff --git a/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm b/tensorflow/examples/ios/simple/RunModelViewController.mm
index 5c121962d9..c8ccb5c77b 100644
--- a/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm
+++ b/tensorflow/examples/ios/simple/RunModelViewController.mm
@@ -21,17 +21,7 @@
#include <sstream>
#include <string>
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
-#include "google/protobuf/message_lite.h"
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "ios_image_load.h"
@@ -50,7 +40,7 @@ class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
return -1;
}
ifs_.read(static_cast<char*>(buffer), size);
- return ifs_.gcount();
+ return (int)ifs_.gcount();
}
private:
@@ -83,7 +73,7 @@ static void GetTopN(
std::vector<std::pair<float, int> >,
std::greater<std::pair<float, int> > > top_result_pq;
- const int count = prediction.size();
+ const long count = prediction.size();
for (int i = 0; i < count; ++i) {
const float value = prediction(i);
@@ -121,7 +111,7 @@ bool PortableReadFileToProto(const std::string& file_name,
// eventually remove this and quit loud when a large protobuf is passed in.
::google::protobuf::io::CodedInputStream coded_stream(&stream);
// Total bytes hard limit / warning limit are set to 1GB and 512MB
- // respectively.
+ // respectively.
coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
@@ -192,7 +182,7 @@ NSString* RunInferenceOnImage() {
1, wanted_height, wanted_width, wanted_channels}));
auto image_tensor_mapped = image_tensor.tensor<float, 4>();
tensorflow::uint8* in = image_data.data();
- tensorflow::uint8* in_end = (in + (image_height * image_width * image_channels));
+ // tensorflow::uint8* in_end = (in + (image_height * image_width * image_channels));
float* out = image_tensor_mapped.data();
for (int y = 0; y < wanted_height; ++y) {
const int in_y = (y * image_height) / wanted_height;
@@ -209,7 +199,7 @@ NSString* RunInferenceOnImage() {
}
NSString* result = [network_path stringByAppendingString: @" - loaded!"];
- result = [NSString stringWithFormat: @"%@ - %d, %s - %dx%d", result,
+ result = [NSString stringWithFormat: @"%@ - %lu, %s - %dx%d", result,
label_strings.size(), label_strings[0].c_str(), image_width, image_height];
std::string input_layer = "input";
diff --git a/tensorflow/contrib/ios_examples/simple/RunModelViewController.xib b/tensorflow/examples/ios/simple/RunModelViewController.xib
index 93f334b985..93f334b985 100644
--- a/tensorflow/contrib/ios_examples/simple/RunModelViewController.xib
+++ b/tensorflow/examples/ios/simple/RunModelViewController.xib
diff --git a/tensorflow/examples/ios/simple/data/grace_hopper.jpg b/tensorflow/examples/ios/simple/data/grace_hopper.jpg
new file mode 100644
index 0000000000..d2a427810f
--- /dev/null
+++ b/tensorflow/examples/ios/simple/data/grace_hopper.jpg
Binary files differ
diff --git a/tensorflow/contrib/ios_examples/simple/ios_image_load.h b/tensorflow/examples/ios/simple/ios_image_load.h
index 0e0b771118..0e0b771118 100644
--- a/tensorflow/contrib/ios_examples/simple/ios_image_load.h
+++ b/tensorflow/examples/ios/simple/ios_image_load.h
diff --git a/tensorflow/contrib/ios_examples/simple/ios_image_load.mm b/tensorflow/examples/ios/simple/ios_image_load.mm
index 64d1ea21cf..64d1ea21cf 100644
--- a/tensorflow/contrib/ios_examples/simple/ios_image_load.mm
+++ b/tensorflow/examples/ios/simple/ios_image_load.mm
diff --git a/tensorflow/contrib/ios_examples/simple/main.mm b/tensorflow/examples/ios/simple/main.mm
index d70550a730..d70550a730 100644
--- a/tensorflow/contrib/ios_examples/simple/main.mm
+++ b/tensorflow/examples/ios/simple/main.mm
diff --git a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj b/tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj
index 94a0037e4f..55c06e28fb 100644
--- a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj
+++ b/tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj
@@ -7,9 +7,9 @@
objects = {
/* Begin PBXBuildFile section */
- 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */; };
- 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D871D02091F00DF5523 /* libprotobuf.a */; };
- 5993C7741D5D4EAF0048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */; };
+ 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; };
+ 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; };
+ 2530463E3C9A9D5FB9299C0E /* libPods-tf_simple_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */; };
59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; };
59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; };
59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; };
@@ -18,22 +18,17 @@
59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; };
59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */; };
59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */; };
- 59A3D0141CF4E82500C4259F /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */; };
- 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 59A3D0171CF4E86100C4259F /* UIKit.framework */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
- 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libprotobuf-lite.a"; path = "../../makefile/gen/protobuf_ios/lib/libprotobuf-lite.a"; sourceTree = "<group>"; };
- 590E7D871D02091F00DF5523 /* libprotobuf.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libprotobuf.a; path = ../../makefile/gen/protobuf_ios/lib/libprotobuf.a; sourceTree = "<group>"; };
- 5911579B1CF4011C00C31E3A /* tf_ios_makefile_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_ios_makefile_example.app; sourceTree = BUILT_PRODUCTS_DIR; };
- 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; };
+ 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; };
+ 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
+ 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; };
+ 5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; };
59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = "<group>"; };
- 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = cropped_panda.jpg; sourceTree = "<group>"; };
59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = "<group>"; };
- 59A3CFF61CF4E68100C4259F /* imagenet_2012_challenge_label_map_proto.pbtxt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_2012_challenge_label_map_proto.pbtxt; sourceTree = "<group>"; };
59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = "<group>"; };
- 59A3CFF81CF4E68100C4259F /* LICENSE */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = LICENSE; sourceTree = "<group>"; };
59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = "<group>"; };
59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = "<group>"; };
59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = "<group>"; };
@@ -42,9 +37,9 @@
59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = "<group>"; };
59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = "<group>"; };
59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = "<group>"; };
- 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
- 59A3D0151CF4E83D00C4259F /* Foundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Foundation.framework; path = System/Library/Frameworks/Foundation.framework; sourceTree = SDKROOT; };
- 59A3D0171CF4E86100C4259F /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; };
+ 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; };
+ 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_simple_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example.release.xcconfig"; sourceTree = "<group>"; };
+ 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_simple_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example.debug.xcconfig"; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@@ -52,26 +47,38 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
- 5993C7741D5D4EAF0048CE6A /* Accelerate.framework in Frameworks */,
- 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */,
- 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */,
- 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */,
- 59A3D0141CF4E82500C4259F /* CoreGraphics.framework in Frameworks */,
+ 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */,
+ 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */,
+ 2530463E3C9A9D5FB9299C0E /* libPods-tf_simple_example.a in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
+ 24D7686C331131624F4454A0 /* Frameworks */ = {
+ isa = PBXGroup;
+ children = (
+ 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */,
+ 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */,
+ 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */,
+ 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */,
+ );
+ name = Frameworks;
+ sourceTree = "<group>";
+ };
+ 3E9FC355632FB928EA23BEED /* Pods */ = {
+ isa = PBXGroup;
+ children = (
+ 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */,
+ 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */,
+ );
+ name = Pods;
+ sourceTree = "<group>";
+ };
591157921CF4011C00C31E3A = {
isa = PBXGroup;
children = (
- 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */,
- 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */,
- 590E7D871D02091F00DF5523 /* libprotobuf.a */,
- 59A3D0171CF4E86100C4259F /* UIKit.framework */,
- 59A3D0151CF4E83D00C4259F /* Foundation.framework */,
- 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */,
59A3CFF11CF4E68100C4259F /* AppDelegate.h */,
59A3CFF21CF4E68100C4259F /* AppDelegate.mm */,
59A3CFF31CF4E68100C4259F /* data */,
@@ -83,13 +90,15 @@
59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */,
59A3D0001CF4E68100C4259F /* RunModelViewController.xib */,
5911579C1CF4011C00C31E3A /* Products */,
+ 3E9FC355632FB928EA23BEED /* Pods */,
+ 24D7686C331131624F4454A0 /* Frameworks */,
);
sourceTree = "<group>";
};
5911579C1CF4011C00C31E3A /* Products */ = {
isa = PBXGroup;
children = (
- 5911579B1CF4011C00C31E3A /* tf_ios_makefile_example.app */,
+ 5911579B1CF4011C00C31E3A /* tf_simple_example.app */,
);
name = Products;
sourceTree = "<group>";
@@ -97,11 +106,8 @@
59A3CFF31CF4E68100C4259F /* data */ = {
isa = PBXGroup;
children = (
- 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */,
59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */,
- 59A3CFF61CF4E68100C4259F /* imagenet_2012_challenge_label_map_proto.pbtxt */,
59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */,
- 59A3CFF81CF4E68100C4259F /* LICENSE */,
59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */,
);
path = data;
@@ -110,21 +116,24 @@
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
- 5911579A1CF4011C00C31E3A /* tf_ios_makefile_example */ = {
+ 5911579A1CF4011C00C31E3A /* tf_simple_example */ = {
isa = PBXNativeTarget;
- buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_ios_makefile_example" */;
+ buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */;
buildPhases = (
+ 1CD07C1CEB04E50C5975C7BB /* [CP] Check Pods Manifest.lock */,
591157971CF4011C00C31E3A /* Sources */,
591157981CF4011C00C31E3A /* Frameworks */,
591157991CF4011C00C31E3A /* Resources */,
+ 0EABEF9F31578BDA8CA9D2A7 /* [CP] Embed Pods Frameworks */,
+ 96DDF9E6E35958387A215092 /* [CP] Copy Pods Resources */,
);
buildRules = (
);
dependencies = (
);
- name = tf_ios_makefile_example;
+ name = tf_simple_example;
productName = tf_ios_makefile_example;
- productReference = 5911579B1CF4011C00C31E3A /* tf_ios_makefile_example.app */;
+ productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */;
productType = "com.apple.product-type.application";
};
/* End PBXNativeTarget section */
@@ -133,7 +142,7 @@
591157931CF4011C00C31E3A /* Project object */ = {
isa = PBXProject;
attributes = {
- LastUpgradeCheck = 0720;
+ LastUpgradeCheck = 0830;
ORGANIZATIONNAME = Google;
TargetAttributes = {
5911579A1CF4011C00C31E3A = {
@@ -142,7 +151,7 @@
};
};
};
- buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_ios_makefile_example" */;
+ buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_simple_example" */;
compatibilityVersion = "Xcode 3.2";
developmentRegion = English;
hasScannedForEncodings = 0;
@@ -155,7 +164,7 @@
projectDirPath = "";
projectRoot = "";
targets = (
- 5911579A1CF4011C00C31E3A /* tf_ios_makefile_example */,
+ 5911579A1CF4011C00C31E3A /* tf_simple_example */,
);
};
/* End PBXProject section */
@@ -174,6 +183,54 @@
};
/* End PBXResourcesBuildPhase section */
+/* Begin PBXShellScriptBuildPhase section */
+ 0EABEF9F31578BDA8CA9D2A7 /* [CP] Embed Pods Frameworks */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Embed Pods Frameworks";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example-frameworks.sh\"\n";
+ showEnvVarsInLog = 0;
+ };
+ 1CD07C1CEB04E50C5975C7BB /* [CP] Check Pods Manifest.lock */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Check Pods Manifest.lock";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n";
+ showEnvVarsInLog = 0;
+ };
+ 96DDF9E6E35958387A215092 /* [CP] Copy Pods Resources */ = {
+ isa = PBXShellScriptBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ );
+ inputPaths = (
+ );
+ name = "[CP] Copy Pods Resources";
+ outputPaths = (
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ shellPath = /bin/sh;
+ shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example-resources.sh\"\n";
+ showEnvVarsInLog = 0;
+ };
+/* End PBXShellScriptBuildPhase section */
+
/* Begin PBXSourcesBuildPhase section */
591157971CF4011C00C31E3A /* Sources */ = {
isa = PBXSourcesBuildPhase;
@@ -202,8 +259,10 @@
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
@@ -225,7 +284,7 @@
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
- IPHONEOS_DEPLOYMENT_TARGET = 9.2;
+ IPHONEOS_DEPLOYMENT_TARGET = 8.0;
MTL_ENABLE_DEBUG_INFO = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
@@ -246,8 +305,10 @@
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
@@ -263,7 +324,7 @@
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
- IPHONEOS_DEPLOYMENT_TARGET = 9.2;
+ IPHONEOS_DEPLOYMENT_TARGET = 8.0;
MTL_ENABLE_DEBUG_INFO = NO;
SDKROOT = iphoneos;
TARGETED_DEVICE_FAMILY = "1,2";
@@ -273,38 +334,21 @@
};
591157B31CF4011D00C31E3A /* Debug */ = {
isa = XCBuildConfiguration;
+ baseConfigurationReference = 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */;
buildSettings = {
CLANG_DEBUG_INFORMATION_LEVEL = default;
CODE_SIGN_IDENTITY = "iPhone Developer";
ENABLE_BITCODE = NO;
GCC_ENABLE_CPP_EXCEPTIONS = YES;
GCC_ENABLE_CPP_RTTI = YES;
- HEADER_SEARCH_PATHS = (
- "$(SRCROOT)/../../../..",
- "$(SRCROOT)/../../makefile/downloads/protobuf/src/",
- "$(SRCROOT)/../../makefile/downloads",
- "$(SRCROOT)/../../makefile/gen/proto",
- "$(SRCROOT)/../../makefile/downloads/eigen",
- );
+ HEADER_SEARCH_PATHS = "$(inherited)";
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.2;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
- LIBRARY_SEARCH_PATHS = (
- "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
- "$(SRCROOT)/../../makefile/gen/lib",
- );
+ LIBRARY_SEARCH_PATHS = "";
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
- OTHER_LDFLAGS = (
- "-force_load",
- "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
- "-Xlinker",
- "-S",
- "-Xlinker",
- "-x",
- "-Xlinker",
- "-dead_strip",
- );
- PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
+ OTHER_LDFLAGS = "$(inherited)";
+ PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-simple-example";
PRODUCT_NAME = "$(TARGET_NAME)";
SEPARATE_STRIP = NO;
};
@@ -312,39 +356,22 @@
};
591157B41CF4011D00C31E3A /* Release */ = {
isa = XCBuildConfiguration;
+ baseConfigurationReference = 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */;
buildSettings = {
CLANG_DEBUG_INFORMATION_LEVEL = default;
CODE_SIGN_IDENTITY = "iPhone Developer";
ENABLE_BITCODE = NO;
GCC_ENABLE_CPP_EXCEPTIONS = YES;
GCC_ENABLE_CPP_RTTI = YES;
- HEADER_SEARCH_PATHS = (
- "$(SRCROOT)/../../../..",
- "$(SRCROOT)/../../makefile/downloads/protobuf/src/",
- "$(SRCROOT)/../../makefile/downloads",
- "$(SRCROOT)/../../makefile/gen/proto",
- "$(SRCROOT)/../../makefile/downloads/eigen",
- );
+ HEADER_SEARCH_PATHS = "$(inherited)";
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.2;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
- LIBRARY_SEARCH_PATHS = (
- "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
- "$(SRCROOT)/../../makefile/gen/lib",
- );
+ LIBRARY_SEARCH_PATHS = "";
ONLY_ACTIVE_ARCH = YES;
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
- OTHER_LDFLAGS = (
- "-force_load",
- "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
- "-Xlinker",
- "-S",
- "-Xlinker",
- "-x",
- "-Xlinker",
- "-dead_strip",
- );
- PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
+ OTHER_LDFLAGS = "$(inherited)";
+ PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-simple-example";
PRODUCT_NAME = "$(TARGET_NAME)";
SEPARATE_STRIP = NO;
};
@@ -353,7 +380,7 @@
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
- 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_ios_makefile_example" */ = {
+ 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_simple_example" */ = {
isa = XCConfigurationList;
buildConfigurations = (
591157B01CF4011D00C31E3A /* Debug */,
@@ -362,7 +389,7 @@
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
- 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_ios_makefile_example" */ = {
+ 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = {
isa = XCConfigurationList;
buildConfigurations = (
591157B31CF4011D00C31E3A /* Debug */,
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index d9083f994c..9f048d3ea0 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3797,7 +3797,7 @@ func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value t
//
// TensorArray gradient calls use an accumulator TensorArray object. If
// multiple gradients are calculated and run in the same session, the multiple
-// gradient nodes may accidentally flow throuth the same accumulator TensorArray.
+// gradient nodes may accidentally flow through the same accumulator TensorArray.
// This double counts and generally breaks the TensorArray gradient flow.
//
// The solution is to identify which gradient call this particular
@@ -13463,6 +13463,11 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) {
// to compute `input` is odd, it should be provided since it cannot be inferred
// properly.
//
+// Along each axis `IRFFT2D` is computed on, if `fft_length` (or
+// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the
+// corresponding dimension of `input`, the dimension is cropped. If it is larger,
+// the dimension is padded with zeros.
+//
// Arguments:
// input: A complex64 tensor.
// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension.
@@ -16691,6 +16696,10 @@ func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) {
// compute `input` is odd, it should be provided since it cannot be inferred
// properly.
//
+// Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller
+// than the corresponding dimension of `input`, the dimension is cropped. If it is
+// larger, the dimension is padded with zeros.
+//
// Arguments:
// input: A complex64 tensor.
// fft_length: An int32 tensor of shape [1]. The FFT length.
@@ -16874,6 +16883,10 @@ func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *
// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,
// followed by the `fft_length / 2` positive-frequency terms.
//
+// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the
+// corresponding dimension of `input`, the dimension is cropped. If it is larger,
+// the dimension is padded with zeros.
+//
// Arguments:
// input: A float32 tensor.
// fft_length: An int32 tensor of shape [1]. The FFT length.
@@ -17169,6 +17182,10 @@ func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output,
// of `output`: the zero-frequency term, followed by the `fft_length / 2`
// positive-frequency terms.
//
+// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the
+// corresponding dimension of `input`, the dimension is cropped. If it is larger,
+// the dimension is padded with zeros.
+//
// Arguments:
// input: A float32 tensor.
// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension.
@@ -17514,6 +17531,11 @@ func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output) (indices tf
// to compute `input` is odd, it should be provided since it cannot be inferred
// properly.
//
+// Along each axis `IRFFT3D` is computed on, if `fft_length` (or
+// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the
+// corresponding dimension of `input`, the dimension is cropped. If it is larger,
+// the dimension is padded with zeros.
+//
// Arguments:
// input: A complex64 tensor.
// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension.
@@ -18936,6 +18958,10 @@ func Erfc(scope *Scope, x tf.Output) (y tf.Output) {
// of `output`: the zero-frequency term, followed by the `fft_length / 2`
// positive-frequency terms.
//
+// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
+// corresponding dimension of `input`, the dimension is cropped. If it is larger,
+// the dimension is padded with zeros.
+//
// Arguments:
// input: A float32 tensor.
// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension.
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 6b746342e0..8febd61c7e 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -460,6 +460,7 @@ py_library(
deps = [
":estimator",
":export_export",
+ ":linear",
":metric_keys",
":numpy_io",
":pandas_io",
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
index 16b4be7b24..dd89f780e6 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
@@ -322,6 +322,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
if self._model_dir:
shutil.rmtree(self._model_dir)
+ def _as_label(self, data_in_float):
+ return np.rint(data_in_float).astype(np.int64)
+
def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
n_classes, batch_size):
@@ -363,12 +366,13 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn."""
- n_classes = 2
+ n_classes = 3
input_dimension = 2
batch_size = 10
- data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
x_data = data.reshape(batch_size, input_dimension)
- y_data = np.reshape(data[:batch_size], (batch_size, 1))
+ y_data = self._as_label(np.reshape(data[:batch_size], (batch_size, 1)))
# learn y = x
train_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data},
@@ -401,9 +405,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
input_dimension = 1
n_classes = 2
batch_size = 10
- data = np.linspace(0., 2., batch_size, dtype=np.float32)
+ data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
x = pd.DataFrame({'x': data})
- y = pd.Series(data)
+ y = pd.Series(self._as_label(data))
train_input_fn = pandas_io.pandas_input_fn(
x=x,
y=y,
@@ -431,25 +435,28 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def test_input_fn_from_parse_example(self):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
- n_classes = 2
+ n_classes = 3
batch_size = 10
- data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
+ data = np.linspace(0., n_classes-1., batch_size * input_dimension,
+ dtype=np.float32)
data = data.reshape(batch_size, input_dimension)
serialized_examples = []
for datum in data:
example = example_pb2.Example(features=feature_pb2.Features(
feature={
- 'x': feature_pb2.Feature(
- float_list=feature_pb2.FloatList(value=datum)),
- 'y': feature_pb2.Feature(
- float_list=feature_pb2.FloatList(value=datum[:1])),
+ 'x':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=datum)),
+ 'y':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=self._as_label(datum[:1]))),
}))
serialized_examples.append(example.SerializeToString())
feature_spec = {
'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
- 'y': parsing_ops.FixedLenFeature([1], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
}
def _train_input_fn():
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
diff --git a/tensorflow/python/estimator/canned/dnn_test.py b/tensorflow/python/estimator/canned/dnn_test.py
index a374ddf115..f9b585698b 100644
--- a/tensorflow/python/estimator/canned/dnn_test.py
+++ b/tensorflow/python/estimator/canned/dnn_test.py
@@ -305,12 +305,18 @@ class DNNClassifierPredictTest(test.TestCase):
# logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597
# probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597]
# class_ids = argmax(probabilities) = [0]
- self.assertAllClose({
- prediction_keys.PredictionKeys.LOGITS: [-2.08],
- prediction_keys.PredictionKeys.LOGISTIC: [0.11105597],
- prediction_keys.PredictionKeys.PROBABILITIES: [0.88894403, 0.11105597],
- prediction_keys.PredictionKeys.CLASS_IDS: [0],
- }, next(dnn_classifier.predict(input_fn=input_fn)))
+ predictions = next(dnn_classifier.predict(input_fn=input_fn))
+ self.assertAllClose([-2.08],
+ predictions[prediction_keys.PredictionKeys.LOGITS])
+ self.assertAllClose([0.11105597],
+ predictions[prediction_keys.PredictionKeys.LOGISTIC])
+ self.assertAllClose(
+ [0.88894403,
+ 0.11105597], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+ self.assertAllClose([0],
+ predictions[prediction_keys.PredictionKeys.CLASS_IDS])
+ self.assertAllEqual([b'0'],
+ predictions[prediction_keys.PredictionKeys.CLASSES])
def test_multi_dim(self):
"""Asserts predictions for multi-dimensional input and logits."""
@@ -542,6 +548,9 @@ class DNNClassifierIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
+ def _as_label(self, data_in_float):
+ return np.rint(data_in_float).astype(np.int64)
+
def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
n_classes, batch_size):
@@ -579,12 +588,13 @@ class DNNClassifierIntegrationTest(test.TestCase):
def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn."""
- n_classes = 2
+ n_classes = 3
input_dimension = 2
batch_size = 10
- data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
x_data = data.reshape(batch_size, input_dimension)
- y_data = np.reshape(data[:batch_size], (batch_size, 1))
+ y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
# learn y = x
train_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data},
@@ -615,11 +625,11 @@ class DNNClassifierIntegrationTest(test.TestCase):
if not HAS_PANDAS:
return
input_dimension = 1
- n_classes = 2
+ n_classes = 3
batch_size = 10
- data = np.linspace(0., 2., batch_size, dtype=np.float32)
+ data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
x = pd.DataFrame({'x': data})
- y = pd.Series(data)
+ y = pd.Series(self._as_label(data))
train_input_fn = pandas_io.pandas_input_fn(
x=x,
y=y,
@@ -647,25 +657,28 @@ class DNNClassifierIntegrationTest(test.TestCase):
def test_input_fn_from_parse_example(self):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
- n_classes = 2
+ n_classes = 3
batch_size = 10
- data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
data = data.reshape(batch_size, input_dimension)
serialized_examples = []
for datum in data:
example = example_pb2.Example(features=feature_pb2.Features(
feature={
- 'x': feature_pb2.Feature(
- float_list=feature_pb2.FloatList(value=datum)),
- 'y': feature_pb2.Feature(
- float_list=feature_pb2.FloatList(value=datum[:1])),
+ 'x':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=datum)),
+ 'y':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=self._as_label(datum[:1]))),
}))
serialized_examples.append(example.SerializeToString())
feature_spec = {
'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
- 'y': parsing_ops.FixedLenFeature([1], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
}
def _train_input_fn():
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 631ddfc5df..8da1e5104c 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -302,7 +302,8 @@ def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
Raises:
ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid.
"""
- if label_vocabulary is not None and not isinstance(label_vocabulary, list):
+ if label_vocabulary is not None and not isinstance(label_vocabulary,
+ (list, tuple)):
raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
type(label_vocabulary)))
@@ -356,14 +357,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
label_ids = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels)
- assert_less = check_ops.assert_less(
- label_ids,
- ops.convert_to_tensor(self._n_classes, dtype=label_ids.dtype),
- message='Label IDs must < n_classes')
- assert_greater = check_ops.assert_non_negative(
- label_ids, message='Label Ids must >= 0')
- with ops.control_dependencies((assert_less, assert_greater)):
- return array_ops.identity(label_ids)
+ return _assert_range(label_ids, self._n_classes)
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
@@ -459,7 +453,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
- weight_feature_key=None, thresholds=None):
+ weight_feature_key=None, thresholds=None, label_vocabulary=None):
"""Creates a `Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
@@ -475,6 +469,11 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
generated for each threshold value. This threshold is applied to the
logistic values to determine the binary classification (i.e., above the
threshold is `true`, below is `false`.
+ label_vocabulary: A list of strings represents possible label values. If it
+ is not given, that means labels are already encoded within [0, 1]. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. Also there will be errors if vocabulary is not
+ provided and labels are string.
Returns:
An instance of `Head` for binary classification.
@@ -483,50 +482,81 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
ValueError: if `thresholds` contains a value outside of `(0, 1)`.
"""
thresholds = tuple(thresholds) if thresholds else tuple()
+ if label_vocabulary is not None and not isinstance(label_vocabulary,
+ (list, tuple)):
+ raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
+ type(label_vocabulary)))
+
for threshold in thresholds:
if (threshold <= 0.0) or (threshold >= 1.0):
raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,))
return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
- weight_feature_key=weight_feature_key, thresholds=thresholds)
+ weight_feature_key=weight_feature_key,
+ thresholds=thresholds,
+ label_vocabulary=label_vocabulary)
class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
"""See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`."""
- def __init__(self, weight_feature_key=None, thresholds=None):
+ def __init__(self,
+ weight_feature_key=None,
+ thresholds=None,
+ label_vocabulary=None):
self._weight_feature_key = weight_feature_key
self._thresholds = thresholds
+ self._label_vocabulary = label_vocabulary
@property
def logits_dimension(self):
return 1
- def _eval_metric_ops(
- self, labels, logits, logistic, scores, classes, unweighted_loss,
- weights=None):
- with ops.name_scope(
- None, 'metrics',
- (labels, logits, logistic, scores, classes, unweighted_loss, weights)):
+ def _eval_metric_ops(self,
+ labels,
+ logits,
+ logistic,
+ scores,
+ class_ids,
+ unweighted_loss,
+ weights=None):
+ with ops.name_scope(None, 'metrics', (labels, logits, logistic, scores,
+ class_ids, unweighted_loss, weights)):
keys = metric_keys.MetricKeys
labels_mean = _indicator_labels_mean(
labels=labels, weights=weights, name=keys.LABEL_MEAN)
metric_ops = {
# Estimator already adds a metric for loss.
- keys.LOSS_MEAN: metrics_lib.mean(
- unweighted_loss, weights=weights, name=keys.LOSS_MEAN),
- keys.ACCURACY: metrics_lib.accuracy(
- labels=labels, predictions=classes, weights=weights,
- name=keys.ACCURACY),
- keys.PREDICTION_MEAN: _predictions_mean(
- predictions=logistic, weights=weights, name=keys.PREDICTION_MEAN),
- keys.LABEL_MEAN: labels_mean,
- keys.ACCURACY_BASELINE: _accuracy_baseline(labels_mean),
- keys.AUC: _auc(
- labels=labels, predictions=logistic, weights=weights,
- name=keys.AUC),
- keys.AUC_PR: _auc(
- labels=labels, predictions=logistic, weights=weights, curve='PR',
- name=keys.AUC_PR)
+ keys.LOSS_MEAN:
+ metrics_lib.mean(
+ unweighted_loss, weights=weights, name=keys.LOSS_MEAN),
+ keys.ACCURACY:
+ metrics_lib.accuracy(
+ labels=labels,
+ predictions=class_ids,
+ weights=weights,
+ name=keys.ACCURACY),
+ keys.PREDICTION_MEAN:
+ _predictions_mean(
+ predictions=logistic,
+ weights=weights,
+ name=keys.PREDICTION_MEAN),
+ keys.LABEL_MEAN:
+ labels_mean,
+ keys.ACCURACY_BASELINE:
+ _accuracy_baseline(labels_mean),
+ keys.AUC:
+ _auc(
+ labels=labels,
+ predictions=logistic,
+ weights=weights,
+ name=keys.AUC),
+ keys.AUC_PR:
+ _auc(
+ labels=labels,
+ predictions=logistic,
+ weights=weights,
+ curve='PR',
+ name=keys.AUC_PR)
}
for threshold in self._thresholds:
accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
@@ -559,27 +589,39 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
two_class_logits = array_ops.concat(
(array_ops.zeros_like(logits), logits), 1, name='two_class_logits')
scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES)
- classes = array_ops.reshape(
+ class_ids = array_ops.reshape(
math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes')
+ if self._label_vocabulary:
+ table = lookup_ops.index_to_string_table_from_tensor(
+ vocabulary_list=self._label_vocabulary, name='class_string_lookup')
+ classes = table.lookup(class_ids)
+ else:
+ classes = string_ops.as_string(class_ids, name='str_classes')
predictions = {
pred_keys.LOGITS: logits,
pred_keys.LOGISTIC: logistic,
pred_keys.PROBABILITIES: scores,
- pred_keys.CLASS_IDS: classes
+ pred_keys.CLASS_IDS: class_ids,
+ pred_keys.CLASSES: classes,
}
if mode == model_fn.ModeKeys.PREDICT:
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
- export_outputs={'': export_output.ClassificationOutput(
- scores=scores,
- # `ClassificationOutput` requires string classes.
- # TODO(ptucker): Support label_keys.
- classes=string_ops.as_string(classes, name='str_classes'))})
+ export_outputs={
+ '':
+ export_output.ClassificationOutput(
+ scores=scores, classes=classes)
+ })
# Eval.
- labels = _check_labels(_maybe_expand_dim(math_ops.to_float(labels)),
- self.logits_dimension)
+ labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension)
+ if self._label_vocabulary is not None:
+ labels = lookup_ops.index_table_from_tensor(
+ vocabulary_list=tuple(self._label_vocabulary),
+ name='class_id_lookup').lookup(labels)
+ labels = math_ops.to_float(labels)
+ labels = _assert_range(labels, 2)
unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits, name='loss')
weights = (
@@ -598,7 +640,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
logits=logits,
logistic=logistic,
scores=scores,
- classes=classes,
+ class_ids=class_ids,
unweighted_loss=unweighted_loss,
weights=weights))
@@ -721,3 +763,14 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
predictions=predictions,
loss=training_loss,
train_op=train_op_fn(training_loss))
+
+
+def _assert_range(labels, n_classes):
+ assert_less = check_ops.assert_less(
+ labels,
+ ops.convert_to_tensor(n_classes, dtype=labels.dtype),
+ message='Label IDs must < n_classes')
+ assert_greater = check_ops.assert_non_negative(
+ labels, message='Label IDs must >= 0')
+ with ops.control_dependencies((assert_less, assert_greater)):
+ return array_ops.identity(labels)
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 0efafac87a..e3d9258466 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -206,7 +206,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
})
with self.test_session():
- with self.assertRaisesOpError('Label Ids must >= 0'):
+ with self.assertRaisesOpError('Label IDs must >= 0'):
spec.loss.eval({
labels_placeholder: labels_2x1_with_negative_id,
logits_placeholder: logits_2x3
@@ -743,8 +743,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertEqual(1, head.logits_dimension)
# Both logits and labels should be shape (batch_size, 1).
- values_2x1 = np.array(((43.,), (44.,),))
- values_3x1 = np.array(((45.,), (46.,), (47.,),))
+ values_2x1 = np.array(((0.,), (1.,),))
+ values_3x1 = np.array(((0.,), (1.,), (0.,),))
# Static shape.
with self.assertRaisesRegexp(
@@ -788,28 +788,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertEqual(1, head.logits_dimension)
# Create estimator spec.
- logits = np.array(((45,), (-41,),), dtype=np.int32)
+ logits = [[45.], [-41.]]
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- expected_predictions = {
- prediction_keys.PredictionKeys.LOGITS:
- logits.astype(np.float32),
- prediction_keys.PredictionKeys.LOGISTIC:
- _sigmoid(logits).astype(np.float32),
- prediction_keys.PredictionKeys.PROBABILITIES:
- np.array(((0., 1.), (1., 0.),), dtype=np.float32),
- prediction_keys.PredictionKeys.CLASS_IDS:
- np.array(((1,), (0,)), dtype=np.int64),
- }
-
# Assert spec contains expected tensors.
- self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
- self.assertEqual(
- {k: v.dtype for k, v in six.iteritems(expected_predictions)},
- {k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNone(spec.train_op)
@@ -821,7 +806,37 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
- self.assertAllClose(expected_predictions, sess.run(spec.predictions))
+ predictions = sess.run(spec.predictions)
+ self.assertAllClose(logits,
+ predictions[prediction_keys.PredictionKeys.LOGITS])
+ self.assertAllClose(
+ _sigmoid(np.array(logits)),
+ predictions[prediction_keys.PredictionKeys.LOGISTIC])
+ self.assertAllClose(
+ [[0., 1.],
+ [1., 0.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+ self.assertAllClose([[1], [0]],
+ predictions[prediction_keys.PredictionKeys.CLASS_IDS])
+ self.assertAllEqual([[b'1'], [b'0']],
+ predictions[prediction_keys.PredictionKeys.CLASSES])
+
+ def test_predict_with_vocabulary_list(self):
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ label_vocabulary=['aang', 'iroh'])
+
+ logits = [[1.], [0.]]
+ expected_classes = [[b'iroh'], [b'aang']]
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertAllEqual(
+ expected_classes,
+ sess.run(spec.predictions[prediction_keys.PredictionKeys.CLASSES]))
def test_eval(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
@@ -834,17 +849,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=np.array(((1,), (1,),), dtype=np.int32))
- expected_predictions = {
- prediction_keys.PredictionKeys.LOGITS:
- logits.astype(np.float32),
- prediction_keys.PredictionKeys.LOGISTIC:
- _sigmoid(logits).astype(np.float32),
- prediction_keys.PredictionKeys.PROBABILITIES:
- np.array(((0., 1.), (1., 0.),), dtype=np.float32),
- # TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
- prediction_keys.PredictionKeys.CLASS_IDS:
- np.array(((1,), (0,)), dtype=np.int64),
- }
keys = metric_keys.MetricKeys
expected_metrics = {
# loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41
@@ -859,10 +863,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
}
# Assert spec contains expected tensors.
- self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
- self.assertEqual(
- {k: v.dtype for k, v in six.iteritems(expected_predictions)},
- {k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNotNone(spec.loss)
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
@@ -875,15 +875,34 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
- predictions, loss, metrics = sess.run((
- spec.predictions, spec.loss, update_ops))
- self.assertAllClose(expected_predictions, predictions)
+ loss, metrics = sess.run((spec.loss, update_ops))
self.assertAllClose(41., loss)
# Check results of both update (in `metrics`) and value ops.
self.assertAllClose(expected_metrics, metrics)
self.assertAllClose(
expected_metrics, {k: value_ops[k].eval() for k in value_ops})
+ def test_eval_with_vocabulary_list(self):
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ label_vocabulary=['aang', 'iroh'])
+
+ # Create estimator spec.
+ logits = np.array(((45,), (-41,),), dtype=np.float32)
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.float32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=[[b'iroh'], [b'iroh']])
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
+ update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
+ sess.run(update_ops)
+ self.assertAllClose(1. / 2,
+ value_ops[metric_keys.MetricKeys.ACCURACY].eval())
+
def test_eval_with_thresholds(self):
thresholds = [0.25, 0.5, 0.75]
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
@@ -942,23 +961,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=np.array(((1,), (1,),), dtype=np.float64),
train_op_fn=_train_op_fn)
- expected_predictions = {
- prediction_keys.PredictionKeys.LOGITS:
- logits.astype(np.float32),
- prediction_keys.PredictionKeys.LOGISTIC:
- _sigmoid(logits).astype(np.float32),
- prediction_keys.PredictionKeys.PROBABILITIES:
- np.array(((0., 1.), (1., 0.),), dtype=np.float32),
- # TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
- prediction_keys.PredictionKeys.CLASS_IDS:
- np.array(((1,), (0,)), dtype=np.int64),
- }
-
# Assert spec contains expected tensors.
- self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
- self.assertEqual(
- {k: v.dtype for k, v in six.iteritems(expected_predictions)},
- {k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNotNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNotNone(spec.train_op)
@@ -969,9 +972,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
- predictions, loss, train_result, summary_str = sess.run((
- spec.predictions, spec.loss, spec.train_op, spec.scaffold.summary_op))
- self.assertAllClose(expected_predictions, predictions)
+ loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
+ spec.scaffold.summary_op))
self.assertAllClose(expected_loss, loss)
self.assertEqual(expected_train_result, train_result)
_assert_simple_summaries(self, {
@@ -995,28 +997,23 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- expected_predictions = {
- prediction_keys.PredictionKeys.LOGITS:
- logits.astype(np.float32),
- prediction_keys.PredictionKeys.LOGISTIC:
- _sigmoid(logits).astype(np.float32),
- prediction_keys.PredictionKeys.PROBABILITIES:
- np.array(((0., 1.), (1., 0.), (0., 1.)), dtype=np.float32),
- # TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
- prediction_keys.PredictionKeys.CLASS_IDS:
- np.array(((1,), (0,), (1,)), dtype=np.int64),
- }
-
- # Assert spec contains expected tensors.
- self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
- self.assertEqual(
- {k: v.dtype for k, v in six.iteritems(expected_predictions)},
- {k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
-
# Assert predictions, loss, and metrics.
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
- self.assertAllClose(expected_predictions, sess.run(spec.predictions))
+ predictions = sess.run(spec.predictions)
+ self.assertAllClose(
+ logits.astype(np.float32),
+ predictions[prediction_keys.PredictionKeys.LOGITS])
+ self.assertAllClose(
+ _sigmoid(logits).astype(np.float32),
+ predictions[prediction_keys.PredictionKeys.LOGISTIC])
+ self.assertAllClose(
+ [[0., 1.], [1., 0.],
+ [0., 1.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+ self.assertAllClose([[1], [0], [1]],
+ predictions[prediction_keys.PredictionKeys.CLASS_IDS])
+ self.assertAllEqual([[b'1'], [b'0'], [b'1']],
+ predictions[prediction_keys.PredictionKeys.CLASSES])
def test_weighted_multi_example_eval(self):
"""3 examples, 1 batch."""
diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py
index f4ba382954..1db3dfbf3a 100644
--- a/tensorflow/python/estimator/canned/linear_test.py
+++ b/tensorflow/python/estimator/canned/linear_test.py
@@ -18,31 +18,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
-import shutil
-import tempfile
-
-import numpy as np
-
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import linear_testing_utils
-from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column as feature_column_lib
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import optimizer
def _linear_regressor_fn(*args, **kwargs):
return linear.LinearRegressor(*args, **kwargs)
+def _linear_classifier_fn(*args, **kwargs):
+ return linear.LinearClassifier(*args, **kwargs)
+
+
+# Tests for Linear Regressor.
+
+
class LinearRegressorPartitionerTest(
linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
@@ -88,348 +79,43 @@ class LinearRegressorTrainingTest(
self, _linear_regressor_fn)
-class _BaseLinearClassiferTrainingTest(object):
-
- def __init__(self, n_classes):
- self._n_classes = n_classes
- self._logits_dimensions = (
- self._n_classes if self._n_classes > 2 else 1)
-
- def setUp(self):
- self._model_dir = tempfile.mkdtemp()
-
- def tearDown(self):
- if self._model_dir:
- writer_cache.FileWriterCache.clear()
- shutil.rmtree(self._model_dir)
-
- def _mock_optimizer(self, expected_loss=None):
- expected_var_names = [
- '%s/part_0:0' % linear_testing_utils.AGE_WEIGHT_NAME,
- '%s/part_0:0' % linear_testing_utils.BIAS_NAME
- ]
-
- def _minimize(loss, global_step):
- trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertItemsEqual(
- expected_var_names,
- [var.name for var in trainable_vars])
-
- # Verify loss. We can't check the value directly, so we add an assert op.
- self.assertEquals(0, loss.shape.ndims)
- if expected_loss is None:
- return state_ops.assign_add(global_step, 1).op
- assert_loss = linear_testing_utils.assert_close(
- math_ops.to_float(expected_loss, name='expected'),
- loss,
- name='assert_loss')
- with ops.control_dependencies((assert_loss,)):
- return state_ops.assign_add(global_step, 1).op
-
- mock_optimizer = test.mock.NonCallableMock(
- spec=optimizer.Optimizer,
- wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
- mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
-
- # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
- # So, return mock_optimizer itself for deepcopy.
- mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
- return mock_optimizer
-
- def _assert_checkpoint(
- self, expected_global_step, expected_age_weight=None, expected_bias=None):
- logits_dimension = self._logits_dimensions
-
- shapes = {
- name: shape for (name, shape) in
- checkpoint_utils.list_variables(self._model_dir)
- }
-
- self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
- self.assertEqual(
- expected_global_step,
- checkpoint_utils.load_variable(
- self._model_dir, ops.GraphKeys.GLOBAL_STEP))
-
- self.assertEqual([1, logits_dimension],
- shapes[linear_testing_utils.AGE_WEIGHT_NAME])
- if expected_age_weight is not None:
- self.assertAllEqual(expected_age_weight,
- checkpoint_utils.load_variable(
- self._model_dir,
- linear_testing_utils.AGE_WEIGHT_NAME))
-
- self.assertEqual([logits_dimension], shapes[linear_testing_utils.BIAS_NAME])
- if expected_bias is not None:
- self.assertAllEqual(expected_bias,
- checkpoint_utils.load_variable(
- self._model_dir, linear_testing_utils.BIAS_NAME))
-
- def testFromScratchWithDefaultOptimizer(self):
- n_classes = self._n_classes
- label = 0
- age = 17
- est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
- n_classes=n_classes,
- model_dir=self._model_dir)
-
- # Train for a few steps, and validate final checkpoint.
- num_steps = 10
- est.train(
- input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
- self._assert_checkpoint(num_steps)
-
- def testTrainWithTwoDimsLabel(self):
- n_classes = self._n_classes
- batch_size = 20
-
- est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
- n_classes=n_classes,
- model_dir=self._model_dir)
- data_rank_1 = np.array([0, 1])
- data_rank_2 = np.array([[0], [1]])
- self.assertEqual((2,), data_rank_1.shape)
- self.assertEqual((2, 1), data_rank_2.shape)
-
- train_input_fn = numpy_io.numpy_input_fn(
- x={'age': data_rank_1},
- y=data_rank_2,
- batch_size=batch_size,
- num_epochs=None,
- shuffle=True)
- est.train(train_input_fn, steps=200)
- self._assert_checkpoint(200)
-
- def testTrainWithOneDimLabel(self):
- n_classes = self._n_classes
- batch_size = 20
-
- est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
- n_classes=n_classes,
- model_dir=self._model_dir)
- data_rank_1 = np.array([0, 1])
- self.assertEqual((2,), data_rank_1.shape)
-
- train_input_fn = numpy_io.numpy_input_fn(
- x={'age': data_rank_1},
- y=data_rank_1,
- batch_size=batch_size,
- num_epochs=None,
- shuffle=True)
- est.train(train_input_fn, steps=200)
- self._assert_checkpoint(200)
-
- def testTrainWithTwoDimsWeight(self):
- n_classes = self._n_classes
- batch_size = 20
-
- est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
- weight_feature_key='w',
- n_classes=n_classes,
- model_dir=self._model_dir)
- data_rank_1 = np.array([0, 1])
- data_rank_2 = np.array([[0], [1]])
- self.assertEqual((2,), data_rank_1.shape)
- self.assertEqual((2, 1), data_rank_2.shape)
-
- train_input_fn = numpy_io.numpy_input_fn(
- x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1,
- batch_size=batch_size, num_epochs=None,
- shuffle=True)
- est.train(train_input_fn, steps=200)
- self._assert_checkpoint(200)
-
- def testTrainWithOneDimWeight(self):
- n_classes = self._n_classes
- batch_size = 20
-
- est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
- weight_feature_key='w',
- n_classes=n_classes,
- model_dir=self._model_dir)
- data_rank_1 = np.array([0, 1])
- self.assertEqual((2,), data_rank_1.shape)
-
- train_input_fn = numpy_io.numpy_input_fn(
- x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1,
- batch_size=batch_size, num_epochs=None,
- shuffle=True)
- est.train(train_input_fn, steps=200)
- self._assert_checkpoint(200)
-
- def testFromScratch(self):
- n_classes = self._n_classes
- label = 1
- age = 17
- # For binary classifer:
- # loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are
- # all zero initially) and label = 1 so,
- # loss = 1 * -log ( sigmoid(logits) ) = 0.69315
- # For multi class classifer:
- # loss = cross_entropy(logits, label) where logits are all 0s (weights are
- # all zero initially) and label = 1 so,
- # loss = 1 * -log ( 1.0 / n_classes )
- # For this particular test case, as logits are same, the formular
- # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.
- mock_optimizer = self._mock_optimizer(
- expected_loss=-1 * math.log(1.0/n_classes))
-
- est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
- n_classes=n_classes,
- optimizer=mock_optimizer,
- model_dir=self._model_dir)
- self.assertEqual(0, mock_optimizer.minimize.call_count)
-
- # Train for a few steps, and validate optimizer and final checkpoint.
- num_steps = 10
- est.train(
- input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
- self.assertEqual(1, mock_optimizer.minimize.call_count)
- self._assert_checkpoint(
- expected_global_step=num_steps,
- expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes],
- expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)
-
- def testFromCheckpoint(self):
- # Create initial checkpoint.
- n_classes = self._n_classes
- label = 1
- age = 17
- # For binary case, the expected weight has shape (1,1). For multi class
- # case, the shape is (1, n_classes). In order to test the weights, set
- # weights as 2.0 * range(n_classes).
- age_weight = [[2.0]] if n_classes == 2 else (
- np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32),
- (1, n_classes)))
- bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes
- initial_global_step = 100
- with ops.Graph().as_default():
- variables.Variable(age_weight, name=linear_testing_utils.AGE_WEIGHT_NAME)
- variables.Variable(bias, name=linear_testing_utils.BIAS_NAME)
- variables.Variable(
- initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
- dtype=dtypes.int64)
- linear_testing_utils.save_variables_to_ckpt(self._model_dir)
-
- # For binary classifer:
- # logits = age * age_weight + bias = 17 * 2. - 35. = -1.
- # loss = sigmoid_cross_entropy(logits, label)
- # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133
- # For multi class classifer:
- # loss = cross_entropy(logits, label)
- # where logits = 17 * age_weight + bias and label = 1
- # so, loss = 1 * -log ( soft_max(logits)[1] )
- if n_classes == 2:
- expected_loss = 1.3133
- else:
- logits = age_weight * age + bias
- logits_exp = np.exp(logits)
- softmax = logits_exp / logits_exp.sum()
- expected_loss = -1 * math.log(softmax[0, label])
-
- mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
-
- est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
- n_classes=n_classes,
- optimizer=mock_optimizer,
- model_dir=self._model_dir)
- self.assertEqual(0, mock_optimizer.minimize.call_count)
-
- # Train for a few steps, and validate optimizer and final checkpoint.
- num_steps = 10
- est.train(
- input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
- self.assertEqual(1, mock_optimizer.minimize.call_count)
- self._assert_checkpoint(
- expected_global_step=initial_global_step + num_steps,
- expected_age_weight=age_weight,
- expected_bias=bias)
-
- def testFromCheckpointMultiBatch(self):
- # Create initial checkpoint.
- n_classes = self._n_classes
- label = [1, 0]
- age = [17, 18.5]
- # For binary case, the expected weight has shape (1,1). For multi class
- # case, the shape is (1, n_classes). In order to test the weights, set
- # weights as 2.0 * range(n_classes).
- age_weight = [[2.0]] if n_classes == 2 else (
- np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32),
- (1, n_classes)))
- bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes
- initial_global_step = 100
- with ops.Graph().as_default():
- variables.Variable(age_weight, name=linear_testing_utils.AGE_WEIGHT_NAME)
- variables.Variable(bias, name=linear_testing_utils.BIAS_NAME)
- variables.Variable(
- initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
- dtype=dtypes.int64)
- linear_testing_utils.save_variables_to_ckpt(self._model_dir)
-
- # For binary classifer:
- # logits = age * age_weight + bias
- # logits[0] = 17 * 2. - 35. = -1.
- # logits[1] = 18.5 * 2. - 35. = 2.
- # loss = sigmoid_cross_entropy(logits, label)
- # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133
- # loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269
- # For multi class classifer:
- # loss = cross_entropy(logits, label)
- # where logits = [17, 18.5] * age_weight + bias and label = [1, 0]
- # so, loss = 1 * -log ( soft_max(logits)[label] )
- if n_classes == 2:
- expected_loss = (1.3133 + 2.1269)
- else:
- logits = age_weight * np.reshape(age, (2, 1)) + bias
- logits_exp = np.exp(logits)
- softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
- softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
- expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
- expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
- expected_loss = expected_loss_0 + expected_loss_1
-
- mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
-
- est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
- n_classes=n_classes,
- optimizer=mock_optimizer,
- model_dir=self._model_dir)
- self.assertEqual(0, mock_optimizer.minimize.call_count)
-
- # Train for a few steps, and validate optimizer and final checkpoint.
- num_steps = 10
- est.train(
- input_fn=lambda: ({'age': (age)}, (label)),
- steps=num_steps)
- self.assertEqual(1, mock_optimizer.minimize.call_count)
- self._assert_checkpoint(
- expected_global_step=initial_global_step + num_steps,
- expected_age_weight=age_weight,
- expected_bias=bias)
+# Tests for Linear Classifer.
class LinearClassiferWithBinaryClassesTrainingTest(
- _BaseLinearClassiferTrainingTest, test.TestCase):
+ linear_testing_utils.BaseLinearClassiferTrainingTest, test.TestCase):
- def __init__(self, methodName='runTest'):
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- _BaseLinearClassiferTrainingTest.__init__(self, n_classes=2)
+ linear_testing_utils.BaseLinearClassiferTrainingTest.__init__(
+ self, n_classes=2)
class LinearClassiferWithMultiClassesTrainingTest(
- _BaseLinearClassiferTrainingTest, test.TestCase):
+ linear_testing_utils.BaseLinearClassiferTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassiferTrainingTest.__init__(
+ self, n_classes=4)
+
+
+class LinearClassiferWithBinaryClassesEvaluationTest(
+ linear_testing_utils.BaseLinearClassiferEvaluationTest, test.TestCase):
- def __init__(self, methodName='runTest'):
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassiferEvaluationTest.__init__(
+ self, n_classes=2, linear_classifer_fn=_linear_classifier_fn)
+
+
+class LinearClassiferWithMultiClassesEvaluationTest(
+ linear_testing_utils.BaseLinearClassiferEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- _BaseLinearClassiferTrainingTest.__init__(self, n_classes=4)
+ linear_testing_utils.BaseLinearClassiferEvaluationTest.__init__(
+ self, n_classes=4, linear_classifer_fn=_linear_classifier_fn)
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 841dc7bdae..bed9556cf6 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import math
import os
import shutil
import tempfile
@@ -30,6 +31,7 @@ from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
@@ -113,6 +115,14 @@ def queue_parsed_features(feature_map):
return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
+def sorted_key_dict(unsorted_dict):
+ return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}
+
+
+def sigmoid(x):
+ return 1 / (1 + np.exp(-1.0 * x))
+
+
class CheckPartitionerVarHook(session_run_hook.SessionRunHook):
"""A `SessionRunHook` to check a paritioned variable."""
@@ -862,3 +872,545 @@ class BaseLinearRegressorTrainingTest(object):
expected_global_step=initial_global_step + num_steps,
expected_age_weight=age_weight,
expected_bias=bias)
+
+
+class BaseLinearClassiferTrainingTest(object):
+
+ def __init__(self, n_classes):
+ self._n_classes = n_classes
+ self._logits_dimensions = (
+ self._n_classes if self._n_classes > 2 else 1)
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ shutil.rmtree(self._model_dir)
+
+ def _mock_optimizer(self, expected_loss=None):
+ expected_var_names = [
+ '%s/part_0:0' % AGE_WEIGHT_NAME,
+ '%s/part_0:0' % BIAS_NAME
+ ]
+
+ def _minimize(loss, global_step):
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertItemsEqual(
+ expected_var_names,
+ [var.name for var in trainable_vars])
+
+ # Verify loss. We can't check the value directly, so we add an assert op.
+ self.assertEquals(0, loss.shape.ndims)
+ if expected_loss is None:
+ return state_ops.assign_add(global_step, 1).op
+ assert_loss = assert_close(
+ math_ops.to_float(expected_loss, name='expected'),
+ loss,
+ name='assert_loss')
+ with ops.control_dependencies((assert_loss,)):
+ return state_ops.assign_add(global_step, 1).op
+
+ mock_optimizer = test.mock.NonCallableMock(
+ spec=optimizer.Optimizer,
+ wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+ mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
+
+ # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
+ # So, return mock_optimizer itself for deepcopy.
+ mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
+ return mock_optimizer
+
+ def _assert_checkpoint(
+ self, expected_global_step, expected_age_weight=None, expected_bias=None):
+ logits_dimension = self._logits_dimensions
+
+ shapes = {
+ name: shape for (name, shape) in
+ checkpoint_utils.list_variables(self._model_dir)
+ }
+
+ self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
+ self.assertEqual(
+ expected_global_step,
+ checkpoint_utils.load_variable(
+ self._model_dir, ops.GraphKeys.GLOBAL_STEP))
+
+ self.assertEqual([1, logits_dimension],
+ shapes[AGE_WEIGHT_NAME])
+ if expected_age_weight is not None:
+ self.assertAllEqual(expected_age_weight,
+ checkpoint_utils.load_variable(
+ self._model_dir,
+ AGE_WEIGHT_NAME))
+
+ self.assertEqual([logits_dimension], shapes[BIAS_NAME])
+ if expected_bias is not None:
+ self.assertAllEqual(expected_bias,
+ checkpoint_utils.load_variable(
+ self._model_dir, BIAS_NAME))
+
+ def testFromScratchWithDefaultOptimizer(self):
+ n_classes = self._n_classes
+ label = 0
+ age = 17
+ est = linear.LinearClassifier(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # Train for a few steps, and validate final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self._assert_checkpoint(num_steps)
+
+ def testTrainWithTwoDimsLabel(self):
+ n_classes = self._n_classes
+ batch_size = 20
+
+ est = linear.LinearClassifier(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ data_rank_1 = np.array([0, 1])
+ data_rank_2 = np.array([[0], [1]])
+ self.assertEqual((2,), data_rank_1.shape)
+ self.assertEqual((2, 1), data_rank_2.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1},
+ y=data_rank_2,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(200)
+
+ def testTrainWithOneDimLabel(self):
+ n_classes = self._n_classes
+ batch_size = 20
+
+ est = linear.LinearClassifier(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ data_rank_1 = np.array([0, 1])
+ self.assertEqual((2,), data_rank_1.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1},
+ y=data_rank_1,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(200)
+
+ def testTrainWithTwoDimsWeight(self):
+ n_classes = self._n_classes
+ batch_size = 20
+
+ est = linear.LinearClassifier(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ weight_feature_key='w',
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ data_rank_1 = np.array([0, 1])
+ data_rank_2 = np.array([[0], [1]])
+ self.assertEqual((2,), data_rank_1.shape)
+ self.assertEqual((2, 1), data_rank_2.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1,
+ batch_size=batch_size, num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(200)
+
+ def testTrainWithOneDimWeight(self):
+ n_classes = self._n_classes
+ batch_size = 20
+
+ est = linear.LinearClassifier(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ weight_feature_key='w',
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ data_rank_1 = np.array([0, 1])
+ self.assertEqual((2,), data_rank_1.shape)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1,
+ batch_size=batch_size, num_epochs=None,
+ shuffle=True)
+ est.train(train_input_fn, steps=200)
+ self._assert_checkpoint(200)
+
+ def testFromScratch(self):
+ n_classes = self._n_classes
+ label = 1
+ age = 17
+ # For binary classifer:
+ # loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are
+ # all zero initially) and label = 1 so,
+ # loss = 1 * -log ( sigmoid(logits) ) = 0.69315
+ # For multi class classifer:
+ # loss = cross_entropy(logits, label) where logits are all 0s (weights are
+ # all zero initially) and label = 1 so,
+ # loss = 1 * -log ( 1.0 / n_classes )
+ # For this particular test case, as logits are same, the formular
+ # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.
+ mock_optimizer = self._mock_optimizer(
+ expected_loss=-1 * math.log(1.0/n_classes))
+
+ est = linear.LinearClassifier(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ expected_global_step=num_steps,
+ expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes],
+ expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)
+
+ def testFromCheckpoint(self):
+ # Create initial checkpoint.
+ n_classes = self._n_classes
+ label = 1
+ age = 17
+ # For binary case, the expected weight has shape (1,1). For multi class
+ # case, the shape is (1, n_classes). In order to test the weights, set
+ # weights as 2.0 * range(n_classes).
+ age_weight = [[2.0]] if n_classes == 2 else (
+ np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32),
+ (1, n_classes)))
+ bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(age_weight, name=AGE_WEIGHT_NAME)
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ # For binary classifer:
+ # logits = age * age_weight + bias = 17 * 2. - 35. = -1.
+ # loss = sigmoid_cross_entropy(logits, label)
+ # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133
+ # For multi class classifer:
+ # loss = cross_entropy(logits, label)
+ # where logits = 17 * age_weight + bias and label = 1
+ # so, loss = 1 * -log ( soft_max(logits)[1] )
+ if n_classes == 2:
+ expected_loss = 1.3133
+ else:
+ logits = age_weight * age + bias
+ logits_exp = np.exp(logits)
+ softmax = logits_exp / logits_exp.sum()
+ expected_loss = -1 * math.log(softmax[0, label])
+
+ mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
+
+ est = linear.LinearClassifier(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ expected_global_step=initial_global_step + num_steps,
+ expected_age_weight=age_weight,
+ expected_bias=bias)
+
+ def testFromCheckpointMultiBatch(self):
+ # Create initial checkpoint.
+ n_classes = self._n_classes
+ label = [1, 0]
+ age = [17, 18.5]
+ # For binary case, the expected weight has shape (1,1). For multi class
+ # case, the shape is (1, n_classes). In order to test the weights, set
+ # weights as 2.0 * range(n_classes).
+ age_weight = [[2.0]] if n_classes == 2 else (
+ np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32),
+ (1, n_classes)))
+ bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(age_weight, name=AGE_WEIGHT_NAME)
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ # For binary classifer:
+ # logits = age * age_weight + bias
+ # logits[0] = 17 * 2. - 35. = -1.
+ # logits[1] = 18.5 * 2. - 35. = 2.
+ # loss = sigmoid_cross_entropy(logits, label)
+ # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133
+ # loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269
+ # For multi class classifer:
+ # loss = cross_entropy(logits, label)
+ # where logits = [17, 18.5] * age_weight + bias and label = [1, 0]
+ # so, loss = 1 * -log ( soft_max(logits)[label] )
+ if n_classes == 2:
+ expected_loss = (1.3133 + 2.1269)
+ else:
+ logits = age_weight * np.reshape(age, (2, 1)) + bias
+ logits_exp = np.exp(logits)
+ softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+ softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+ expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+ expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+ expected_loss = expected_loss_0 + expected_loss_1
+
+ mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
+
+ est = linear.LinearClassifier(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ est.train(
+ input_fn=lambda: ({'age': (age)}, (label)),
+ steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ expected_global_step=initial_global_step + num_steps,
+ expected_age_weight=age_weight,
+ expected_bias=bias)
+
+
+class BaseLinearClassiferEvaluationTest(object):
+
+ def __init__(self, n_classes, linear_classifer_fn):
+ self._linear_classifer_fn = linear_classifer_fn
+ self._n_classes = n_classes
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ shutil.rmtree(self._model_dir)
+
+ def test_evaluation_for_simple_data(self):
+ n_classes = self._n_classes
+ label = 1
+ age = 1.
+
+ # For binary case, the expected weight has shape (1,1). For multi class
+ # case, the shape is (1, n_classes). In order to test the weights, set
+ # weights as 2.0 * range(n_classes).
+ age_weight = [[-11.0]] if n_classes == 2 else (
+ np.reshape(-11.0 * np.array(list(range(n_classes)), dtype=np.float32),
+ (1, n_classes)))
+ bias = [-30.0] if n_classes == 2 else [-30.0] * n_classes
+
+ with ops.Graph().as_default():
+ variables.Variable(age_weight, name=AGE_WEIGHT_NAME)
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ est = self._linear_classifer_fn(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ eval_metrics = est.evaluate(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=1)
+
+ if n_classes == 2:
+ # Binary classes: loss = sum(corss_entropy(41)) = 41.
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: 41.,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: 41.,
+ metric_keys.MetricKeys.ACCURACY: 0.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 0.,
+ metric_keys.MetricKeys.LABEL_MEAN: 1.,
+ metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
+ metric_keys.MetricKeys.AUC: 0.,
+ metric_keys.MetricKeys.AUC_PR: 1.,
+ }
+ else:
+ # Multi classes: loss = 1 * -log ( soft_max(logits)[label] )
+ logits = age_weight * age + bias
+ logits_exp = np.exp(logits)
+ softmax = logits_exp / logits_exp.sum()
+ expected_loss = -1 * math.log(softmax[0, label])
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
+ metric_keys.MetricKeys.ACCURACY: 0.,
+ }
+
+ self.assertAllClose(sorted_key_dict(expected_metrics),
+ sorted_key_dict(eval_metrics), rtol=1e-3)
+
+ def test_evaluation_batch(self):
+ """Tests evaluation for batch_size==2."""
+ n_classes = self._n_classes
+ label = [1, 0]
+ age = [17., 18.]
+ # For binary case, the expected weight has shape (1,1). For multi class
+ # case, the shape is (1, n_classes). In order to test the weights, set
+ # weights as 2.0 * range(n_classes).
+ age_weight = [[2.0]] if n_classes == 2 else (
+ np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32),
+ (1, n_classes)))
+ bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(age_weight, name=AGE_WEIGHT_NAME)
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ est = self._linear_classifer_fn(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+ eval_metrics = est.evaluate(
+ input_fn=lambda: ({'age': (age)}, (label)), steps=1)
+
+ if n_classes == 2:
+ # Logits are (-1., 1.) labels are (1, 0).
+ # Loss is
+ # loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133
+ # loss for row 2: (1 - 0) * -log(1 - sigmoid(1)) = 1.3133
+ expected_loss = 1.3133 * 2
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
+ metric_keys.MetricKeys.ACCURACY: 0.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 0.5,
+ metric_keys.MetricKeys.LABEL_MEAN: 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ metric_keys.MetricKeys.AUC: 0.,
+ metric_keys.MetricKeys.AUC_PR: 0.25,
+ }
+ else:
+ # Multi classes: loss = 1 * -log ( soft_max(logits)[label] )
+ logits = age_weight * np.reshape(age, (2, 1)) + bias
+ logits_exp = np.exp(logits)
+ softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+ softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+ expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+ expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+ expected_loss = expected_loss_0 + expected_loss_1
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
+ metric_keys.MetricKeys.ACCURACY: 0.,
+ }
+
+ self.assertAllClose(sorted_key_dict(expected_metrics),
+ sorted_key_dict(eval_metrics), rtol=1e-3)
+
+ def test_evaluation_weights(self):
+ """Tests evaluation with weights."""
+
+ n_classes = self._n_classes
+ label = [1, 0]
+ age = [17., 18.]
+ weights = [1., 2.]
+ # For binary case, the expected weight has shape (1,1). For multi class
+ # case, the shape is (1, n_classes). In order to test the weights, set
+ # weights as 2.0 * range(n_classes).
+ age_weight = [[2.0]] if n_classes == 2 else (
+ np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32),
+ (1, n_classes)))
+ bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable(age_weight, name=AGE_WEIGHT_NAME)
+ variables.Variable(bias, name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ est = self._linear_classifer_fn(
+ feature_columns=(feature_column_lib.numeric_column('age'),),
+ n_classes=n_classes,
+ weight_feature_key='w',
+ model_dir=self._model_dir)
+ eval_metrics = est.evaluate(
+ input_fn=lambda: ({'age': (age), 'w': (weights)}, (label)), steps=1)
+
+ if n_classes == 2:
+ # Logits are (-1., 1.) labels are (1, 0).
+ # Loss is
+ # loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133
+ # loss for row 2: (1 - 0) * -log(1 - sigmoid(1)) = 1.3133
+ # weights = [1., 2.]
+ expected_loss = 1.3133 * (1. + 2.)
+ loss_mean = expected_loss / (1.0 + 2.0)
+ label_mean = np.average(label, weights=weights)
+ logits = [-1, 1]
+ logistics = sigmoid(np.array(logits))
+ predictions_mean = np.average(logistics, weights=weights)
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: loss_mean,
+ metric_keys.MetricKeys.ACCURACY: 0.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,
+ metric_keys.MetricKeys.LABEL_MEAN: label_mean,
+ metric_keys.MetricKeys.ACCURACY_BASELINE: (
+ max(label_mean, 1-label_mean)),
+ metric_keys.MetricKeys.AUC: 0.,
+ metric_keys.MetricKeys.AUC_PR: 0.1668,
+ }
+ else:
+ # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )
+ logits = age_weight * np.reshape(age, (2, 1)) + bias
+ logits_exp = np.exp(logits)
+ softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+ softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+ expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+ expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+ loss_mean = np.average([expected_loss_0, expected_loss_1],
+ weights=weights)
+ expected_loss = loss_mean * np.sum(weights)
+
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ ops.GraphKeys.GLOBAL_STEP: 100,
+ metric_keys.MetricKeys.LOSS_MEAN: loss_mean,
+ metric_keys.MetricKeys.ACCURACY: 0.,
+ }
+
+ self.assertAllClose(sorted_key_dict(expected_metrics),
+ sorted_key_dict(eval_metrics), rtol=1e-3)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index c22413001d..293aa75253 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -589,6 +589,7 @@ class Estimator(object):
estimator_spec = self._call_model_fn(features, labels,
model_fn_lib.ModeKeys.TRAIN)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
+ all_hooks.extend(hooks)
all_hooks.extend([
training.NanTensorHook(estimator_spec.loss),
training.LoggingTensorHook(
@@ -598,7 +599,6 @@ class Estimator(object):
},
every_n_iter=100)
])
- all_hooks.extend(hooks)
all_hooks.extend(estimator_spec.training_hooks)
if not (estimator_spec.scaffold.saver or
@@ -725,7 +725,9 @@ def _get_replica_device_setter(config):
"""
ps_ops = [
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
- 'MutableHashTableOfTensors', 'MutableDenseHashTable'
+ 'MutableHashTableV2', 'MutableHashTableOfTensors',
+ 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
+ 'MutableDenseHashTableV2'
]
if config.task_type:
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index c8aab5dac8..b86afece43 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -55,6 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import saver
from tensorflow.python.training import saver_test_utils
@@ -1286,7 +1287,7 @@ class EstimatorExportTest(test.TestCase):
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
# Note that the SavedModel builder replaced the Saver with a new one
- self.assertTrue('save_1/LookupTableImport' in graph_ops)
+ self.assertTrue('save_1/LookupTableImportV2' in graph_ops)
# Clean up.
gfile.DeleteRecursively(tmpdir)
@@ -1520,6 +1521,47 @@ class EstimatorExportTest(test.TestCase):
est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn)
+class EstimatorHookOrderingTest(test.TestCase):
+
+ def testCustomHooksAreCalledBeforeNanTensorHook(self):
+
+ def nan_making_model_fn(mode, features, labels):
+ """A graph that generates NaN's for testing."""
+ del features, labels
+
+ global_step = variables.Variable(
+ 0, dtype=dtypes.int64, name='global_step')
+ inc_global_step = state_ops.assign_add(global_step, 1)
+ nan_const = constant_op.constant(np.nan, dtype=dtypes.float32)
+ loss = control_flow_ops.cond(
+ inc_global_step > 1, lambda: nan_const, lambda: 1.0)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=global_step.read_value(),
+ loss=loss,
+ train_op=inc_global_step)
+
+ def empty_input_fn():
+ return dict(), None
+
+ class AfterRunCountingHook(session_run_hook.SessionRunHook):
+ """Hooks that counts the number of times after_run() is called."""
+
+ def __init__(self):
+ self.after_run_count = 0
+
+ def after_run(self, run_context, run_values):
+ del run_context, run_values
+ self.after_run_count += 1
+
+ test_hook = AfterRunCountingHook()
+ est = estimator.Estimator(model_fn=nan_making_model_fn)
+ with self.assertRaises(basic_session_run_hooks.NanLossDuringTrainingError):
+ est.train(input_fn=empty_input_fn, steps=2, hooks=[test_hook])
+ self.assertEqual(2, test_hook.after_run_count)
+
+
class EstimatorIntegrationTest(test.TestCase):
def test_complete_flow_with_a_simple_linear_model(self):
diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py
index a4ce0340fc..546e7a296d 100644
--- a/tensorflow/python/kernel_tests/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/fft_ops_test.py
@@ -22,8 +22,10 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import spectral_ops
@@ -295,6 +297,39 @@ class RFFTOpsTest(BaseFFTOpsTest):
self._CompareBackward(c2r.astype(np.complex64), rank, (size,) * rank,
use_placeholder=True)
+ def testFftLength(self):
+ if test.is_gpu_available(cuda_only=True):
+ for rank in VALID_FFT_RANKS:
+ for dims in xrange(rank, rank + 3):
+ for size in (5, 6):
+ inner_dim = size // 2 + 1
+ r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
+ (size,) * dims)
+ c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
+ 10).reshape((size,) * (dims - 1) + (inner_dim,))
+
+ # Test truncation (FFT size < dimensions).
+ fft_length = (size - 2,) * rank
+ self._CompareForward(r2c.astype(np.float32), rank, fft_length)
+ self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
+
+ # Confirm it works with unknown shapes as well.
+ self._CompareForward(r2c.astype(np.float32), rank, fft_length,
+ use_placeholder=True)
+ self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
+ use_placeholder=True)
+
+ # Test padding (FFT size > dimensions).
+ fft_length = (size + 2,) * rank
+ self._CompareForward(r2c.astype(np.float32), rank, fft_length)
+ self._CompareBackward(c2r.astype(np.complex64), rank, fft_length)
+
+ # Confirm it works with unknown shapes as well.
+ self._CompareForward(r2c.astype(np.float32), rank, fft_length,
+ use_placeholder=True)
+ self._CompareBackward(c2r.astype(np.complex64), rank, fft_length,
+ use_placeholder=True)
+
def testRandom(self):
np.random.seed(12345)
@@ -324,10 +359,10 @@ class RFFTOpsTest(BaseFFTOpsTest):
for dims in xrange(0, rank):
x = np.zeros((1,) * dims).astype(np.complex64)
with self.assertRaisesWithPredicateMatch(
- ValueError, "Shape must be .*rank {}.*".format(rank)):
+ ValueError, "Shape .* must have rank at least {}".format(rank)):
self._tfFFT(x, rank)
with self.assertRaisesWithPredicateMatch(
- ValueError, "Shape must be .*rank {}.*".format(rank)):
+ ValueError, "Shape .* must have rank at least {}".format(rank)):
self._tfIFFT(x, rank)
for dims in xrange(rank, rank + 2):
x = np.zeros((1,) * rank)
@@ -335,10 +370,10 @@ class RFFTOpsTest(BaseFFTOpsTest):
# Test non-rank-1 fft_length produces an error.
fft_length = np.zeros((1, 1)).astype(np.int32)
with self.assertRaisesWithPredicateMatch(ValueError,
- "Shape must be .*rank 1"):
+ "Shape .* must have rank 1"):
self._tfFFT(x, rank, fft_length)
with self.assertRaisesWithPredicateMatch(ValueError,
- "Shape must be .*rank 1"):
+ "Shape .* must have rank 1"):
self._tfIFFT(x, rank, fft_length)
# Test wrong fft_length length.
@@ -350,24 +385,46 @@ class RFFTOpsTest(BaseFFTOpsTest):
ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
self._tfIFFT(x, rank, fft_length)
+ # Test that calling the kernel directly without padding to fft_length
+ # produces an error.
+ rffts_for_rank = {1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft],
+ 2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d],
+ 3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d]}
+ rfft_fn, irfft_fn = rffts_for_rank[rank]
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ "Input dimension .* must have length of at least 6 but got: 5"):
+ x = np.zeros((5,) * rank).astype(np.float32)
+ fft_length = [6] * rank
+ with self.test_session():
+ rfft_fn(x, fft_length).eval()
+ # TODO(rjryan): Remove when CPU-based IRFFT is supported.
+ if test.is_gpu_available(cuda_only=True):
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ "Input dimension .* must have length of at least .* but got: 3"):
+ x = np.zeros((3,) * rank).astype(np.complex64)
+ fft_length = [6] * rank
+ with self.test_session():
+ irfft_fn(x, fft_length).eval()
+
def testGrad_Simple(self):
- for rank in VALID_FFT_RANKS:
- # rfft3d/irfft3d do not have gradients yet.
- if rank == 3:
- continue
- for dims in xrange(rank, rank + 2):
- for size in (
- 5,
- 6,):
- re = np.ones(shape=(size,) * dims, dtype=np.float32)
- im = -np.ones(shape=(size,) * dims, dtype=np.float32)
- self._checkGradReal(self._tfFFTForRank(rank), re, use_gpu=True)
- self._checkGradComplex(
- self._tfIFFTForRank(rank),
- re,
- im,
- result_is_complex=False,
- use_gpu=True)
+ if test.is_gpu_available(cuda_only=True):
+ for rank in VALID_FFT_RANKS:
+ # rfft3d/irfft3d do not have gradients yet.
+ if rank == 3:
+ continue
+ for dims in xrange(rank, rank + 2):
+ for size in (5, 6):
+ re = np.ones(shape=(size,) * dims, dtype=np.float32)
+ im = -np.ones(shape=(size,) * dims, dtype=np.float32)
+ self._checkGradReal(self._tfFFTForRank(rank), re, use_gpu=True)
+ self._checkGradComplex(
+ self._tfIFFTForRank(rank),
+ re,
+ im,
+ result_is_complex=False,
+ use_gpu=True)
def testGrad_Random(self):
np.random.seed(54321)
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index f026e5ac45..fdf1b134b9 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -159,14 +159,12 @@ class _Conv(base.Layer):
if self.bias is not None:
if self.data_format == 'channels_first':
- # bias_add only supports NHWC.
- # TODO(fchollet): remove this when `bias_add` is feature-complete.
if self.rank == 1:
+ # nn.bias_add does not accept a 1D input tensor.
bias = array_ops.reshape(self.bias, (1, self.filters, 1))
outputs += bias
if self.rank == 2:
- bias = array_ops.reshape(self.bias, (1, self.filters, 1, 1))
- outputs += bias
+ outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
if self.rank == 3:
# As of Mar 2017, direct addition is significantly slower than
# bias_add when computing gradients. To use bias_add, we collapse Z
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index ca72734707..3c3c18b1c9 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -984,7 +984,8 @@ def raw_rnn(cell, loop_fn,
if emit_structure is not None:
flat_emit_structure = nest.flatten(emit_structure)
- flat_emit_size = [emit.get_shape() for emit in flat_emit_structure]
+ flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
+ array_ops.shape(emit) for emit in flat_emit_structure]
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
emit_structure = cell.output_size
diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py
index 95a2806330..47ff7018f2 100644
--- a/tensorflow/python/ops/spectral_ops.py
+++ b/tensorflow/python/ops/spectral_ops.py
@@ -33,6 +33,7 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops
+from tensorflow.python.framework import tensor_util as _tensor_util
from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import math_ops as _math_ops
@@ -70,6 +71,52 @@ def _infer_fft_length_for_irfft(input_tensor, fft_rank):
return _ops.convert_to_tensor(fft_length, _dtypes.int32)
+def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False):
+ """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims."""
+ fft_shape = _tensor_util.constant_value_as_shape(fft_length)
+
+ # Edge case: skip padding empty tensors.
+ if (input_tensor.shape.ndims is not None and
+ any(dim.value == 0 for dim in input_tensor.shape)):
+ return input_tensor
+
+ # If we know the shapes ahead of time, we can either skip or pre-compute the
+ # appropriate paddings. Otherwise, fall back to computing paddings in
+ # TensorFlow.
+ if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None:
+ # Slice the last FFT-rank dimensions from input_tensor's shape.
+ input_fft_shape = input_tensor.shape[-fft_shape.ndims:]
+
+ if input_fft_shape.is_fully_defined():
+ # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
+ if is_reverse:
+ fft_shape = fft_shape[:-1].concatenate(fft_shape[-1].value // 2 + 1)
+
+ paddings = [[0, max(fft_dim.value - input_dim.value, 0)]
+ for fft_dim, input_dim in zip(fft_shape, input_fft_shape)]
+ if any(pad > 0 for _, pad in paddings):
+ outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims -
+ fft_shape.ndims), 0)
+ return _array_ops.pad(input_tensor, outer_paddings + paddings)
+ return input_tensor
+
+ # If we can't determine the paddings ahead of time, then we have to pad. If
+ # the paddings end up as zero, tf.pad has a special-case that does no work.
+ input_rank = _array_ops.rank(input_tensor)
+ input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:]
+ outer_dims = _math_ops.maximum(0, input_rank - fft_rank)
+ outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype)
+ # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1.
+ if is_reverse:
+ fft_length = _array_ops.concat([fft_length[:-1],
+ fft_length[-1:] // 2 + 1], 0)
+ fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape)
+ paddings = _array_ops.concat([outer_paddings, fft_paddings], 0)
+ paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings],
+ axis=1)
+ return _array_ops.pad(input_tensor, paddings)
+
+
def _rfft_wrapper(fft_fn, fft_rank, default_name):
"""Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
@@ -77,10 +124,12 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
with _ops.name_scope(name, default_name,
[input_tensor, fft_length]) as name:
input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32)
+ input_tensor.shape.with_rank_at_least(fft_rank)
if fft_length is None:
fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank)
else:
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
+ input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
return fft_fn(input_tensor, fft_length, name)
_rfft.__doc__ = fft_fn.__doc__
return _rfft
@@ -93,10 +142,13 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
with _ops.name_scope(name, default_name,
[input_tensor, fft_length]) as name:
input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64)
+ input_tensor.shape.with_rank_at_least(fft_rank)
if fft_length is None:
fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank)
else:
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
+ input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
+ is_reverse=True)
return ifft_fn(input_tensor, fft_length, name)
_irfft.__doc__ = ifft_fn.__doc__
return _irfft
diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py
index 6a73565f82..bcabb41304 100644
--- a/tensorflow/python/training/saver_test_utils.py
+++ b/tensorflow/python/training/saver_test_utils.py
@@ -34,7 +34,7 @@ class CheckpointedOp(object):
# pylint: disable=protected-access
def __init__(self, name, table_ref=None):
if table_ref is None:
- self.table_ref = gen_lookup_ops._mutable_hash_table(
+ self.table_ref = gen_lookup_ops._mutable_hash_table_v2(
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
else:
self.table_ref = table_ref
@@ -52,10 +52,10 @@ class CheckpointedOp(object):
return self._saveable
def insert(self, keys, values):
- return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values)
+ return gen_lookup_ops._lookup_table_insert_v2(self.table_ref, keys, values)
def lookup(self, keys, default):
- return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default)
+ return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys, default)
def keys(self):
return self._export()[0]
@@ -64,8 +64,8 @@ class CheckpointedOp(object):
return self._export()[1]
def _export(self):
- return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string,
- dtypes.float32)
+ return gen_lookup_ops._lookup_table_export_v2(self.table_ref, dtypes.string,
+ dtypes.float32)
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
"""A custom saveable for CheckpointedOp."""
@@ -81,6 +81,6 @@ class CheckpointedOp(object):
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
def restore(self, restore_tensors, shapes):
- return gen_lookup_ops._lookup_table_import(
+ return gen_lookup_ops._lookup_table_import_v2(
self.op.table_ref, restore_tensors[0], restore_tensors[1])
# pylint: enable=protected-access
diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD
index caaf1769c0..23581badb6 100644
--- a/tensorflow/tensorboard/BUILD
+++ b/tensorflow/tensorboard/BUILD
@@ -30,6 +30,7 @@ filegroup(
srcs = [
"TAG",
"//tensorflow/tensorboard/components:index.html",
+ "//tensorflow/tensorboard/components:trace_viewer_index.html",
],
)
diff --git a/tensorflow/tensorboard/components/BUILD b/tensorflow/tensorboard/components/BUILD
index 6a0052b793..e287b2c918 100644
--- a/tensorflow/tensorboard/components/BUILD
+++ b/tensorflow/tensorboard/components/BUILD
@@ -22,6 +22,24 @@ tensorboard_html_binary(
deps = [":tensorboard"],
)
+ts_web_library(
+ name = "trace_viewer",
+ srcs = [
+ "trace_viewer.html",
+ ],
+ path = "/",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_trace_viewer",
+ ],
+)
+
+tensorboard_html_binary(
+ name = "trace_viewer_index",
+ input_path = "/trace_viewer.html",
+ output_path = "/trace_viewer_index.html",
+ deps = [":trace_viewer"],
+)
+
filegroup(
name = "all_files",
srcs = glob(["**"]),
diff --git a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html
index 00a30686f6..926c476731 100644
--- a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html
+++ b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html
@@ -36,7 +36,6 @@ limitations under the License.
<link rel="import" href="../tf-backend/tf-backend.html">
<link rel="import" href="../tf-storage/tf-storage.html">
<link rel="import" href="../vz-projector/vz-projector-dashboard.html">
-<link rel="import" href="../vz-projector/bundle.html">
<!--
tf-tensorboard is the frontend entry point for TensorBoard.
diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/BUILD b/tensorflow/tensorboard/components/tf_trace_viewer/BUILD
new file mode 100644
index 0000000000..943229fd8b
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_trace_viewer/BUILD
@@ -0,0 +1,30 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow/tensorboard:web.bzl", "ts_web_library")
+
+licenses(["notice"]) # Apache 2.0
+
+ts_web_library(
+ name = "tf_trace_viewer",
+ srcs = [
+ "tf-trace-viewer.html",
+ "@org_chromium_catapult_vulcanized_trace_viewer//:trace_viewer_full.html",
+ ],
+ path = "/tf-trace-viewer",
+)
+
+ts_web_library(
+ name = "demo",
+ srcs = ["demo.html"],
+ path = "/tf-trace-viewer",
+ deps = [
+ ":tf_trace_viewer",
+ "//tensorflow/tensorboard/components/tf_trace_viewer/data",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD b/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD
new file mode 100644
index 0000000000..f72035d43a
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD
@@ -0,0 +1,17 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "web_library")
+
+licenses(["notice"]) # Apache 2.0
+
+web_library(
+ name = "data",
+ srcs = glob(["*.json"]),
+ path = "/tf-trace-viewer/data/plugin/profile",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json b/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json
new file mode 100644
index 0000000000..e1d57394e3
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json
@@ -0,0 +1,105 @@
+{
+ "traceEvents": [
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "C",
+ "name": "counter", "args": {"value": 10}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "B",
+ "name": "A long name that doesnt fit but is exceedingly informative",
+ "args": {"name_false": false, "value_true": true}},
+ {"cat": "PERF", "pid": 22630, "ts": 835, "ph": "I", "s": "p",
+ "name": "ProcessWideEvent1", "args": {}},
+
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 827, "ph": "B",
+ "name": "Asub with a name that wont fit", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 828, "ph": "E",
+ "name": "Asub", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 829, "ph": "B",
+ "name": "Asub", "args": {}},
+ {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 15, "ts": 820, "ph": "X",
+ "name": "Long X type", "args": {}, "sf": 7, "esf": 8},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "E",
+ "name": "Asub", "args": {}},
+ {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X",
+ "name": "X1", "args": {}},
+ {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X",
+ "name": "X same ts and dur as X1", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "C",
+ "name": "counter", "args": {"value": 1}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 833, "ph": "E",
+ "name": "", "args": {}},
+
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 835, "ph": "I",
+ "name": "ThreadLevelI1", "args": {}},
+
+ {"cat": "PERF", "ts": 880, "ph": "I", "s": "g", "name": "GlobalEvent1",
+ "args": {}},
+
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 837, "ph": "I",
+ "name": "ThreadLevelI2", "args": {}},
+
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 839, "ph": "C",
+ "name": "counter", "args": {"value": 5}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 840, "ph": "B",
+ "name": "A not as long a name", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "E",
+ "name": "A not as long a name", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "C",
+ "name": "counter", "args": {"value": 1}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "C",
+ "name": "counter", "args": {"value": 10}},
+
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 850, "ph": "B",
+ "name": "B", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "E",
+ "name": "B", "args": {}},
+
+ {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 827, "ph": "B",
+ "name": "A", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 835, "ph": "I",
+ "name": "ThreadLevelImmediate Three", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 845, "ph": "I",
+ "name": "ThreadLevelImmediate4", "args": {}},
+ {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 854, "ph": "E",
+ "name": "A", "args": {}},
+
+ {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B",
+ "name": "B/E over X", "args": {}},
+ {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 10, "ts": 860, "ph": "X",
+ "name": "X", "args": {}},
+ {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B",
+ "name": "B/E under X", "args": {}},
+ {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E",
+ "name": "B/E under X", "args": {}},
+ {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E",
+ "name": "B/E over X", "args": {}},
+
+ {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 870, "ph": "P",
+ "name": "SampleA", "args": {}},
+ {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 875, "ph": "P",
+ "name": "SampleB", "args": {}},
+ {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 878, "ph": "P",
+ "name": "SampleC", "args": {}, "sf": 8},
+
+ {"cat": "__metadata", "pid": 22630, "tid": 22630, "ts": 0, "ph": "M",
+ "name": "thread_name", "args": {"name": "threadA"}},
+ {"cat": "__metadata", "pid": 22630, "tid": 22631, "ts": 0, "ph": "M",
+ "name": "thread_name", "args": {"name": "threadB"}},
+ {"cat": "__metadata", "pid": 22630, "tid": 22632, "ts": 0, "ph": "M",
+ "name": "thread_name", "args": {"name": "threadC"}}
+ ],
+ "stackFrames": {
+ "1": {
+ "category": "m1",
+ "name": "main"
+ },
+ "7": {
+ "category": "m2",
+ "name": "frame7",
+ "parent": "1"
+ },
+ "8": {
+ "category": "m2",
+ "name": "frame8",
+ "parent": "1"
+ }
+ }
+}
diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/demo.html b/tensorflow/tensorboard/components/tf_trace_viewer/demo.html
new file mode 100644
index 0000000000..dd0029e967
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_trace_viewer/demo.html
@@ -0,0 +1,30 @@
+<!doctype html>
+<!--
+@license
+Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="tf-trace-viewer.html">
+<title>Trace Viewer Demo</title>
+<style>
+ #container{
+ height: 800px;
+ border: 2px solid grey;
+ }
+</style>
+<div id="container">
+ <tf-trace-viewer trace-data-url="data/plugin/profile/trace.json">
+ </tf-trace-viewer>
+</div>
diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html b/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html
new file mode 100644
index 0000000000..a7b0b2cd73
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html
@@ -0,0 +1,127 @@
+<!--
+@license
+Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="trace_viewer_full.html">
+
+<!--
+tf-trace-viewer is the frontend entry point for Trace Viewer on TensorBoard.
+
+The server serves the trace viewer app at a separate endpoint. TensorBoard
+dashboard would integrate trace viewer app using iframe.
+-->
+<script>
+ "use strict";
+
+ Polymer({
+ is: "tf-trace-viewer",
+ properties: {
+ // The URL of trace data. Provided by caller via URL parameter.
+ traceDataUrl: {
+ type: String,
+ value: null,
+ },
+ _traceData: {
+ type: Object,
+ observer: "_traceDataChanged"
+ },
+ _traceViewer: Object,
+ _traceContainer: Object,
+ _traceModel: Object,
+ },
+ ready: function() {
+ // Initiate the trace viewer app.
+ this._traceContainer = document.createElement("track-view-container");
+ this._traceContainer.id = "track_view_container";
+
+ this._traceViewer = document.createElement("tr-ui-timeline-view");
+ this._traceViewer.track_view_container = this._traceContainer;
+ this._traceViewer.appendChild(this._traceContainer);
+
+ this._traceViewer.id = 'trace-viewer';
+ this._traceViewer.globalMode = true;
+
+ Polymer.dom(this.root).appendChild(this._traceViewer);
+
+ // Retrieve the URL of trace data.
+ var queryString = window.location.href.split("?")[1];
+ if (queryString) {
+ var parts = queryString.split('&')
+ for (var i=0; i<parts.length; i++) {
+ var components = parts[i].split('=');
+ if (components[0] == "trace_data_url") {
+ this.traceDataUrl = decodeURIComponent(components[1]);
+ break;
+ }
+ }
+ }
+
+ this._loadTrace();
+ },
+ _loadTrace : function() {
+ if (!this.traceDataUrl) {
+ this._displayOverlay("Trace data URL is not provided.", "Trace Viewer");
+ return null;
+ }
+ // Send HTTP request to get the trace data.
+ var req = new XMLHttpRequest();
+ var is_binary = / [.] gz$ /.test(this.traceDataUrl) ||
+ / [.] zip$ /.test(this.traceDataUrl);
+ req.overrideMimeType('text/plain; charset=x-user-defined');
+ req.open('GET', this.traceDataUrl, true);
+ if (is_binary) {
+ req.responseType = 'arraybuffer';
+ }
+
+ req.onreadystatechange = function(event) {
+ if (req.readyState !== 4) {
+ return;
+ }
+ window.setTimeout(function() {
+ if (req.status === 200) {
+ this.set("_traceData", is_binary ? req.response : req.responseText);
+ } else {
+ this._displayOverlay(req.status, "Failed to fetch data");
+ }
+ }.bind(this), 0);
+ }.bind(this);
+ req.send(null);
+ },
+ _traceDataChanged: function(data) {
+ if (!data) {
+ this._displayOverlay("Trace Viewer", "No trace to display...");
+ return;
+ }
+ // Feed the trace data into the trace viewer app.
+ this._traceModel = new tr.Model();
+ var i = new tr.importer.Import(this._traceModel);
+ var p = i.importTracesWithProgressDialog([data]);
+ p.then(() => {
+ this._traceViewer.model = this._traceModel;
+ this._traceViewer.viewTitle = "Trace View";
+ }).catch((err) => {
+ this._displayOverlay(
+ 'Import error', tr.b.normalizeException(err).message);
+ });
+ },
+ _displayOverlay: function(title, content) {
+ var overlay = new tr.ui.b.Overlay();
+ overlay.textContent = content;
+ overlay.title = title;
+ overlay.visible = true;
+ },
+ });
+</script>
diff --git a/tensorflow/tensorboard/components/trace_viewer.html b/tensorflow/tensorboard/components/trace_viewer.html
new file mode 100644
index 0000000000..c9bcdc9e20
--- /dev/null
+++ b/tensorflow/tensorboard/components/trace_viewer.html
@@ -0,0 +1,28 @@
+<!doctype html>
+<!--
+@license
+Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<meta charset="utf-8">
+<title>Trace Viewer</title>
+<html>
+<head>
+ <link rel="import" href="tf-trace-viewer/tf-trace-viewer.html" jscomp-nocompile="true">
+ <body>
+ <tf-trace-viewer></tf-trace-viewer>
+ </body>
+</head>
+</html>
diff --git a/tensorflow/tensorboard/components/vz_projector/bundle.html b/tensorflow/tensorboard/components/vz_projector/bundle.html
index de87763673..f5a25230a0 100644
--- a/tensorflow/tensorboard/components/vz_projector/bundle.html
+++ b/tensorflow/tensorboard/components/vz_projector/bundle.html
@@ -46,11 +46,3 @@ limitations under the License.
<script src="scatterPlotVisualizer.js"></script>
<script src="projectorScatterPlotAdapter.js"></script>
<script src="vz-projector-util.js"></script>
-<script src="vz-projector-bookmark-panel.js"></script>
-<script src="vz-projector-data-panel.js"></script>
-<script src="vz-projector-input.js"></script>
-<script src="vz-projector-inspector-panel.js"></script>
-<script src="vz-projector-legend.js"></script>
-<script src="vz-projector-metadata-card.js"></script>
-<script src="vz-projector-projections-panel.js"></script>
-<script src="vz-projector.js"></script>
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-app.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-app.html
index 34aca77dde..e19f0364c4 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-app.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-app.html
@@ -83,8 +83,9 @@ vz-projector {
event-logging="[[eventLogging]]">
</vz-projector>
</div>
+</template>
<!-- Google analytics -->
-<script>
+<script jscomp-nocompile>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
@@ -92,7 +93,6 @@ vz-projector {
ga('create', 'UA-46457317-5', 'auto');
</script>
-</template>
<script>
Polymer({
is: 'vz-projector-app',
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.html
index c37d8d9571..f3f3f59a94 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.html
@@ -21,6 +21,7 @@ limitations under the License.
<link rel="import" href="../paper-icon-button/paper-icon-button.html">
<link rel="import" href="../paper-tooltip/paper-tooltip.html">
<link rel="import" href="styles.html">
+<link rel="import" href="bundle.html">
<dom-module id="vz-projector-bookmark-panel">
<template>
@@ -202,4 +203,5 @@ paper-textarea {
</div>
</template>
+<script src="vz-projector-bookmark-panel.js"></script>
</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html
index 607d446789..d8dfd6e978 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html
@@ -27,8 +27,10 @@ limitations under the License.
<link rel="import" href="../paper-dialog/paper-dialog.html">
<link rel="import" href="../paper-dialog-scrollable/paper-dialog-scrollable.html">
<link rel="import" href="../paper-tooltip/paper-tooltip.html">
+<link rel="import" href="../tf-imports/d3.html">
<link rel="import" href="vz-projector-legend.html">
<link rel="import" href="styles.html">
+<link rel="import" href="bundle.html">
<dom-module id="vz-projector-data-panel">
<template>
@@ -396,4 +398,5 @@ paper-dropdown-menu paper-item {
</div>
<!-- Closing global template -->
</template>
+<script src="vz-projector-data-panel.js"></script>
</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html
index e77694426e..0d7bf7cdda 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html
@@ -20,6 +20,7 @@ limitations under the License.
<link rel="import" href="../paper-button/paper-button.html">
<link rel="import" href="../paper-tooltip/paper-tooltip.html">
<link rel="import" href="styles.html">
+<link rel="import" href="bundle.html">
<dom-module id="vz-projector-input">
<template>
@@ -61,4 +62,5 @@ limitations under the License.
<!-- Closing global template -->
</template>
-</dom-module> \ No newline at end of file
+<script src="vz-projector-input.js"></script>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html
index cb3e8c6479..1b81094776 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html
@@ -17,9 +17,9 @@ limitations under the License.
<link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../paper-slider/paper-slider.html">
-
-<link rel="import" href="styles.html">
<link rel="import" href="vz-projector-input.html">
+<link rel="import" href="styles.html">
+<link rel="import" href="bundle.html">
<dom-module id="vz-projector-inspector-panel">
<style include="vz-projector-styles"></style>
@@ -237,4 +237,5 @@ limitations under the License.
</div>
<!-- Closing global template -->
</template>
+<script src="vz-projector-inspector-panel.js"></script>
</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html
index 3fc5f4db15..4b98d8bded 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html
@@ -17,6 +17,7 @@ limitations under the License.
<link rel="import" href="../polymer/polymer.html">
<link rel="import" href="styles.html">
+<link rel="import" href="bundle.html">
<dom-module id='vz-projector-legend'>
<template>
@@ -73,4 +74,5 @@ limitations under the License.
</template>
<!-- Closing global template -->
</template>
-</dom-module> \ No newline at end of file
+<script src="vz-projector-legend.js"></script>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html
index ebdcd72c77..4231a61ff3 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html
@@ -18,6 +18,7 @@ limitations under the License.
<link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../iron-collapse/iron-collapse.html">
<link rel="import" href="../paper-icon-button/paper-icon-button.html">
+<link rel="import" href="bundle.html">
<dom-module id="vz-projector-metadata-card">
<template>
@@ -94,4 +95,5 @@ limitations under the License.
</div>
</template>
</template>
+<script src="vz-projector-metadata-card.js"></script>
</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html
index cddcb2b7d0..b82f3f520b 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html
@@ -30,6 +30,7 @@ limitations under the License.
<link rel="import" href="../paper-button/paper-button.html">
<link rel="import" href="../paper-slider/paper-slider.html">
<link rel="import" href="styles.html">
+<link rel="import" href="bundle.html">
<dom-module id="vz-projector-projections-panel">
<template>
@@ -311,4 +312,5 @@ limitations under the License.
</div>
</div>
</template>
+<script src="vz-projector-projections-panel.js"></script>
</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.html b/tensorflow/tensorboard/components/vz_projector/vz-projector.html
index d4be2f26a5..438ea9f4e9 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.html
@@ -32,6 +32,7 @@ limitations under the License.
<link rel="import" href="../paper-styles/typography.html">
<link rel="import" href="../paper-spinner/paper-spinner-lite.html">
<link rel="import" href="../paper-dialog-scrollable/paper-dialog-scrollable.html">
+<link rel="import" href="../tf-imports/threejs.html">
<link rel="import" href="vz-projector-bookmark-panel.html">
<link rel="import" href="vz-projector-data-panel.html">
@@ -40,6 +41,7 @@ limitations under the License.
<link rel="import" href="vz-projector-metadata-card.html">
<link rel="import" href="vz-projector-projections-panel.html">
<link rel="import" href="styles.html">
+<link rel="import" href="bundle.html">
<dom-module id="vz-projector">
<template>
@@ -340,4 +342,5 @@ limitations under the License.
<paper-toast id="toast" always-on-top></paper-toast>
</template> <!-- global template -->
+<script src="vz-projector.js"></script>
</dom-module>
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index b14f11cd7c..888390764a 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -496,11 +496,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
temp_workaround_http_archive(
name = "llvm",
urls = [
- "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/c978c0ff91f7c4ea58cfbd8f378e51c6af2c2b4b.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/c978c0ff91f7c4ea58cfbd8f378e51c6af2c2b4b.tar.gz",
+ "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/e156d99231a7735d06a97b5b83de70bf4ce4f034.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/e156d99231a7735d06a97b5b83de70bf4ce4f034.tar.gz",
],
- sha256 = "42c57d798a037d9dea692ce1da8ff4d24966ab5a40494015b374341e43411a37",
- strip_prefix = "llvm-c978c0ff91f7c4ea58cfbd8f378e51c6af2c2b4b",
+ sha256 = "72e34e2411a06d4200a2688ee83832805fbef23a12ea481f31c2b8866fde007a",
+ strip_prefix = "llvm-e156d99231a7735d06a97b5b83de70bf4ce4f034",
build_file = str(Label("//third_party/llvm:llvm.BUILD")),
repository = tf_repo_name,
)
@@ -2704,3 +2704,18 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
path = "/test-fixture",
exclude = ["test/**"],
)
+
+ filegroup_external(
+ name = "org_chromium_catapult_vulcanized_trace_viewer",
+ licenses = ["notice"], # BSD-3-Clause
+ sha256_urls = {
+ "f0df289ba9d03d857ad1c2f5918861376b1510b71588ffc60eff5c7a7bfedb09": [
+ "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE",
+ "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE",
+ ],
+ "9e99e79439ea5a1471bd4dd325bd6733e133bcb3da4df4b878ed6d2aec7c8d86": [
+ "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html",
+ "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html"
+ ],
+ },
+ )
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 2b52a991c4..32266997a7 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -24,23 +24,20 @@ llvm_host_triple = "x86_64-unknown-linux_gnu"
llvm_targets = [
"AArch64",
+ # Uncomment to enable the AMDGPU backend.
+ # TODO(phawkins): use a configure-time test.
+ # "AMDGPU",
"ARM",
"NVPTX",
"PowerPC",
"X86",
]
-llvm_target_asm_parsers = [
- "AArch64",
- "ARM",
- "NVPTX",
- "PowerPC",
- "X86",
-]
+llvm_target_asm_parsers = llvm_targets
-llvm_target_asm_printers = llvm_target_asm_parsers
+llvm_target_asm_printers = llvm_targets
-llvm_target_disassemblers = llvm_target_asm_parsers
+llvm_target_disassemblers = llvm_targets
# TODO(phawkins): the set of CMake variables was hardcoded for expediency.
# However, we should really detect many of these via configure-time tests.
@@ -353,6 +350,26 @@ llvm_target_list = [
],
},
{
+ "name": "AMDGPU",
+ "lower_name": "amdgpu",
+ "short_name": "AMDGPU",
+ "tbl_outs": [
+ ("-gen-register-bank", "lib/Target/AMDGPU/AMDGPUGenRegisterBank.inc"),
+ ("-gen-register-info", "lib/Target/AMDGPU/AMDGPUGenRegisterInfo.inc"),
+ ("-gen-instr-info", "lib/Target/AMDGPU/AMDGPUGenInstrInfo.inc"),
+ ("-gen-dag-isel", "lib/Target/AMDGPU/AMDGPUGenDAGISel.inc"),
+ ("-gen-callingconv", "lib/Target/AMDGPU/AMDGPUGenCallingConv.inc"),
+ ("-gen-subtarget", "lib/Target/AMDGPU/AMDGPUGenSubtargetInfo.inc"),
+ ("-gen-tgt-intrinsic", "lib/Target/AMDGPU/AMDGPUGenIntrinsics.inc"),
+ ("-gen-emitter", "lib/Target/AMDGPU/AMDGPUGenMCCodeEmitter.inc"),
+ ("-gen-dfa-packetizer", "lib/Target/AMDGPU/AMDGPUGenDFAPacketizer.inc"),
+ ("-gen-asm-writer", "lib/Target/AMDGPU/AMDGPUGenAsmWriter.inc"),
+ ("-gen-asm-matcher", "lib/Target/AMDGPU/AMDGPUGenAsmMatcher.inc"),
+ ("-gen-disassembler", "lib/Target/AMDGPU/AMDGPUGenDisassemblerTables.inc"),
+ ("-gen-pseudo-lowering", "lib/Target/AMDGPU/AMDGPUGenMCPseudoLowering.inc"),
+ ],
+ },
+ {
"name": "ARM",
"lower_name": "arm",
"short_name": "ARM",
@@ -436,7 +453,6 @@ llvm_target_list = [
"include/llvm/IR/Intrinsics*.td",
"include/llvm/TableGen/*.td",
"include/llvm/Target/*.td",
- "include/llvm/Target/GlobalISel/*.td",
]),
)
for target in llvm_target_list
@@ -648,6 +664,7 @@ cc_library(
"include/llvm/Analysis/*.inc",
]),
deps = [
+ ":binary_format",
":config",
":core",
":object",
@@ -657,6 +674,184 @@ cc_library(
)
cc_library(
+ name = "amdgpu_desc",
+ srcs = glob([
+ "lib/Target/AMDGPU/MCTargetDesc/*.c",
+ "lib/Target/AMDGPU/MCTargetDesc/*.cpp",
+ "lib/Target/AMDGPU/MCTargetDesc/*.inc",
+ ]),
+ hdrs = glob([
+ "include/llvm/Target/AMDGPU/MCTargetDesc/*.h",
+ "include/llvm/Target/AMDGPU/MCTargetDesc/*.def",
+ "include/llvm/Target/AMDGPU/MCTargetDesc/*.inc",
+ "lib/Target/AMDGPU/MCTargetDesc/*.h",
+ ]),
+ copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ deps = [
+ ":amdgpu_asm_printer",
+ ":amdgpu_info",
+ ":amdgpu_utils",
+ ":config",
+ ":core",
+ ":mc",
+ ":support",
+ ],
+)
+
+cc_library(
+ name = "amdgpu_disassembler",
+ srcs = glob([
+ "lib/Target/AMDGPU/Disassembler/*.c",
+ "lib/Target/AMDGPU/Disassembler/*.cpp",
+ "lib/Target/AMDGPU/Disassembler/*.inc",
+ ]),
+ hdrs = glob([
+ "include/llvm/Target/AMDGPU/Disassembler/*.h",
+ "include/llvm/Target/AMDGPU/Disassembler/*.def",
+ "include/llvm/Target/AMDGPU/Disassembler/*.inc",
+ "lib/Target/AMDGPU/Disassembler/*.h",
+ ]),
+ copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ deps = [
+ ":amdgpu_desc",
+ ":amdgpu_info",
+ ":amdgpu_utils",
+ ":config",
+ ":mc",
+ ":mc_disassembler",
+ ":support",
+ ],
+)
+
+cc_library(
+ name = "amdgpu_info",
+ srcs = glob([
+ "lib/Target/AMDGPU/TargetInfo/*.c",
+ "lib/Target/AMDGPU/TargetInfo/*.cpp",
+ "lib/Target/AMDGPU/TargetInfo/*.inc",
+ ]),
+ hdrs = glob([
+ "include/llvm/Target/AMDGPU/TargetInfo/*.h",
+ "include/llvm/Target/AMDGPU/TargetInfo/*.def",
+ "include/llvm/Target/AMDGPU/TargetInfo/*.inc",
+ "lib/Target/AMDGPU/TargetInfo/*.h",
+ ]),
+ copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ deps = [
+ ":amdgpu_target_gen",
+ ":config",
+ ":core",
+ ":support",
+ ],
+)
+
+cc_library(
+ name = "amdgpu_utils",
+ srcs = glob([
+ "lib/Target/AMDGPU/Utils/*.c",
+ "lib/Target/AMDGPU/Utils/*.cpp",
+ "lib/Target/AMDGPU/Utils/*.inc",
+ ]),
+ hdrs = glob([
+ "include/llvm/Target/AMDGPU/Utils/*.h",
+ "include/llvm/Target/AMDGPU/Utils/*.def",
+ "include/llvm/Target/AMDGPU/Utils/*.inc",
+ "lib/Target/AMDGPU/Utils/*.h",
+ ]),
+ copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ deps = [
+ ":amdgpu_target_gen",
+ ":config",
+ ":core",
+ ":mc",
+ ":support",
+ ],
+)
+
+cc_library(
+ name = "amdgpu_asm_parser",
+ srcs = glob([
+ "lib/Target/AMDGPU/AsmParser/*.c",
+ "lib/Target/AMDGPU/AsmParser/*.cpp",
+ "lib/Target/AMDGPU/AsmParser/*.inc",
+ ]),
+ hdrs = glob([
+ "include/llvm/Target/AMDGPU/AsmParser/*.h",
+ "include/llvm/Target/AMDGPU/AsmParser/*.def",
+ "include/llvm/Target/AMDGPU/AsmParser/*.inc",
+ "lib/Target/AMDGPU/AsmParser/*.h",
+ ]),
+ copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ deps = [
+ ":amdgpu_desc",
+ ":amdgpu_info",
+ ":amdgpu_utils",
+ ":config",
+ ":mc",
+ ":mc_parser",
+ ":support",
+ ],
+)
+
+cc_library(
+ name = "amdgpu_asm_printer",
+ srcs = glob([
+ "lib/Target/AMDGPU/InstPrinter/*.c",
+ "lib/Target/AMDGPU/InstPrinter/*.cpp",
+ "lib/Target/AMDGPU/InstPrinter/*.inc",
+ ]),
+ hdrs = glob([
+ "include/llvm/Target/AMDGPU/InstPrinter/*.h",
+ "include/llvm/Target/AMDGPU/InstPrinter/*.def",
+ "include/llvm/Target/AMDGPU/InstPrinter/*.inc",
+ "lib/Target/AMDGPU/InstPrinter/*.h",
+ ]),
+ copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ deps = [
+ ":amdgpu_utils",
+ ":config",
+ ":mc",
+ ":support",
+ ],
+)
+
+cc_library(
+ name = "amdgpu_code_gen",
+ srcs = glob([
+ "lib/Target/AMDGPU/*.c",
+ "lib/Target/AMDGPU/*.cpp",
+ "lib/Target/AMDGPU/*.inc",
+ ]),
+ hdrs = glob([
+ "include/llvm/Target/AMDGPU/*.h",
+ "include/llvm/Target/AMDGPU/*.def",
+ "include/llvm/Target/AMDGPU/*.inc",
+ "lib/Target/AMDGPU/*.h",
+ ]),
+ copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ deps = [
+ ":amdgpu_asm_printer",
+ ":amdgpu_desc",
+ ":amdgpu_info",
+ ":amdgpu_utils",
+ ":analysis",
+ ":asm_printer",
+ ":code_gen",
+ ":config",
+ ":core",
+ ":global_i_sel",
+ ":ipo",
+ ":mc",
+ ":scalar",
+ ":selection_dag",
+ ":support",
+ ":target",
+ ":transform_utils",
+ ":vectorize",
+ ],
+)
+
+cc_library(
name = "arm_asm_parser",
srcs = glob([
"lib/Target/ARM/AsmParser/*.c",
@@ -824,6 +1019,7 @@ cc_library(
"include/llvm/AsmParser/*.inc",
]),
deps = [
+ ":binary_format",
":config",
":core",
":support",
@@ -842,9 +1038,11 @@ cc_library(
"include/llvm/CodeGen/AsmPrinter/*.h",
"include/llvm/CodeGen/AsmPrinter/*.def",
"include/llvm/CodeGen/AsmPrinter/*.inc",
+ "lib/CodeGen/AsmPrinter/*.def",
]),
deps = [
":analysis",
+ ":binary_format",
":code_gen",
":config",
":core",
@@ -858,6 +1056,25 @@ cc_library(
)
cc_library(
+ name = "binary_format",
+ srcs = glob([
+ "lib/BinaryFormat/*.c",
+ "lib/BinaryFormat/*.cpp",
+ "lib/BinaryFormat/*.inc",
+ "lib/BinaryFormat/*.h",
+ ]),
+ hdrs = glob([
+ "include/llvm/BinaryFormat/*.h",
+ "include/llvm/BinaryFormat/*.def",
+ "include/llvm/BinaryFormat/*.inc",
+ ]),
+ deps = [
+ ":config",
+ ":support",
+ ],
+)
+
+cc_library(
name = "bit_reader",
srcs = glob([
"lib/Bitcode/Reader/*.c",
@@ -956,6 +1173,7 @@ cc_library(
deps = [
":attributes_compat_gen",
":attributes_gen",
+ ":binary_format",
":config",
":intrinsics_gen",
":support",
@@ -1376,6 +1594,7 @@ cc_library(
"include/llvm/Object/*.inc",
]),
deps = [
+ ":binary_format",
":bit_reader",
":config",
":core",