From fdb8e29354ce93afa8c2335a6287a59eb37d42fc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 10:59:18 -0700 Subject: Update iOS examples to use CocoaPods, and moved to tensorflow/examples/ios PiperOrigin-RevId: 158289285 --- tensorflow/contrib/ios_examples/.gitignore | 4 - tensorflow/contrib/ios_examples/README.md | 140 ----- .../contrib/ios_examples/benchmark/AppDelegate.h | 21 - .../contrib/ios_examples/benchmark/AppDelegate.mm | 44 -- .../ios_examples/benchmark/Benchmark-Info.plist | 47 -- .../benchmark/BenchmarkViewController.h | 24 - .../benchmark/BenchmarkViewController.mm | 302 ---------- .../benchmark/BenchmarkViewController.xib | 47 -- .../benchmark/benchmark.xcodeproj/project.pbxproj | 367 ------------ .../ios_examples/benchmark/data/grace_hopper.jpg | Bin 73746 -> 0 bytes .../ios_examples/benchmark/ios_image_load.h | 27 - .../ios_examples/benchmark/ios_image_load.mm | 87 --- tensorflow/contrib/ios_examples/benchmark/main.mm | 22 - .../ios_examples/camera/CameraExampleAppDelegate.h | 21 - .../ios_examples/camera/CameraExampleAppDelegate.m | 44 -- .../camera/CameraExampleViewController.h | 46 -- .../camera/CameraExampleViewController.mm | 596 -------------------- tensorflow/contrib/ios_examples/camera/Info.plist | 44 -- .../camera_example.xcodeproj/project.pbxproj | 431 -------------- .../en.lproj/MainStoryboard_iPhone.storyboard | 46 -- .../contrib/ios_examples/camera/ios_image_load.h | 27 - .../contrib/ios_examples/camera/ios_image_load.mm | 87 --- tensorflow/contrib/ios_examples/camera/main.mm | 27 - .../contrib/ios_examples/camera/tensorflow_utils.h | 52 -- .../ios_examples/camera/tensorflow_utils.mm | 231 -------- .../contrib/ios_examples/simple/AppDelegate.h | 21 - .../contrib/ios_examples/simple/AppDelegate.mm | 44 -- .../ios_examples/simple/RunModel-Info.plist | 47 -- .../ios_examples/simple/RunModelViewController.h | 24 - .../ios_examples/simple/RunModelViewController.mm | 263 --------- .../ios_examples/simple/RunModelViewController.xib | 46 -- .../ios_examples/simple/data/grace_hopper.jpg | Bin 73746 -> 0 bytes .../contrib/ios_examples/simple/ios_image_load.h | 27 - .../contrib/ios_examples/simple/ios_image_load.mm | 87 --- tensorflow/contrib/ios_examples/simple/main.mm | 22 - .../project.pbxproj | 377 ------------- tensorflow/examples/ios/.gitignore | 4 + tensorflow/examples/ios/README.md | 194 +++++++ tensorflow/examples/ios/benchmark/AppDelegate.h | 21 + tensorflow/examples/ios/benchmark/AppDelegate.mm | 44 ++ .../examples/ios/benchmark/Benchmark-Info.plist | 47 ++ .../ios/benchmark/BenchmarkViewController.h | 24 + .../ios/benchmark/BenchmarkViewController.mm | 302 ++++++++++ .../ios/benchmark/BenchmarkViewController.xib | 47 ++ tensorflow/examples/ios/benchmark/Podfile | 5 + .../examples/ios/benchmark/data/grace_hopper.jpg | Bin 0 -> 73746 bytes tensorflow/examples/ios/benchmark/ios_image_load.h | 27 + .../examples/ios/benchmark/ios_image_load.mm | 87 +++ tensorflow/examples/ios/benchmark/main.mm | 22 + .../tf_benchmark_example.xcodeproj/project.pbxproj | 388 +++++++++++++ .../examples/ios/camera/CameraExampleAppDelegate.h | 21 + .../examples/ios/camera/CameraExampleAppDelegate.m | 44 ++ .../ios/camera/CameraExampleViewController.h | 47 ++ .../ios/camera/CameraExampleViewController.mm | 621 +++++++++++++++++++++ tensorflow/examples/ios/camera/Info.plist | 44 ++ .../ios/camera/MainStoryboard_iPhone.storyboard | 46 ++ tensorflow/examples/ios/camera/Podfile | 5 + .../examples/ios/camera/data/grace_hopper.jpg | Bin 0 -> 73746 bytes tensorflow/examples/ios/camera/ios_image_load.h | 27 + tensorflow/examples/ios/camera/ios_image_load.mm | 87 +++ tensorflow/examples/ios/camera/main.mm | 27 + tensorflow/examples/ios/camera/tensorflow_utils.h | 52 ++ tensorflow/examples/ios/camera/tensorflow_utils.mm | 219 ++++++++ .../tf_camera_example.xcodeproj/project.pbxproj | 412 ++++++++++++++ tensorflow/examples/ios/simple/AppDelegate.h | 21 + tensorflow/examples/ios/simple/AppDelegate.mm | 44 ++ tensorflow/examples/ios/simple/Podfile | 5 + tensorflow/examples/ios/simple/RunModel-Info.plist | 47 ++ .../examples/ios/simple/RunModelViewController.h | 24 + .../examples/ios/simple/RunModelViewController.mm | 253 +++++++++ .../examples/ios/simple/RunModelViewController.xib | 46 ++ .../examples/ios/simple/data/grace_hopper.jpg | Bin 0 -> 73746 bytes tensorflow/examples/ios/simple/ios_image_load.h | 27 + tensorflow/examples/ios/simple/ios_image_load.mm | 87 +++ tensorflow/examples/ios/simple/main.mm | 22 + .../tf_simple_example.xcodeproj/project.pbxproj | 404 ++++++++++++++ 76 files changed, 3844 insertions(+), 3742 deletions(-) delete mode 100644 tensorflow/contrib/ios_examples/.gitignore delete mode 100644 tensorflow/contrib/ios_examples/README.md delete mode 100644 tensorflow/contrib/ios_examples/benchmark/AppDelegate.h delete mode 100644 tensorflow/contrib/ios_examples/benchmark/AppDelegate.mm delete mode 100644 tensorflow/contrib/ios_examples/benchmark/Benchmark-Info.plist delete mode 100644 tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.h delete mode 100644 tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.mm delete mode 100644 tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.xib delete mode 100644 tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj delete mode 100644 tensorflow/contrib/ios_examples/benchmark/data/grace_hopper.jpg delete mode 100644 tensorflow/contrib/ios_examples/benchmark/ios_image_load.h delete mode 100644 tensorflow/contrib/ios_examples/benchmark/ios_image_load.mm delete mode 100644 tensorflow/contrib/ios_examples/benchmark/main.mm delete mode 100644 tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.h delete mode 100644 tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.m delete mode 100644 tensorflow/contrib/ios_examples/camera/CameraExampleViewController.h delete mode 100644 tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm delete mode 100644 tensorflow/contrib/ios_examples/camera/Info.plist delete mode 100644 tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj delete mode 100644 tensorflow/contrib/ios_examples/camera/en.lproj/MainStoryboard_iPhone.storyboard delete mode 100644 tensorflow/contrib/ios_examples/camera/ios_image_load.h delete mode 100644 tensorflow/contrib/ios_examples/camera/ios_image_load.mm delete mode 100644 tensorflow/contrib/ios_examples/camera/main.mm delete mode 100644 tensorflow/contrib/ios_examples/camera/tensorflow_utils.h delete mode 100644 tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm delete mode 100644 tensorflow/contrib/ios_examples/simple/AppDelegate.h delete mode 100644 tensorflow/contrib/ios_examples/simple/AppDelegate.mm delete mode 100644 tensorflow/contrib/ios_examples/simple/RunModel-Info.plist delete mode 100644 tensorflow/contrib/ios_examples/simple/RunModelViewController.h delete mode 100644 tensorflow/contrib/ios_examples/simple/RunModelViewController.mm delete mode 100644 tensorflow/contrib/ios_examples/simple/RunModelViewController.xib delete mode 100644 tensorflow/contrib/ios_examples/simple/data/grace_hopper.jpg delete mode 100644 tensorflow/contrib/ios_examples/simple/ios_image_load.h delete mode 100644 tensorflow/contrib/ios_examples/simple/ios_image_load.mm delete mode 100644 tensorflow/contrib/ios_examples/simple/main.mm delete mode 100644 tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj create mode 100644 tensorflow/examples/ios/.gitignore create mode 100644 tensorflow/examples/ios/README.md create mode 100644 tensorflow/examples/ios/benchmark/AppDelegate.h create mode 100644 tensorflow/examples/ios/benchmark/AppDelegate.mm create mode 100644 tensorflow/examples/ios/benchmark/Benchmark-Info.plist create mode 100644 tensorflow/examples/ios/benchmark/BenchmarkViewController.h create mode 100644 tensorflow/examples/ios/benchmark/BenchmarkViewController.mm create mode 100644 tensorflow/examples/ios/benchmark/BenchmarkViewController.xib create mode 100644 tensorflow/examples/ios/benchmark/Podfile create mode 100644 tensorflow/examples/ios/benchmark/data/grace_hopper.jpg create mode 100644 tensorflow/examples/ios/benchmark/ios_image_load.h create mode 100644 tensorflow/examples/ios/benchmark/ios_image_load.mm create mode 100644 tensorflow/examples/ios/benchmark/main.mm create mode 100644 tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj create mode 100644 tensorflow/examples/ios/camera/CameraExampleAppDelegate.h create mode 100644 tensorflow/examples/ios/camera/CameraExampleAppDelegate.m create mode 100644 tensorflow/examples/ios/camera/CameraExampleViewController.h create mode 100644 tensorflow/examples/ios/camera/CameraExampleViewController.mm create mode 100644 tensorflow/examples/ios/camera/Info.plist create mode 100644 tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard create mode 100644 tensorflow/examples/ios/camera/Podfile create mode 100644 tensorflow/examples/ios/camera/data/grace_hopper.jpg create mode 100644 tensorflow/examples/ios/camera/ios_image_load.h create mode 100644 tensorflow/examples/ios/camera/ios_image_load.mm create mode 100644 tensorflow/examples/ios/camera/main.mm create mode 100644 tensorflow/examples/ios/camera/tensorflow_utils.h create mode 100644 tensorflow/examples/ios/camera/tensorflow_utils.mm create mode 100644 tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj create mode 100644 tensorflow/examples/ios/simple/AppDelegate.h create mode 100644 tensorflow/examples/ios/simple/AppDelegate.mm create mode 100644 tensorflow/examples/ios/simple/Podfile create mode 100644 tensorflow/examples/ios/simple/RunModel-Info.plist create mode 100644 tensorflow/examples/ios/simple/RunModelViewController.h create mode 100644 tensorflow/examples/ios/simple/RunModelViewController.mm create mode 100644 tensorflow/examples/ios/simple/RunModelViewController.xib create mode 100644 tensorflow/examples/ios/simple/data/grace_hopper.jpg create mode 100644 tensorflow/examples/ios/simple/ios_image_load.h create mode 100644 tensorflow/examples/ios/simple/ios_image_load.mm create mode 100644 tensorflow/examples/ios/simple/main.mm create mode 100644 tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj diff --git a/tensorflow/contrib/ios_examples/.gitignore b/tensorflow/contrib/ios_examples/.gitignore deleted file mode 100644 index e572b3012c..0000000000 --- a/tensorflow/contrib/ios_examples/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -project.xcworkspace -xcuserdata -imagenet_comp_graph_label_strings.txt -tensorflow_inception_graph.pb diff --git a/tensorflow/contrib/ios_examples/README.md b/tensorflow/contrib/ios_examples/README.md deleted file mode 100644 index 6bac33c0ec..0000000000 --- a/tensorflow/contrib/ios_examples/README.md +++ /dev/null @@ -1,140 +0,0 @@ -# TensorFlow iOS Examples - -This folder contains examples of how to build applications for iOS devices using TensorFlow. - -## Building the Examples - - - You'll need Xcode 7.3 or later, with the command-line tools installed. - - - Follow the instructions at - [tensorflow/contrib/makefile](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile) - under "iOS" to compile a static library containing the core TensorFlow code. - - - From the root of the Tensorflow folder, download - [Inception v1](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip), - and extract the label and graph files into the data folders inside both the - simple and camera examples: - -```bash -mkdir -p ~/graphs -curl -o ~/graphs/inception5h.zip \ - https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip \ - && unzip ~/graphs/inception5h.zip -d ~/graphs/inception5h -cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/benchmark/data/ -cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/camera/data/ -cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/simple/data/ -``` - - - Load the Xcode project inside the `simple` subfolder, and press Command-R to - build and run it on the simulator or your connected device. - - - You should see a single-screen app with a "Run Model" button. Tap that, and - you should see some debug output appear below indicating that the example - Grace Hopper image has been analyzed, with a military uniform recognized. - - - Once you have success there, make sure you have a real device connected and - open up the Xcode project in the `camera` subfolder. Once you build and run - that, you should get a live camera view that you can point at objects to get - real-time recognition results. - -## Troubleshooting - -If you're hitting problems, here's a checklist of common things to investigate: - - - Make sure that you've run the `build_all_ios.sh` script. - This will run `download_dependencies.sh`,`compile_ios_protobuf.sh` and `compile_ios_tensorflow.sh`. - (check each one if they have run successful.) - - - Check that you have version 7.3 of Xcode. - - - If there's a complaint about no Sessions registered, that means that the C++ - global constructors that TensorFlow relies on for registration haven't been - linked in properly. You'll have to make sure your project uses force_load, as - described below. - -## Creating your Own App - -You'll need to update various settings in your app to link against -TensorFlow. You can view them in the example projects, but here's a full -rundown: - - - The `compile_ios_tensorflow.sh` script builds a universal static library in - `tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a`. You'll need to add - this to your linking build stage, and in Search Paths add - `tensorflow/contrib/makefile/gen/lib` to the Library Search Paths setting. - - - You'll also need to add `libprotobuf.a` and `libprotobuf-lite.a` from - `tensorflow/contrib/makefile/gen/protobuf_ios/lib` to your _Build Stages_ and - _Library Search Paths_. - - - The _Header Search_ paths needs to contain: - - the root folder of tensorflow, - - `tensorflow/contrib/makefile/downloads/protobuf/src` - - `tensorflow/contrib/makefile/downloads`, - - `tensorflow/contrib/makefile/downloads/eigen`, and - - `tensorflow/contrib/makefile/gen/proto`. - - - In the Linking section, you need to add `-force_load` followed by the path to - the TensorFlow static library in the _Other Linker_ Flags section. This ensures - that the global C++ objects that are used to register important classes - inside the library are not stripped out. To the linker, they can appear - unused because no other code references the variables, but in fact their - constructors have the important side effect of registering the class. - - - You'll need to include the Accelerate framework in the "Link Binary with - Libraries" build phase of your project. - - - C++11 support (or later) should be enabled by setting `C++ Language Dialect` to - `GNU++11` (or `GNU++14`), and `C++ Standard Library` to `libc++`. - - - The library doesn't currently support bitcode, so you'll need to disable that - in your project settings. - - - Remove any use of the `-all_load` flag in your project. The protocol buffers - libraries (full and lite versions) contain duplicate symbols, and the `-all_load` - flag will cause these duplicates to become link errors. If you were using - `-all_load` to avoid issues with Objective-C categories in static libraries, - you may be able to replace it with the `-ObjC` flag. - -## Reducing the binary size - -TensorFlow is a comparatively large library for a mobile device, so it will -increase the size of your app. Currently on iOS we see around a 11 MB binary -footprint per CPU architecture, though we're actively working on reducing that. -It can be tricky to set up the right configuration in your own app to keep the -size minimized, so if you do run into this issue we recommend you start by -looking at the simple example to examine its size. Here's how you do that: - - - Open the Xcode project in tensorflow/contrib/ios_examples/simple. - - - Make sure you've followed the steps above to get the data files. - - - Choose "Generic iOS Device" as the build configuration. - - - Select Product->Build. - - - Once the build's complete, open the Report Navigator and select the logs. - - - Near the bottom, you'll see a line saying "Touch tf_ios_makefile_example.app". - - - Expand that line using the icon on the right, and copy the first argument to - the Touch command. - - - Go to the terminal, type `ls -lah ` and then paste the path you copied. - - - For example it might look like `ls -lah /Users/petewarden/Library/Developer/Xcode/DerivedData/tf_ios_makefile_example-etdbksqytcnzeyfgdwiihzkqpxwr/Build/Products/Debug-iphoneos/tf_ios_makefile_example.app` - - - Running this command will show the size of the executable as the - `tf_ios_makefile_example` line. - -Right now you'll see a size of around 23 MB, since it's including two -architectures (armv7 and arm64). As a first step, you should make sure the size -increase you see in your own app is similar, and if it's larger, look at the -"Other Linker Flags" used in the Simple Xcode project settings to strip the -executable. - -After that, you can manually look at modifying the list of kernels -included in tensorflow/contrib/makefile/tf_op_files.txt to reduce the number of -implementations to the ones you're actually using in your own model. We're -hoping to automate this step in the future, but for now manually removing them -is the best approach. diff --git a/tensorflow/contrib/ios_examples/benchmark/AppDelegate.h b/tensorflow/contrib/ios_examples/benchmark/AppDelegate.h deleted file mode 100644 index 94046d9728..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/AppDelegate.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface AppDelegate : UIResponder - -@property(strong, nonatomic) UIWindow *window; - -@end diff --git a/tensorflow/contrib/ios_examples/benchmark/AppDelegate.mm b/tensorflow/contrib/ios_examples/benchmark/AppDelegate.mm deleted file mode 100644 index 23ffba0f7b..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/AppDelegate.mm +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "AppDelegate.h" - -#import "BenchmarkViewController.h" - -@implementation AppDelegate - -- (BOOL)application:(UIApplication *)application - didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { - - UITabBarController *bar = [[UITabBarController alloc] init]; - [bar setViewControllers: - @[[[BenchmarkViewController alloc] init]]]; - bar.selectedIndex = 0; - self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; - self.window.rootViewController = bar; - [self.window makeKeyAndVisible]; - return YES; -} - -- (void)applicationWillResignActive:(UIApplication *)application {} - -- (void)applicationDidEnterBackground:(UIApplication *)application {} - -- (void)applicationWillEnterForeground:(UIApplication *)application {} - -- (void)applicationDidBecomeActive:(UIApplication *)application {} - -- (void)applicationWillTerminate:(UIApplication *)application {} - -@end diff --git a/tensorflow/contrib/ios_examples/benchmark/Benchmark-Info.plist b/tensorflow/contrib/ios_examples/benchmark/Benchmark-Info.plist deleted file mode 100644 index 8d17162b87..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/Benchmark-Info.plist +++ /dev/null @@ -1,47 +0,0 @@ - - - - - CFBundleDevelopmentRegion - en - CFBundleDisplayName - TF Benchmark - CFBundleExecutable - benchmark - CFBundleIdentifier - Google.Benchmark - CFBundleInfoDictionaryVersion - 6.0 - CFBundleName - ios-app - CFBundlePackageType - APPL - CFBundleShortVersionString - 1.0 - CFBundleSignature - ???? - CFBundleVersion - 1.0 - LSRequiresIPhoneOS - - UILaunchStoryboardName - BenchmarkViewController - UIRequiredDeviceCapabilities - - armv7 - - UISupportedInterfaceOrientations - - UIInterfaceOrientationPortrait - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - UISupportedInterfaceOrientations~ipad - - UIInterfaceOrientationPortrait - UIInterfaceOrientationPortraitUpsideDown - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - - diff --git a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.h b/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.h deleted file mode 100644 index c9cbc49280..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface BenchmarkViewController : UIViewController - -- (IBAction)getUrl:(id)sender; - -@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; -@property(weak, nonatomic) IBOutlet UITextField *urlTextField; - -@end diff --git a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.mm b/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.mm deleted file mode 100644 index 4421c88651..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.mm +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "BenchmarkViewController.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "google/protobuf/io/zero_copy_stream_impl_lite.h" -#include "google/protobuf/message_lite.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/util/stat_summarizer.h" - -#include "ios_image_load.h" - -NSString* RunInferenceOnImage(); - -namespace { -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const std::string& file_name) - : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -@interface BenchmarkViewController () -@end - -@implementation BenchmarkViewController { -} - -- (IBAction)getUrl:(id)sender { - NSString* inference_result = RunInferenceOnImage(); - self.urlContentTextView.text = inference_result; -} - -@end - -// Returns the top N confidence values over threshold in the provided vector, -// sorted by confidence in descending order. -static void GetTopN( - const Eigen::TensorMap, - Eigen::Aligned>& prediction, - const int num_results, const float threshold, - std::vector>* top_results) { - // Will contain top N results in ascending order. - std::priority_queue, std::vector>, - std::greater>> - top_result_pq; - - const int count = prediction.size(); - for (int i = 0; i < count; ++i) { - const float value = prediction(i); - - // Only add it if it beats the threshold and has a chance at being in - // the top N. - if (value < threshold) { - continue; - } - - top_result_pq.push(std::pair(value, i)); - - // If at capacity, kick the smallest value out. - if (top_result_pq.size() > num_results) { - top_result_pq.pop(); - } - } - - // Copy to output vector and reverse into descending order. - while (!top_result_pq.empty()) { - top_results->push_back(top_result_pq.top()); - top_result_pq.pop(); - } - std::reverse(top_results->begin(), top_results->end()); -} - -bool PortableReadFileToProto(const std::string& file_name, - ::google::protobuf::MessageLite* proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(file_name)); - stream.SetOwnsCopyingStream(true); - // TODO(jiayq): the following coded stream is for debugging purposes to allow - // one to parse arbitrarily large messages for MessageLite. One most likely - // doesn't want to put protobufs larger than 64MB on Android, so we should - // eventually remove this and quit loud when a large protobuf is passed in. - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -NSString* FilePathForResourceName(NSString* name, NSString* extension) { - NSString* file_path = - [[NSBundle mainBundle] pathForResource:name ofType:extension]; - if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; - } - return file_path; -} - -// A utility function to get the current time in seconds, for simple profiling. -double time() { - timeval t; - gettimeofday(&t, nullptr); - return t.tv_sec + 1e-6 * t.tv_usec; -} - -// Runs the session with profiling enabled, and prints out details of the time -// that each node in the graph takes to the debug log. -tensorflow::Status BenchmarkInference( - tensorflow::Session* session, - const std::vector> inputs, - const std::vector& output_layer_names, - std::vector* output_layers, - tensorflow::StatSummarizer* stat_summarizer, double* average_time) { - tensorflow::Status run_status; - const int iterations_count = 20; - double total_time = 0.0; - tensorflow::RunOptions run_options; - run_options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); - tensorflow::RunMetadata run_metadata; - for (int iteration = 0; iteration < (iterations_count + 1); ++iteration) { - const double start_time = time(); - run_status = session->Run(run_options, inputs, output_layer_names, {}, - output_layers, &run_metadata); - const double end_time = time(); - if (iteration != 0) { - total_time += end_time - start_time; - } - if (!run_status.ok()) { - LOG(ERROR) << "Running model failed: " << run_status; - tensorflow::LogAllRegisteredKernels(); - return run_status; - } - } - assert(run_metadata.has_step_stats()); - const tensorflow::StepStats& step_stats = run_metadata.step_stats(); - stat_summarizer->ProcessStepStats(step_stats); - stat_summarizer->PrintStepStats(); - - *average_time = total_time / iterations_count; - NSLog(@"Took %f seconds", average_time); - - return tensorflow::Status::OK(); -} - -NSString* RunInferenceOnImage() { - tensorflow::SessionOptions options; - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = - tensorflow::NewSession(options, &session_pointer); - if (!session_status.ok()) { - std::string status_string = session_status.ToString(); - return [NSString - stringWithFormat:@"Session create failed - %s", status_string.c_str()]; - } - std::unique_ptr session(session_pointer); - LOG(INFO) << "Session created."; - - tensorflow::GraphDef tensorflow_graph; - LOG(INFO) << "Graph created."; - - NSString* network_path = - FilePathForResourceName(@"tensorflow_inception_graph", @"pb"); - PortableReadFileToProto([network_path UTF8String], &tensorflow_graph); - - LOG(INFO) << "Creating session."; - tensorflow::Status s = session->Create(tensorflow_graph); - if (!s.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << s; - return @""; - } - - // Read the label list - NSString* labels_path = - FilePathForResourceName(@"imagenet_comp_graph_label_strings", @"txt"); - std::vector label_strings; - std::ifstream t; - t.open([labels_path UTF8String]); - std::string line; - while (t) { - std::getline(t, line); - label_strings.push_back(line); - } - t.close(); - - // Read the Grace Hopper image. - NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); - int image_width; - int image_height; - int image_channels; - std::vector image_data = LoadImageFromFile( - [image_path UTF8String], &image_width, &image_height, &image_channels); - const int wanted_width = 224; - const int wanted_height = 224; - const int wanted_channels = 3; - const float input_mean = 117.0f; - const float input_std = 1.0f; - assert(image_channels >= wanted_channels); - tensorflow::Tensor image_tensor( - tensorflow::DT_FLOAT, - tensorflow::TensorShape( - {1, wanted_height, wanted_width, wanted_channels})); - auto image_tensor_mapped = image_tensor.tensor(); - tensorflow::uint8* in = image_data.data(); - float* out = image_tensor_mapped.data(); - for (int y = 0; y < wanted_height; ++y) { - const int in_y = (y * image_height) / wanted_height; - tensorflow::uint8* in_row = in + (in_y * image_width * image_channels); - float* out_row = out + (y * wanted_width * wanted_channels); - for (int x = 0; x < wanted_width; ++x) { - const int in_x = (x * image_width) / wanted_width; - tensorflow::uint8* in_pixel = in_row + (in_x * image_channels); - float* out_pixel = out_row + (x * wanted_channels); - for (int c = 0; c < wanted_channels; ++c) { - out_pixel[c] = (in_pixel[c] - input_mean) / input_std; - } - } - } - tensorflow::string input_layer = "input"; - tensorflow::string output_layer = "output"; - std::vector outputs; - tensorflow::StatSummarizer stat_summarizer(tensorflow_graph); - double average_time = 0.0; - BenchmarkInference(session.get(), {{input_layer, image_tensor}}, - {output_layer}, &outputs, &stat_summarizer, &average_time); - NSString* result = - [NSString stringWithFormat:@"Average time: %.4f seconds \n\n", average_time]; - - tensorflow::Tensor* output = &outputs[0]; - const int kNumResults = 5; - const float kThreshold = 0.1f; - std::vector> top_results; - GetTopN(output->flat(), kNumResults, kThreshold, &top_results); - - std::stringstream ss; - ss.precision(3); - for (const auto& result : top_results) { - const float confidence = result.first; - const int index = result.second; - - ss << index << " " << confidence << " "; - - // Write out the result as a string - if (index < label_strings.size()) { - // just for safety: theoretically, the output is under 1000 unless there - // is some numerical issues leading to a wrong prediction. - ss << label_strings[index]; - } else { - ss << "Prediction: " << index; - } - - ss << "\n"; - } - - LOG(INFO) << "Predictions: " << ss.str(); - - tensorflow::string predictions = ss.str(); - result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; - - return result; -} diff --git a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.xib b/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.xib deleted file mode 100644 index 56c3708062..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/BenchmarkViewController.xib +++ /dev/null @@ -1,47 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj deleted file mode 100644 index 5cd173b416..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj/project.pbxproj +++ /dev/null @@ -1,367 +0,0 @@ -// !$*UTF8*$! -{ - archiveVersion = 1; - classes = { - }; - objectVersion = 46; - objects = { - -/* Begin PBXBuildFile section */ - 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */; }; - 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D871D02091F00DF5523 /* libprotobuf.a */; }; - 5993C7701D5D4E7F0048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */; }; - 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; - 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; - 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; - 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; }; - 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; - 59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; - 59A3D00B1CF4E68100C4259F /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */; }; - 59A3D00C1CF4E68100C4259F /* BenchmarkViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */; }; - 59A3D0141CF4E82500C4259F /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */; }; - 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 59A3D0171CF4E86100C4259F /* UIKit.framework */; }; -/* End PBXBuildFile section */ - -/* Begin PBXFileReference section */ - 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libprotobuf-lite.a"; path = "../../makefile/gen/protobuf_ios/lib/libprotobuf-lite.a"; sourceTree = ""; }; - 590E7D871D02091F00DF5523 /* libprotobuf.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libprotobuf.a; path = ../../makefile/gen/protobuf_ios/lib/libprotobuf.a; sourceTree = ""; }; - 5911579B1CF4011C00C31E3A /* benchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = benchmark.app; sourceTree = BUILT_PRODUCTS_DIR; }; - 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; - 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; - 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; - 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = cropped_panda.jpg; sourceTree = ""; }; - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; - 59A3CFF61CF4E68100C4259F /* imagenet_2012_challenge_label_map_proto.pbtxt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_2012_challenge_label_map_proto.pbtxt; sourceTree = ""; }; - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; - 59A3CFF81CF4E68100C4259F /* LICENSE */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = LICENSE; sourceTree = ""; }; - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; - 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; - 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; - 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; - 59A3CFFD1CF4E68100C4259F /* Benchmark-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "Benchmark-Info.plist"; sourceTree = ""; }; - 59A3CFFE1CF4E68100C4259F /* BenchmarkViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = BenchmarkViewController.h; sourceTree = ""; }; - 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = ""; }; - 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = BenchmarkViewController.xib; sourceTree = ""; }; - 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; - 59A3D0151CF4E83D00C4259F /* Foundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Foundation.framework; path = System/Library/Frameworks/Foundation.framework; sourceTree = SDKROOT; }; - 59A3D0171CF4E86100C4259F /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; -/* End PBXFileReference section */ - -/* Begin PBXFrameworksBuildPhase section */ - 591157981CF4011C00C31E3A /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - 5993C7701D5D4E7F0048CE6A /* Accelerate.framework in Frameworks */, - 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */, - 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */, - 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */, - 59A3D0141CF4E82500C4259F /* CoreGraphics.framework in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXFrameworksBuildPhase section */ - -/* Begin PBXGroup section */ - 591157921CF4011C00C31E3A = { - isa = PBXGroup; - children = ( - 5993C76F1D5D4E7F0048CE6A /* Accelerate.framework */, - 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */, - 590E7D871D02091F00DF5523 /* libprotobuf.a */, - 59A3D0171CF4E86100C4259F /* UIKit.framework */, - 59A3D0151CF4E83D00C4259F /* Foundation.framework */, - 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */, - 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, - 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, - 59A3CFF31CF4E68100C4259F /* data */, - 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, - 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, - 59A3CFFC1CF4E68100C4259F /* main.mm */, - 59A3CFFD1CF4E68100C4259F /* Benchmark-Info.plist */, - 59A3CFFE1CF4E68100C4259F /* BenchmarkViewController.h */, - 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */, - 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */, - 5911579C1CF4011C00C31E3A /* Products */, - ); - sourceTree = ""; - }; - 5911579C1CF4011C00C31E3A /* Products */ = { - isa = PBXGroup; - children = ( - 5911579B1CF4011C00C31E3A /* benchmark.app */, - ); - name = Products; - sourceTree = ""; - }; - 59A3CFF31CF4E68100C4259F /* data */ = { - isa = PBXGroup; - children = ( - 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */, - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, - 59A3CFF61CF4E68100C4259F /* imagenet_2012_challenge_label_map_proto.pbtxt */, - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */, - 59A3CFF81CF4E68100C4259F /* LICENSE */, - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */, - ); - path = data; - sourceTree = ""; - }; -/* End PBXGroup section */ - -/* Begin PBXNativeTarget section */ - 5911579A1CF4011C00C31E3A /* benchmark */ = { - isa = PBXNativeTarget; - buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "benchmark" */; - buildPhases = ( - 591157971CF4011C00C31E3A /* Sources */, - 591157981CF4011C00C31E3A /* Frameworks */, - 591157991CF4011C00C31E3A /* Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = benchmark; - productName = benchmark; - productReference = 5911579B1CF4011C00C31E3A /* benchmark.app */; - productType = "com.apple.product-type.application"; - }; -/* End PBXNativeTarget section */ - -/* Begin PBXProject section */ - 591157931CF4011C00C31E3A /* Project object */ = { - isa = PBXProject; - attributes = { - LastUpgradeCheck = 0720; - ORGANIZATIONNAME = Google; - TargetAttributes = { - 5911579A1CF4011C00C31E3A = { - CreatedOnToolsVersion = 7.2; - DevelopmentTeam = 85Z3VXS37U; - }; - }; - }; - buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "benchmark" */; - compatibilityVersion = "Xcode 3.2"; - developmentRegion = English; - hasScannedForEncodings = 0; - knownRegions = ( - en, - Base, - ); - mainGroup = 591157921CF4011C00C31E3A; - productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; - projectDirPath = ""; - projectRoot = ""; - targets = ( - 5911579A1CF4011C00C31E3A /* benchmark */, - ); - }; -/* End PBXProject section */ - -/* Begin PBXResourcesBuildPhase section */ - 591157991CF4011C00C31E3A /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 59A3D00C1CF4E68100C4259F /* BenchmarkViewController.xib in Resources */, - 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */, - 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */, - 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXResourcesBuildPhase section */ - -/* Begin PBXSourcesBuildPhase section */ - 591157971CF4011C00C31E3A /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 59A3D0091CF4E68100C4259F /* main.mm in Sources */, - 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */, - 59A3D00B1CF4E68100C4259F /* BenchmarkViewController.mm in Sources */, - 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXSourcesBuildPhase section */ - -/* Begin XCBuildConfiguration section */ - 591157B01CF4011D00C31E3A /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_DYNAMIC_NO_PIC = NO; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - MTL_ENABLE_DEBUG_INFO = YES; - ONLY_ACTIVE_ARCH = YES; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - }; - name = Debug; - }; - 591157B11CF4011D00C31E3A /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - MTL_ENABLE_DEBUG_INFO = NO; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - VALIDATE_PRODUCT = YES; - }; - name = Release; - }; - 591157B31CF4011D00C31E3A /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - CODE_SIGN_IDENTITY = "iPhone Developer"; - ENABLE_BITCODE = NO; - HEADER_SEARCH_PATHS = ( - "$(SRCROOT)/../../../..", - "$(SRCROOT)/../../makefile/downloads/protobuf/src/", - "$(SRCROOT)/../../makefile/downloads", - "$(SRCROOT)/../../makefile/gen/proto", - "$(SRCROOT)/../../makefile/downloads/eigen", - ); - INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ( - "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib", - "$(SRCROOT)/../../makefile/gen/lib", - ); - OTHER_LDFLAGS = ( - "-force_load", - "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", - "-Xlinker", - "-S", - "-Xlinker", - "-x", - "-Xlinker", - "-dead_strip", - ); - PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test"; - PRODUCT_NAME = "$(TARGET_NAME)"; - }; - name = Debug; - }; - 591157B41CF4011D00C31E3A /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - CODE_SIGN_IDENTITY = "iPhone Developer"; - ENABLE_BITCODE = NO; - HEADER_SEARCH_PATHS = ( - "$(SRCROOT)/../../../..", - "$(SRCROOT)/../../makefile/downloads/protobuf/src/", - "$(SRCROOT)/../../makefile/downloads", - "$(SRCROOT)/../../makefile/gen/proto", - "$(SRCROOT)/../../makefile/downloads/eigen", - ); - INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ( - "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib", - "$(SRCROOT)/../../makefile/gen/lib", - ); - ONLY_ACTIVE_ARCH = YES; - OTHER_LDFLAGS = ( - "-force_load", - "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", - "-Xlinker", - "-S", - "-Xlinker", - "-x", - "-Xlinker", - "-dead_strip", - ); - PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test"; - PRODUCT_NAME = "$(TARGET_NAME)"; - }; - name = Release; - }; -/* End XCBuildConfiguration section */ - -/* Begin XCConfigurationList section */ - 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "benchmark" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 591157B01CF4011D00C31E3A /* Debug */, - 591157B11CF4011D00C31E3A /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; - 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "benchmark" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 591157B31CF4011D00C31E3A /* Debug */, - 591157B41CF4011D00C31E3A /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; -/* End XCConfigurationList section */ - }; - rootObject = 591157931CF4011C00C31E3A /* Project object */; -} diff --git a/tensorflow/contrib/ios_examples/benchmark/data/grace_hopper.jpg b/tensorflow/contrib/ios_examples/benchmark/data/grace_hopper.jpg deleted file mode 100644 index d2a427810f..0000000000 Binary files a/tensorflow/contrib/ios_examples/benchmark/data/grace_hopper.jpg and /dev/null differ diff --git a/tensorflow/contrib/ios_examples/benchmark/ios_image_load.h b/tensorflow/contrib/ios_examples/benchmark/ios_image_load.h deleted file mode 100644 index 78eaded8d7..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/ios_image_load.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ -#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ - -#include - -#include "tensorflow/core/framework/types.h" - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); - -#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/ios_examples/benchmark/ios_image_load.mm b/tensorflow/contrib/ios_examples/benchmark/ios_image_load.mm deleted file mode 100644 index 64d1ea21cf..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/ios_image_load.mm +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ios_image_load.h" - -#include -#include -#include -#include - -#import -#import - -using tensorflow::uint8; - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { - FILE* file_handle = fopen(file_name, "rb"); - fseek(file_handle, 0, SEEK_END); - const size_t bytes_in_file = ftell(file_handle); - fseek(file_handle, 0, SEEK_SET); - std::vector file_data(bytes_in_file); - fread(file_data.data(), 1, bytes_in_file, file_handle); - fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); - - const char* suffix = strrchr(file_name, '.'); - if (!suffix || suffix == file_name) { - suffix = ""; - } - CGImageRef image; - if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else { - CFRelease(image_provider); - CFRelease(file_data_ref); - fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); - *out_width = 0; - *out_height = 0; - *out_channels = 0; - return std::vector(); - } - - const int width = (int)CGImageGetWidth(image); - const int height = (int)CGImageGetHeight(image); - const int channels = 4; - CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); - const int bytes_per_row = (width * channels); - const int bytes_in_image = (bytes_per_row * height); - std::vector result(bytes_in_image); - const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); - CGColorSpaceRelease(color_space); - CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); - CGContextRelease(context); - CFRelease(image); - CFRelease(image_provider); - CFRelease(file_data_ref); - - *out_width = width; - *out_height = height; - *out_channels = channels; - return result; -} diff --git a/tensorflow/contrib/ios_examples/benchmark/main.mm b/tensorflow/contrib/ios_examples/benchmark/main.mm deleted file mode 100644 index d70550a730..0000000000 --- a/tensorflow/contrib/ios_examples/benchmark/main.mm +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -int main(int argc, char * argv[]) { - @autoreleasepool { - NSString *delegateClassName = @"AppDelegate"; - return UIApplicationMain(argc, argv, nil, delegateClassName); - } -} diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.h b/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.h deleted file mode 100644 index 0039d5e7ca..0000000000 --- a/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface CameraExampleAppDelegate : UIResponder - -@property(strong, nonatomic) UIWindow *window; - -@end diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.m b/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.m deleted file mode 100644 index d134c2b591..0000000000 --- a/tensorflow/contrib/ios_examples/camera/CameraExampleAppDelegate.m +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "CameraExampleAppDelegate.h" - -@implementation CameraExampleAppDelegate - -@synthesize window = _window; - -- (BOOL)application:(UIApplication *)application - didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { - [self.window makeKeyAndVisible]; - return YES; -} - -- (void)applicationWillResignActive:(UIApplication *)application { - [[UIApplication sharedApplication] setIdleTimerDisabled:NO]; -} - -- (void)applicationDidEnterBackground:(UIApplication *)application { -} - -- (void)applicationWillEnterForeground:(UIApplication *)application { -} - -- (void)applicationDidBecomeActive:(UIApplication *)application { - [[UIApplication sharedApplication] setIdleTimerDisabled:YES]; -} - -- (void)applicationWillTerminate:(UIApplication *)application { -} - -@end diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.h b/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.h deleted file mode 100644 index df744428a8..0000000000 --- a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import - -#include -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/util/memmapped_file_system.h" - -@interface CameraExampleViewController - : UIViewController { - IBOutlet UIView *previewView; - IBOutlet UISegmentedControl *camerasControl; - AVCaptureVideoPreviewLayer *previewLayer; - AVCaptureVideoDataOutput *videoDataOutput; - dispatch_queue_t videoDataOutputQueue; - AVCaptureStillImageOutput *stillImageOutput; - UIView *flashView; - BOOL isUsingFrontFacingCamera; - AVSpeechSynthesizer *synth; - NSMutableDictionary *oldPredictionValues; - NSMutableArray *labelLayers; - AVCaptureSession *session; - std::unique_ptr tf_session; - std::unique_ptr tf_memmapped_env; - std::vector labels; -} -@property(strong, nonatomic) CATextLayer *predictionTextLayer; - -- (IBAction)takePicture:(id)sender; -- (IBAction)switchCameras:(id)sender; - -@end diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm b/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm deleted file mode 100644 index 27df3d3d71..0000000000 --- a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm +++ /dev/null @@ -1,596 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import -#import -#import -#import "CameraExampleViewController.h" - -#include - -#include "tensorflow_utils.h" - -// If you have your own model, modify this to the file name, and make sure -// you've added the file to your app resources too. -static NSString* model_file_name = @"tensorflow_inception_graph"; -static NSString* model_file_type = @"pb"; -// This controls whether we'll be loading a plain GraphDef proto, or a -// file created by the convert_graphdef_memmapped_format utility that wraps a -// GraphDef and parameter file that can be mapped into memory from file to -// reduce overall memory usage. -const bool model_uses_memory_mapping = false; -// If you have your own model, point this to the labels file. -static NSString* labels_file_name = @"imagenet_comp_graph_label_strings"; -static NSString* labels_file_type = @"txt"; -// These dimensions need to match those the model was trained with. -const int wanted_input_width = 224; -const int wanted_input_height = 224; -const int wanted_input_channels = 3; -const float input_mean = 117.0f; -const float input_std = 1.0f; -const std::string input_layer_name = "input"; -const std::string output_layer_name = "softmax1"; - -static void *AVCaptureStillImageIsCapturingStillImageContext = - &AVCaptureStillImageIsCapturingStillImageContext; - -@interface CameraExampleViewController (InternalMethods) -- (void)setupAVCapture; -- (void)teardownAVCapture; -@end - -@implementation CameraExampleViewController - -- (void)setupAVCapture { - NSError *error = nil; - - session = [AVCaptureSession new]; - if ([[UIDevice currentDevice] userInterfaceIdiom] == - UIUserInterfaceIdiomPhone) - [session setSessionPreset:AVCaptureSessionPreset640x480]; - else - [session setSessionPreset:AVCaptureSessionPresetPhoto]; - - AVCaptureDevice *device = - [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; - AVCaptureDeviceInput *deviceInput = - [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; - assert(error == nil); - - isUsingFrontFacingCamera = NO; - if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; - - stillImageOutput = [AVCaptureStillImageOutput new]; - [stillImageOutput - addObserver:self - forKeyPath:@"capturingStillImage" - options:NSKeyValueObservingOptionNew - context:(void *)(AVCaptureStillImageIsCapturingStillImageContext)]; - if ([session canAddOutput:stillImageOutput]) - [session addOutput:stillImageOutput]; - - videoDataOutput = [AVCaptureVideoDataOutput new]; - - NSDictionary *rgbOutputSettings = [NSDictionary - dictionaryWithObject:[NSNumber numberWithInt:kCMPixelFormat_32BGRA] - forKey:(id)kCVPixelBufferPixelFormatTypeKey]; - [videoDataOutput setVideoSettings:rgbOutputSettings]; - [videoDataOutput setAlwaysDiscardsLateVideoFrames:YES]; - videoDataOutputQueue = - dispatch_queue_create("VideoDataOutputQueue", DISPATCH_QUEUE_SERIAL); - [videoDataOutput setSampleBufferDelegate:self queue:videoDataOutputQueue]; - - if ([session canAddOutput:videoDataOutput]) - [session addOutput:videoDataOutput]; - [[videoDataOutput connectionWithMediaType:AVMediaTypeVideo] setEnabled:YES]; - - previewLayer = [[AVCaptureVideoPreviewLayer alloc] initWithSession:session]; - [previewLayer setBackgroundColor:[[UIColor blackColor] CGColor]]; - [previewLayer setVideoGravity:AVLayerVideoGravityResizeAspect]; - CALayer *rootLayer = [previewView layer]; - [rootLayer setMasksToBounds:YES]; - [previewLayer setFrame:[rootLayer bounds]]; - [rootLayer addSublayer:previewLayer]; - [session startRunning]; - - if (error) { - NSString *title = [NSString stringWithFormat:@"Failed with error %d", (int)[error code]]; - UIAlertController *alertController = - [UIAlertController alertControllerWithTitle:title - message:[error localizedDescription] - preferredStyle:UIAlertControllerStyleAlert]; - UIAlertAction *dismiss = - [UIAlertAction actionWithTitle:@"Dismiss" style:UIAlertActionStyleDefault handler:nil]; - [alertController addAction:dismiss]; - [self presentViewController:alertController animated:YES completion:nil]; - [self teardownAVCapture]; - } -} - -- (void)teardownAVCapture { - [stillImageOutput removeObserver:self forKeyPath:@"isCapturingStillImage"]; - [previewLayer removeFromSuperlayer]; -} - -- (void)observeValueForKeyPath:(NSString *)keyPath - ofObject:(id)object - change:(NSDictionary *)change - context:(void *)context { - if (context == AVCaptureStillImageIsCapturingStillImageContext) { - BOOL isCapturingStillImage = - [[change objectForKey:NSKeyValueChangeNewKey] boolValue]; - - if (isCapturingStillImage) { - // do flash bulb like animation - flashView = [[UIView alloc] initWithFrame:[previewView frame]]; - [flashView setBackgroundColor:[UIColor whiteColor]]; - [flashView setAlpha:0.f]; - [[[self view] window] addSubview:flashView]; - - [UIView animateWithDuration:.4f - animations:^{ - [flashView setAlpha:1.f]; - }]; - } else { - [UIView animateWithDuration:.4f - animations:^{ - [flashView setAlpha:0.f]; - } - completion:^(BOOL finished) { - [flashView removeFromSuperview]; - flashView = nil; - }]; - } - } -} - -- (AVCaptureVideoOrientation)avOrientationForDeviceOrientation: - (UIDeviceOrientation)deviceOrientation { - AVCaptureVideoOrientation result = - (AVCaptureVideoOrientation)(deviceOrientation); - if (deviceOrientation == UIDeviceOrientationLandscapeLeft) - result = AVCaptureVideoOrientationLandscapeRight; - else if (deviceOrientation == UIDeviceOrientationLandscapeRight) - result = AVCaptureVideoOrientationLandscapeLeft; - return result; -} - -- (IBAction)takePicture:(id)sender { - if ([session isRunning]) { - [session stopRunning]; - [sender setTitle:@"Continue" forState:UIControlStateNormal]; - - flashView = [[UIView alloc] initWithFrame:[previewView frame]]; - [flashView setBackgroundColor:[UIColor whiteColor]]; - [flashView setAlpha:0.f]; - [[[self view] window] addSubview:flashView]; - - [UIView animateWithDuration:.2f - animations:^{ - [flashView setAlpha:1.f]; - } - completion:^(BOOL finished) { - [UIView animateWithDuration:.2f - animations:^{ - [flashView setAlpha:0.f]; - } - completion:^(BOOL finished) { - [flashView removeFromSuperview]; - flashView = nil; - }]; - }]; - - } else { - [session startRunning]; - [sender setTitle:@"Freeze Frame" forState:UIControlStateNormal]; - } -} - -+ (CGRect)videoPreviewBoxForGravity:(NSString *)gravity - frameSize:(CGSize)frameSize - apertureSize:(CGSize)apertureSize { - CGFloat apertureRatio = apertureSize.height / apertureSize.width; - CGFloat viewRatio = frameSize.width / frameSize.height; - - CGSize size = CGSizeZero; - if ([gravity isEqualToString:AVLayerVideoGravityResizeAspectFill]) { - if (viewRatio > apertureRatio) { - size.width = frameSize.width; - size.height = - apertureSize.width * (frameSize.width / apertureSize.height); - } else { - size.width = - apertureSize.height * (frameSize.height / apertureSize.width); - size.height = frameSize.height; - } - } else if ([gravity isEqualToString:AVLayerVideoGravityResizeAspect]) { - if (viewRatio > apertureRatio) { - size.width = - apertureSize.height * (frameSize.height / apertureSize.width); - size.height = frameSize.height; - } else { - size.width = frameSize.width; - size.height = - apertureSize.width * (frameSize.width / apertureSize.height); - } - } else if ([gravity isEqualToString:AVLayerVideoGravityResize]) { - size.width = frameSize.width; - size.height = frameSize.height; - } - - CGRect videoBox; - videoBox.size = size; - if (size.width < frameSize.width) - videoBox.origin.x = (frameSize.width - size.width) / 2; - else - videoBox.origin.x = (size.width - frameSize.width) / 2; - - if (size.height < frameSize.height) - videoBox.origin.y = (frameSize.height - size.height) / 2; - else - videoBox.origin.y = (size.height - frameSize.height) / 2; - - return videoBox; -} - -- (void)captureOutput:(AVCaptureOutput *)captureOutput -didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer - fromConnection:(AVCaptureConnection *)connection { - CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); - CFRetain(pixelBuffer); - [self runCNNOnFrame:pixelBuffer]; - CFRelease(pixelBuffer); -} - -- (void)runCNNOnFrame:(CVPixelBufferRef)pixelBuffer { - assert(pixelBuffer != NULL); - - OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); - int doReverseChannels; - if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { - doReverseChannels = 1; - } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { - doReverseChannels = 0; - } else { - assert(false); // Unknown source format - } - - const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); - const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); - const int fullHeight = (int)CVPixelBufferGetHeight(pixelBuffer); - - CVPixelBufferLockFlags unlockFlags = kNilOptions; - CVPixelBufferLockBaseAddress(pixelBuffer, unlockFlags); - - unsigned char *sourceBaseAddr = - (unsigned char *)(CVPixelBufferGetBaseAddress(pixelBuffer)); - int image_height; - unsigned char *sourceStartAddr; - if (fullHeight <= image_width) { - image_height = fullHeight; - sourceStartAddr = sourceBaseAddr; - } else { - image_height = image_width; - const int marginY = ((fullHeight - image_width) / 2); - sourceStartAddr = (sourceBaseAddr + (marginY * sourceRowBytes)); - } - const int image_channels = 4; - - assert(image_channels >= wanted_input_channels); - tensorflow::Tensor image_tensor( - tensorflow::DT_FLOAT, - tensorflow::TensorShape( - {1, wanted_input_height, wanted_input_width, wanted_input_channels})); - auto image_tensor_mapped = image_tensor.tensor(); - tensorflow::uint8 *in = sourceStartAddr; - float *out = image_tensor_mapped.data(); - for (int y = 0; y < wanted_input_height; ++y) { - float *out_row = out + (y * wanted_input_width * wanted_input_channels); - for (int x = 0; x < wanted_input_width; ++x) { - const int in_x = (y * image_width) / wanted_input_width; - const int in_y = (x * image_height) / wanted_input_height; - tensorflow::uint8 *in_pixel = - in + (in_y * image_width * image_channels) + (in_x * image_channels); - float *out_pixel = out_row + (x * wanted_input_channels); - for (int c = 0; c < wanted_input_channels; ++c) { - out_pixel[c] = (in_pixel[c] - input_mean) / input_std; - } - } - } - - CVPixelBufferUnlockBaseAddress(pixelBuffer, unlockFlags); - - if (tf_session.get()) { - std::vector outputs; - tensorflow::Status run_status = tf_session->Run( - {{input_layer_name, image_tensor}}, {output_layer_name}, {}, &outputs); - if (!run_status.ok()) { - LOG(ERROR) << "Running model failed:" << run_status; - } else { - tensorflow::Tensor *output = &outputs[0]; - auto predictions = output->flat(); - - NSMutableDictionary *newValues = [NSMutableDictionary dictionary]; - for (int index = 0; index < predictions.size(); ++index) { - const float predictionValue = predictions(index); - if (predictionValue > 0.05f) { - std::string label = labels[index]; - NSString *labelObject = [NSString stringWithUTF8String:label.c_str()]; - NSNumber *valueObject = [NSNumber numberWithFloat:predictionValue]; - [newValues setObject:valueObject forKey:labelObject]; - } - } - dispatch_async(dispatch_get_main_queue(), ^(void) { - [self setPredictionValues:newValues]; - }); - } - } - CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); -} - -- (void)dealloc { - [self teardownAVCapture]; -} - -// use front/back camera -- (IBAction)switchCameras:(id)sender { - AVCaptureDevicePosition desiredPosition; - if (isUsingFrontFacingCamera) - desiredPosition = AVCaptureDevicePositionBack; - else - desiredPosition = AVCaptureDevicePositionFront; - - for (AVCaptureDevice *d in - [AVCaptureDevice devicesWithMediaType:AVMediaTypeVideo]) { - if ([d position] == desiredPosition) { - [[previewLayer session] beginConfiguration]; - AVCaptureDeviceInput *input = - [AVCaptureDeviceInput deviceInputWithDevice:d error:nil]; - for (AVCaptureInput *oldInput in [[previewLayer session] inputs]) { - [[previewLayer session] removeInput:oldInput]; - } - [[previewLayer session] addInput:input]; - [[previewLayer session] commitConfiguration]; - break; - } - } - isUsingFrontFacingCamera = !isUsingFrontFacingCamera; -} - -- (void)viewDidLoad { - [super viewDidLoad]; - synth = [[AVSpeechSynthesizer alloc] init]; - labelLayers = [[NSMutableArray alloc] init]; - oldPredictionValues = [[NSMutableDictionary alloc] init]; - - tensorflow::Status load_status; - if (model_uses_memory_mapping) { - load_status = LoadMemoryMappedModel( - model_file_name, model_file_type, &tf_session, &tf_memmapped_env); - } else { - load_status = LoadModel(model_file_name, model_file_type, &tf_session); - } - if (!load_status.ok()) { - LOG(FATAL) << "Couldn't load model: " << load_status; - } - - tensorflow::Status labels_status = - LoadLabels(labels_file_name, labels_file_type, &labels); - if (!labels_status.ok()) { - LOG(FATAL) << "Couldn't load labels: " << labels_status; - } - [self setupAVCapture]; -} - -- (BOOL)shouldAutorotateToInterfaceOrientation: - (UIInterfaceOrientation)interfaceOrientation { - return (interfaceOrientation == UIInterfaceOrientationPortrait); -} - -- (BOOL)prefersStatusBarHidden { - return YES; -} - -- (void)setPredictionValues:(NSDictionary *)newValues { - const float decayValue = 0.75f; - const float updateValue = 0.25f; - const float minimumThreshold = 0.01f; - - NSMutableDictionary *decayedPredictionValues = - [[NSMutableDictionary alloc] init]; - for (NSString *label in oldPredictionValues) { - NSNumber *oldPredictionValueObject = - [oldPredictionValues objectForKey:label]; - const float oldPredictionValue = [oldPredictionValueObject floatValue]; - const float decayedPredictionValue = (oldPredictionValue * decayValue); - if (decayedPredictionValue > minimumThreshold) { - NSNumber *decayedPredictionValueObject = - [NSNumber numberWithFloat:decayedPredictionValue]; - [decayedPredictionValues setObject:decayedPredictionValueObject - forKey:label]; - } - } - oldPredictionValues = decayedPredictionValues; - - for (NSString *label in newValues) { - NSNumber *newPredictionValueObject = [newValues objectForKey:label]; - NSNumber *oldPredictionValueObject = - [oldPredictionValues objectForKey:label]; - if (!oldPredictionValueObject) { - oldPredictionValueObject = [NSNumber numberWithFloat:0.0f]; - } - const float newPredictionValue = [newPredictionValueObject floatValue]; - const float oldPredictionValue = [oldPredictionValueObject floatValue]; - const float updatedPredictionValue = - (oldPredictionValue + (newPredictionValue * updateValue)); - NSNumber *updatedPredictionValueObject = - [NSNumber numberWithFloat:updatedPredictionValue]; - [oldPredictionValues setObject:updatedPredictionValueObject forKey:label]; - } - NSArray *candidateLabels = [NSMutableArray array]; - for (NSString *label in oldPredictionValues) { - NSNumber *oldPredictionValueObject = - [oldPredictionValues objectForKey:label]; - const float oldPredictionValue = [oldPredictionValueObject floatValue]; - if (oldPredictionValue > 0.05f) { - NSDictionary *entry = @{ - @"label" : label, - @"value" : oldPredictionValueObject - }; - candidateLabels = [candidateLabels arrayByAddingObject:entry]; - } - } - NSSortDescriptor *sort = - [NSSortDescriptor sortDescriptorWithKey:@"value" ascending:NO]; - NSArray *sortedLabels = [candidateLabels - sortedArrayUsingDescriptors:[NSArray arrayWithObject:sort]]; - - const float leftMargin = 10.0f; - const float topMargin = 10.0f; - - const float valueWidth = 48.0f; - const float valueHeight = 26.0f; - - const float labelWidth = 246.0f; - const float labelHeight = 26.0f; - - const float labelMarginX = 5.0f; - const float labelMarginY = 5.0f; - - [self removeAllLabelLayers]; - - int labelCount = 0; - for (NSDictionary *entry in sortedLabels) { - NSString *label = [entry objectForKey:@"label"]; - NSNumber *valueObject = [entry objectForKey:@"value"]; - const float value = [valueObject floatValue]; - - const float originY = - (topMargin + ((labelHeight + labelMarginY) * labelCount)); - - const int valuePercentage = (int)roundf(value * 100.0f); - - const float valueOriginX = leftMargin; - NSString *valueText = [NSString stringWithFormat:@"%d%%", valuePercentage]; - - [self addLabelLayerWithText:valueText - originX:valueOriginX - originY:originY - width:valueWidth - height:valueHeight - alignment:kCAAlignmentRight]; - - const float labelOriginX = (leftMargin + valueWidth + labelMarginX); - - [self addLabelLayerWithText:[label capitalizedString] - originX:labelOriginX - originY:originY - width:labelWidth - height:labelHeight - alignment:kCAAlignmentLeft]; - - if ((labelCount == 0) && (value > 0.5f)) { - [self speak:[label capitalizedString]]; - } - - labelCount += 1; - if (labelCount > 4) { - break; - } - } -} - -- (void)removeAllLabelLayers { - for (CATextLayer *layer in labelLayers) { - [layer removeFromSuperlayer]; - } - [labelLayers removeAllObjects]; -} - -- (void)addLabelLayerWithText:(NSString *)text - originX:(float)originX - originY:(float)originY - width:(float)width - height:(float)height - alignment:(NSString *)alignment { - CFTypeRef font = (CFTypeRef) @"Menlo-Regular"; - const float fontSize = 20.0f; - - const float marginSizeX = 5.0f; - const float marginSizeY = 2.0f; - - const CGRect backgroundBounds = CGRectMake(originX, originY, width, height); - - const CGRect textBounds = - CGRectMake((originX + marginSizeX), (originY + marginSizeY), - (width - (marginSizeX * 2)), (height - (marginSizeY * 2))); - - CATextLayer *background = [CATextLayer layer]; - [background setBackgroundColor:[UIColor blackColor].CGColor]; - [background setOpacity:0.5f]; - [background setFrame:backgroundBounds]; - background.cornerRadius = 5.0f; - - [[self.view layer] addSublayer:background]; - [labelLayers addObject:background]; - - CATextLayer *layer = [CATextLayer layer]; - [layer setForegroundColor:[UIColor whiteColor].CGColor]; - [layer setFrame:textBounds]; - [layer setAlignmentMode:alignment]; - [layer setWrapped:YES]; - [layer setFont:font]; - [layer setFontSize:fontSize]; - layer.contentsScale = [[UIScreen mainScreen] scale]; - [layer setString:text]; - - [[self.view layer] addSublayer:layer]; - [labelLayers addObject:layer]; -} - -- (void)setPredictionText:(NSString *)text withDuration:(float)duration { - if (duration > 0.0) { - CABasicAnimation *colorAnimation = - [CABasicAnimation animationWithKeyPath:@"foregroundColor"]; - colorAnimation.duration = duration; - colorAnimation.fillMode = kCAFillModeForwards; - colorAnimation.removedOnCompletion = NO; - colorAnimation.fromValue = (id)[UIColor darkGrayColor].CGColor; - colorAnimation.toValue = (id)[UIColor whiteColor].CGColor; - colorAnimation.timingFunction = - [CAMediaTimingFunction functionWithName:kCAMediaTimingFunctionLinear]; - [self.predictionTextLayer addAnimation:colorAnimation - forKey:@"colorAnimation"]; - } else { - self.predictionTextLayer.foregroundColor = [UIColor whiteColor].CGColor; - } - - [self.predictionTextLayer removeFromSuperlayer]; - [[self.view layer] addSublayer:self.predictionTextLayer]; - [self.predictionTextLayer setString:text]; -} - -- (void)speak:(NSString *)words { - if ([synth isSpeaking]) { - return; - } - AVSpeechUtterance *utterance = - [AVSpeechUtterance speechUtteranceWithString:words]; - utterance.voice = [AVSpeechSynthesisVoice voiceWithLanguage:@"en-US"]; - utterance.rate = 0.75 * AVSpeechUtteranceDefaultSpeechRate; - [synth speakUtterance:utterance]; -} - -@end diff --git a/tensorflow/contrib/ios_examples/camera/Info.plist b/tensorflow/contrib/ios_examples/camera/Info.plist deleted file mode 100644 index 82978ca278..0000000000 --- a/tensorflow/contrib/ios_examples/camera/Info.plist +++ /dev/null @@ -1,44 +0,0 @@ - - - - - CFBundleDevelopmentRegion - en - CFBundleDisplayName - ${PRODUCT_NAME} - CFBundleExecutable - ${EXECUTABLE_NAME} - CFBundleIdentifier - $(PRODUCT_BUNDLE_IDENTIFIER) - CFBundleInfoDictionaryVersion - 6.0 - CFBundleName - ${PRODUCT_NAME} - CFBundlePackageType - APPL - CFBundleShortVersionString - 1.0 - CFBundleSignature - ???? - CFBundleVersion - 1.0 - LSRequiresIPhoneOS - - NSCameraUsageDescription - Capture images to detect object - UIMainStoryboardFile - MainStoryboard_iPhone - UIRequiresFullScreen - - UIStatusBarHidden - - UISupportedInterfaceOrientations - - UIInterfaceOrientationPortrait - - UISupportedInterfaceOrientations~ipad - - UIInterfaceOrientationPortrait - - - diff --git a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj deleted file mode 100644 index e9d783e49d..0000000000 --- a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj +++ /dev/null @@ -1,431 +0,0 @@ -// !$*UTF8*$! -{ - archiveVersion = 1; - classes = { - }; - objectVersion = 46; - objects = { - -/* Begin PBXBuildFile section */ - 591D3EC51CFF7F130059011C /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3EC41CFF7F120059011C /* AVFoundation.framework */; }; - 591D3ECB1CFF7F5F0059011C /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3ECA1CFF7F5F0059011C /* CoreMedia.framework */; }; - 591D3ECD1CFF7F9F0059011C /* AssetsLibrary.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3ECC1CFF7F9F0059011C /* AssetsLibrary.framework */; }; - 591D3ECF1CFF7FCE0059011C /* ImageIO.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3ECE1CFF7FCE0059011C /* ImageIO.framework */; }; - 591D3ED21CFF85C30059011C /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 591D3ED11CFF85C30059011C /* ios_image_load.mm */; }; - 591D3ED51CFF85FD0059011C /* tensorflow_utils.mm in Sources */ = {isa = PBXBuildFile; fileRef = 591D3ED31CFF85FD0059011C /* tensorflow_utils.mm */; }; - 591D3EDB1CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 591D3ED81CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt */; }; - 591D3EDC1CFFA83A0059011C /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 591D3ED91CFFA83A0059011C /* tensorflow_inception_graph.pb */; }; - 591D3EDF1CFFAD230059011C /* libprotobuf-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3EDD1CFFAD230059011C /* libprotobuf-lite.a */; }; - 591D3EE01CFFAD230059011C /* libprotobuf.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 591D3EDE1CFFAD230059011C /* libprotobuf.a */; }; - 592FF8B918ECBD7600C164F8 /* Foundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 592FF8B818ECBD7600C164F8 /* Foundation.framework */; }; - 592FF8BB18ECBD7600C164F8 /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 592FF8BA18ECBD7600C164F8 /* CoreGraphics.framework */; }; - 592FF90218ECC66200C164F8 /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 592FF90118ECC66200C164F8 /* main.mm */; }; - 592FF90D18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 592FF90A18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard */; }; - 592FF92518EE240200C164F8 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 592FF92218EE240200C164F8 /* CameraExampleAppDelegate.m */; }; - 592FF92618EE240200C164F8 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 592FF92418EE240200C164F8 /* CameraExampleViewController.mm */; }; - 5993C7721D5D4E980048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C7711D5D4E980048CE6A /* Accelerate.framework */; }; -/* End PBXBuildFile section */ - -/* Begin PBXFileReference section */ - 591D3EC41CFF7F120059011C /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/AVFoundation.framework; sourceTree = DEVELOPER_DIR; }; - 591D3EC61CFF7F370059011C /* CoreFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreFoundation.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/CoreFoundation.framework; sourceTree = DEVELOPER_DIR; }; - 591D3EC81CFF7F500059011C /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/CoreImage.framework; sourceTree = DEVELOPER_DIR; }; - 591D3ECA1CFF7F5F0059011C /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/CoreMedia.framework; sourceTree = DEVELOPER_DIR; }; - 591D3ECC1CFF7F9F0059011C /* AssetsLibrary.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AssetsLibrary.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/AssetsLibrary.framework; sourceTree = DEVELOPER_DIR; }; - 591D3ECE1CFF7FCE0059011C /* ImageIO.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = ImageIO.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.2.sdk/System/Library/Frameworks/ImageIO.framework; sourceTree = DEVELOPER_DIR; }; - 591D3ED01CFF85C30059011C /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = SOURCE_ROOT; }; - 591D3ED11CFF85C30059011C /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = SOURCE_ROOT; }; - 591D3ED31CFF85FD0059011C /* tensorflow_utils.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = tensorflow_utils.mm; sourceTree = SOURCE_ROOT; }; - 591D3ED41CFF85FD0059011C /* tensorflow_utils.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = tensorflow_utils.h; sourceTree = SOURCE_ROOT; }; - 591D3ED81CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; - 591D3ED91CFFA83A0059011C /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; - 591D3EDD1CFFAD230059011C /* libprotobuf-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libprotobuf-lite.a"; path = "../../makefile/gen/protobuf_ios/lib/libprotobuf-lite.a"; sourceTree = ""; }; - 591D3EDE1CFFAD230059011C /* libprotobuf.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libprotobuf.a; path = ../../makefile/gen/protobuf_ios/lib/libprotobuf.a; sourceTree = ""; }; - 592FF8B518ECBD7600C164F8 /* CameraExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = CameraExample.app; sourceTree = BUILT_PRODUCTS_DIR; }; - 592FF8B818ECBD7600C164F8 /* Foundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Foundation.framework; path = System/Library/Frameworks/Foundation.framework; sourceTree = SDKROOT; }; - 592FF8BA18ECBD7600C164F8 /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; - 592FF90118ECC66200C164F8 /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = SOURCE_ROOT; }; - 592FF90318ECCB8300C164F8 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = SOURCE_ROOT; }; - 592FF90B18EDD0DA00C164F8 /* en */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = en; path = MainStoryboard_iPhone.storyboard; sourceTree = ""; }; - 592FF92118EE240200C164F8 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = SOURCE_ROOT; }; - 592FF92218EE240200C164F8 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = SOURCE_ROOT; }; - 592FF92318EE240200C164F8 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = SOURCE_ROOT; }; - 592FF92418EE240200C164F8 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = SOURCE_ROOT; }; - 5993C7711D5D4E980048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS9.3.sdk/System/Library/Frameworks/Accelerate.framework; sourceTree = DEVELOPER_DIR; }; -/* End PBXFileReference section */ - -/* Begin PBXFrameworksBuildPhase section */ - 592FF8B218ECBD7600C164F8 /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - 5993C7721D5D4E980048CE6A /* Accelerate.framework in Frameworks */, - 591D3EDF1CFFAD230059011C /* libprotobuf-lite.a in Frameworks */, - 591D3EE01CFFAD230059011C /* libprotobuf.a in Frameworks */, - 591D3ECF1CFF7FCE0059011C /* ImageIO.framework in Frameworks */, - 591D3ECD1CFF7F9F0059011C /* AssetsLibrary.framework in Frameworks */, - 591D3ECB1CFF7F5F0059011C /* CoreMedia.framework in Frameworks */, - 591D3EC51CFF7F130059011C /* AVFoundation.framework in Frameworks */, - 592FF8BB18ECBD7600C164F8 /* CoreGraphics.framework in Frameworks */, - 592FF8B918ECBD7600C164F8 /* Foundation.framework in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXFrameworksBuildPhase section */ - -/* Begin PBXGroup section */ - 591D3ED61CFFA83A0059011C /* data */ = { - isa = PBXGroup; - children = ( - 591D3ED81CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt */, - 591D3ED91CFFA83A0059011C /* tensorflow_inception_graph.pb */, - ); - path = data; - sourceTree = SOURCE_ROOT; - }; - 592FF8AA18ECBD3600C164F8 = { - isa = PBXGroup; - children = ( - 592FF8BE18ECBD7600C164F8 /* CameraExample */, - 592FF8B718ECBD7600C164F8 /* Frameworks */, - 592FF8B618ECBD7600C164F8 /* Products */, - ); - sourceTree = ""; - }; - 592FF8B618ECBD7600C164F8 /* Products */ = { - isa = PBXGroup; - children = ( - 592FF8B518ECBD7600C164F8 /* CameraExample.app */, - ); - name = Products; - sourceTree = ""; - }; - 592FF8B718ECBD7600C164F8 /* Frameworks */ = { - isa = PBXGroup; - children = ( - 5993C7711D5D4E980048CE6A /* Accelerate.framework */, - 591D3EDD1CFFAD230059011C /* libprotobuf-lite.a */, - 591D3EDE1CFFAD230059011C /* libprotobuf.a */, - 591D3ECE1CFF7FCE0059011C /* ImageIO.framework */, - 591D3ECC1CFF7F9F0059011C /* AssetsLibrary.framework */, - 591D3ECA1CFF7F5F0059011C /* CoreMedia.framework */, - 591D3EC81CFF7F500059011C /* CoreImage.framework */, - 591D3EC61CFF7F370059011C /* CoreFoundation.framework */, - 591D3EC41CFF7F120059011C /* AVFoundation.framework */, - 592FF8B818ECBD7600C164F8 /* Foundation.framework */, - 592FF8BA18ECBD7600C164F8 /* CoreGraphics.framework */, - ); - name = Frameworks; - sourceTree = ""; - }; - 592FF8BE18ECBD7600C164F8 /* CameraExample */ = { - isa = PBXGroup; - children = ( - 591D3ED61CFFA83A0059011C /* data */, - 592FF90718EDD0DA00C164F8 /* en.lproj */, - 592FF92118EE240200C164F8 /* CameraExampleAppDelegate.h */, - 592FF92218EE240200C164F8 /* CameraExampleAppDelegate.m */, - 592FF92318EE240200C164F8 /* CameraExampleViewController.h */, - 592FF92418EE240200C164F8 /* CameraExampleViewController.mm */, - 592FF90318ECCB8300C164F8 /* Info.plist */, - 591D3ED01CFF85C30059011C /* ios_image_load.h */, - 591D3ED11CFF85C30059011C /* ios_image_load.mm */, - 592FF90118ECC66200C164F8 /* main.mm */, - 591D3ED31CFF85FD0059011C /* tensorflow_utils.mm */, - 591D3ED41CFF85FD0059011C /* tensorflow_utils.h */, - ); - name = CameraExample; - path = SimpleExample; - sourceTree = ""; - }; - 592FF90718EDD0DA00C164F8 /* en.lproj */ = { - isa = PBXGroup; - children = ( - 592FF90A18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard */, - ); - path = en.lproj; - sourceTree = SOURCE_ROOT; - }; -/* End PBXGroup section */ - -/* Begin PBXNativeTarget section */ - 592FF8B418ECBD7600C164F8 /* CameraExample */ = { - isa = PBXNativeTarget; - buildConfigurationList = 592FF8E318ECBD7600C164F8 /* Build configuration list for PBXNativeTarget "CameraExample" */; - buildPhases = ( - 592FF8B118ECBD7600C164F8 /* Sources */, - 592FF8B218ECBD7600C164F8 /* Frameworks */, - 592FF8B318ECBD7600C164F8 /* Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = CameraExample; - productName = SimpleExample; - productReference = 592FF8B518ECBD7600C164F8 /* CameraExample.app */; - productType = "com.apple.product-type.application"; - }; -/* End PBXNativeTarget section */ - -/* Begin PBXProject section */ - 592FF8AB18ECBD3600C164F8 /* Project object */ = { - isa = PBXProject; - attributes = { - LastUpgradeCheck = 0720; - }; - buildConfigurationList = 592FF8AE18ECBD3600C164F8 /* Build configuration list for PBXProject "camera_example" */; - compatibilityVersion = "Xcode 3.2"; - developmentRegion = English; - hasScannedForEncodings = 0; - knownRegions = ( - en, - ); - mainGroup = 592FF8AA18ECBD3600C164F8; - productRefGroup = 592FF8B618ECBD7600C164F8 /* Products */; - projectDirPath = ""; - projectRoot = ""; - targets = ( - 592FF8B418ECBD7600C164F8 /* CameraExample */, - ); - }; -/* End PBXProject section */ - -/* Begin PBXResourcesBuildPhase section */ - 592FF8B318ECBD7600C164F8 /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 591D3EDC1CFFA83A0059011C /* tensorflow_inception_graph.pb in Resources */, - 592FF90D18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard in Resources */, - 591D3EDB1CFFA83A0059011C /* imagenet_comp_graph_label_strings.txt in Resources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXResourcesBuildPhase section */ - -/* Begin PBXSourcesBuildPhase section */ - 592FF8B118ECBD7600C164F8 /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 592FF90218ECC66200C164F8 /* main.mm in Sources */, - 591D3ED21CFF85C30059011C /* ios_image_load.mm in Sources */, - 592FF92618EE240200C164F8 /* CameraExampleViewController.mm in Sources */, - 592FF92518EE240200C164F8 /* CameraExampleAppDelegate.m in Sources */, - 591D3ED51CFF85FD0059011C /* tensorflow_utils.mm in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXSourcesBuildPhase section */ - -/* Begin PBXVariantGroup section */ - 592FF90A18EDD0DA00C164F8 /* MainStoryboard_iPhone.storyboard */ = { - isa = PBXVariantGroup; - children = ( - 592FF90B18EDD0DA00C164F8 /* en */, - ); - name = MainStoryboard_iPhone.storyboard; - sourceTree = ""; - }; -/* End PBXVariantGroup section */ - -/* Begin XCBuildConfiguration section */ - 592FF8AF18ECBD3600C164F8 /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - ONLY_ACTIVE_ARCH = YES; - }; - name = Debug; - }; - 592FF8B018ECBD3600C164F8 /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - ENABLE_STRICT_OBJC_MSGSEND = YES; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - }; - name = Release; - }; - 592FF8DF18ECBD7600C164F8 /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "compiler-default"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - ENABLE_BITCODE = NO; - FRAMEWORK_SEARCH_PATHS = "$(inherited)"; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_DYNAMIC_NO_PIC = NO; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PRECOMPILE_PREFIX_HEADER = YES; - GCC_PREFIX_HEADER = ""; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_SYMBOLS_PRIVATE_EXTERN = NO; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - HEADER_SEARCH_PATHS = ( - "$(SRCROOT)/../../makefile/gen/proto", - "$(SRCROOT)/../../makefile/downloads/eigen", - "$(SRCROOT)/../../makefile/downloads", - "$(SRCROOT)/../../makefile/downloads/protobuf/src/", - "$(SRCROOT)/../../../..", - ); - INFOPLIST_FILE = "$(SRCROOT)/Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - LIBRARY_SEARCH_PATHS = ( - "$(SRCROOT)/../../makefile/gen/lib", - "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib", - ); - ONLY_ACTIVE_ARCH = NO; - OTHER_LDFLAGS = ( - "-force_load", - "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", - ); - PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample; - PRODUCT_NAME = "$(TARGET_NAME)"; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - VALID_ARCHS = "arm64 armv7 armv7s"; - WRAPPER_EXTENSION = app; - }; - name = Debug; - }; - 592FF8E018ECBD7600C164F8 /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "compiler-default"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = YES; - ENABLE_BITCODE = NO; - ENABLE_NS_ASSERTIONS = NO; - FRAMEWORK_SEARCH_PATHS = "$(inherited)"; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_PRECOMPILE_PREFIX_HEADER = YES; - GCC_PREFIX_HEADER = ""; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - HEADER_SEARCH_PATHS = ( - "$(SRCROOT)/../../makefile/gen/proto", - "$(SRCROOT)/../../makefile/downloads/eigen", - "$(SRCROOT)/../../makefile/downloads", - "$(SRCROOT)/../../makefile/downloads/protobuf/src/", - "$(SRCROOT)/../../../..", - ); - INFOPLIST_FILE = "$(SRCROOT)/Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - LIBRARY_SEARCH_PATHS = ( - "$(SRCROOT)/../../makefile/gen/lib", - "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib", - ); - ONLY_ACTIVE_ARCH = NO; - OTHER_LDFLAGS = ( - "-force_load", - "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", - ); - PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample; - PRODUCT_NAME = "$(TARGET_NAME)"; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - VALIDATE_PRODUCT = YES; - VALID_ARCHS = "arm64 armv7 armv7s"; - WRAPPER_EXTENSION = app; - }; - name = Release; - }; -/* End XCBuildConfiguration section */ - -/* Begin XCConfigurationList section */ - 592FF8AE18ECBD3600C164F8 /* Build configuration list for PBXProject "camera_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 592FF8AF18ECBD3600C164F8 /* Debug */, - 592FF8B018ECBD3600C164F8 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; - 592FF8E318ECBD7600C164F8 /* Build configuration list for PBXNativeTarget "CameraExample" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 592FF8DF18ECBD7600C164F8 /* Debug */, - 592FF8E018ECBD7600C164F8 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; -/* End XCConfigurationList section */ - }; - rootObject = 592FF8AB18ECBD3600C164F8 /* Project object */; -} diff --git a/tensorflow/contrib/ios_examples/camera/en.lproj/MainStoryboard_iPhone.storyboard b/tensorflow/contrib/ios_examples/camera/en.lproj/MainStoryboard_iPhone.storyboard deleted file mode 100644 index 0f10a22e41..0000000000 --- a/tensorflow/contrib/ios_examples/camera/en.lproj/MainStoryboard_iPhone.storyboard +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/contrib/ios_examples/camera/ios_image_load.h b/tensorflow/contrib/ios_examples/camera/ios_image_load.h deleted file mode 100644 index 87a847e145..0000000000 --- a/tensorflow/contrib/ios_examples/camera/ios_image_load.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_ -#define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_ - -#include - -#include "tensorflow/core/framework/types.h" - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); - -#endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/ios_examples/camera/ios_image_load.mm b/tensorflow/contrib/ios_examples/camera/ios_image_load.mm deleted file mode 100644 index 64d1ea21cf..0000000000 --- a/tensorflow/contrib/ios_examples/camera/ios_image_load.mm +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ios_image_load.h" - -#include -#include -#include -#include - -#import -#import - -using tensorflow::uint8; - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { - FILE* file_handle = fopen(file_name, "rb"); - fseek(file_handle, 0, SEEK_END); - const size_t bytes_in_file = ftell(file_handle); - fseek(file_handle, 0, SEEK_SET); - std::vector file_data(bytes_in_file); - fread(file_data.data(), 1, bytes_in_file, file_handle); - fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); - - const char* suffix = strrchr(file_name, '.'); - if (!suffix || suffix == file_name) { - suffix = ""; - } - CGImageRef image; - if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else { - CFRelease(image_provider); - CFRelease(file_data_ref); - fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); - *out_width = 0; - *out_height = 0; - *out_channels = 0; - return std::vector(); - } - - const int width = (int)CGImageGetWidth(image); - const int height = (int)CGImageGetHeight(image); - const int channels = 4; - CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); - const int bytes_per_row = (width * channels); - const int bytes_in_image = (bytes_per_row * height); - std::vector result(bytes_in_image); - const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); - CGColorSpaceRelease(color_space); - CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); - CGContextRelease(context); - CFRelease(image); - CFRelease(image_provider); - CFRelease(file_data_ref); - - *out_width = width; - *out_height = height; - *out_channels = channels; - return result; -} diff --git a/tensorflow/contrib/ios_examples/camera/main.mm b/tensorflow/contrib/ios_examples/camera/main.mm deleted file mode 100644 index 42eff697ef..0000000000 --- a/tensorflow/contrib/ios_examples/camera/main.mm +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -#import "CameraExampleAppDelegate.h" - -int main(int argc, char *argv[]) { - int retVal = 0; - - @autoreleasepool { - retVal = UIApplicationMain( - argc, argv, nil, NSStringFromClass([CameraExampleAppDelegate class])); - } - return retVal; -} diff --git a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.h b/tensorflow/contrib/ios_examples/camera/tensorflow_utils.h deleted file mode 100644 index 78bdb82aae..0000000000 --- a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ -#define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ - -#include -#include - -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/util/memmapped_file_system.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -// Reads a serialized GraphDef protobuf file from the bundle, typically -// created with the freeze_graph script. Populates the session argument with a -// Session object that has the model loaded. -tensorflow::Status LoadModel(NSString* file_name, NSString* file_type, - std::unique_ptr* session); - -// Loads a model from a file that has been created using the -// convert_graphdef_memmapped_format tool. This bundles together a GraphDef -// proto together with a file that can be memory-mapped, containing the weight -// parameters for the model. This is useful because it reduces the overall -// memory pressure, since the read-only parameter regions can be easily paged -// out and don't count toward memory limits on iOS. -tensorflow::Status LoadMemoryMappedModel( - NSString* file_name, NSString* file_type, - std::unique_ptr* session, - std::unique_ptr* memmapped_env); - -// Takes a text file with a single label on each line, and returns a list. -tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type, - std::vector* label_strings); - -// Sorts the results from a model execution, and returns the highest scoring. -void GetTopN(const Eigen::TensorMap, - Eigen::Aligned>& prediction, - const int num_results, const float threshold, - std::vector >* top_results); - -#endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ diff --git a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm b/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm deleted file mode 100644 index 43746882ee..0000000000 --- a/tensorflow/contrib/ios_examples/camera/tensorflow_utils.mm +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -#include "tensorflow_utils.h" - -#include -#include -#include -#include -#include -#include - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "google/protobuf/io/zero_copy_stream_impl_lite.h" -#include "google/protobuf/message_lite.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session.h" - -namespace { - -// Helper class used to load protobufs efficiently. -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const std::string& file_name) - : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -// Returns the top N confidence values over threshold in the provided vector, -// sorted by confidence in descending order. -void GetTopN(const Eigen::TensorMap, - Eigen::Aligned>& prediction, - const int num_results, const float threshold, - std::vector >* top_results) { - // Will contain top N results in ascending order. - std::priority_queue, - std::vector >, - std::greater > > - top_result_pq; - - const int count = prediction.size(); - for (int i = 0; i < count; ++i) { - const float value = prediction(i); - - // Only add it if it beats the threshold and has a chance at being in - // the top N. - if (value < threshold) { - continue; - } - - top_result_pq.push(std::pair(value, i)); - - // If at capacity, kick the smallest value out. - if (top_result_pq.size() > num_results) { - top_result_pq.pop(); - } - } - - // Copy to output vector and reverse into descending order. - while (!top_result_pq.empty()) { - top_results->push_back(top_result_pq.top()); - top_result_pq.pop(); - } - std::reverse(top_results->begin(), top_results->end()); -} - -bool PortableReadFileToProto(const std::string& file_name, - ::google::protobuf::MessageLite* proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(file_name)); - stream.SetOwnsCopyingStream(true); - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -NSString* FilePathForResourceName(NSString* name, NSString* extension) { - NSString* file_path = - [[NSBundle mainBundle] pathForResource:name ofType:extension]; - if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; - return nullptr; - } - return file_path; -} - -tensorflow::Status LoadModel(NSString* file_name, NSString* file_type, - std::unique_ptr* session) { - tensorflow::SessionOptions options; - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = - tensorflow::NewSession(options, &session_pointer); - if (!session_status.ok()) { - LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; - return session_status; - } - session->reset(session_pointer); - - tensorflow::GraphDef tensorflow_graph; - - NSString* model_path = FilePathForResourceName(file_name, file_type); - if (!model_path) { - LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] - << [file_type UTF8String]; - return tensorflow::errors::NotFound([file_name UTF8String], - [file_type UTF8String]); - } - const bool read_proto_succeeded = - PortableReadFileToProto([model_path UTF8String], &tensorflow_graph); - if (!read_proto_succeeded) { - LOG(ERROR) << "Failed to load model proto from" << [model_path UTF8String]; - return tensorflow::errors::NotFound([model_path UTF8String]); - } - - tensorflow::Status create_status = (*session)->Create(tensorflow_graph); - if (!create_status.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; - return create_status; - } - - return tensorflow::Status::OK(); -} - -tensorflow::Status LoadMemoryMappedModel( - NSString* file_name, NSString* file_type, - std::unique_ptr* session, - std::unique_ptr* memmapped_env) { - NSString* network_path = FilePathForResourceName(file_name, file_type); - memmapped_env->reset( - new tensorflow::MemmappedEnv(tensorflow::Env::Default())); - tensorflow::Status mmap_status = - (memmapped_env->get())->InitializeFromFile([network_path UTF8String]); - if (!mmap_status.ok()) { - LOG(ERROR) << "MMap failed with " << mmap_status.error_message(); - return mmap_status; - } - - tensorflow::GraphDef tensorflow_graph; - tensorflow::Status load_graph_status = ReadBinaryProto( - memmapped_env->get(), - tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, - &tensorflow_graph); - if (!load_graph_status.ok()) { - LOG(ERROR) << "MMap load graph failed with " - << load_graph_status.error_message(); - return load_graph_status; - } - - tensorflow::SessionOptions options; - // Disable optimizations on this graph so that constant folding doesn't - // increase the memory footprint by creating new constant copies of the weight - // parameters. - options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_opt_level(::tensorflow::OptimizerOptions::L0); - options.env = memmapped_env->get(); - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = - tensorflow::NewSession(options, &session_pointer); - if (!session_status.ok()) { - LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; - return session_status; - } - - tensorflow::Status create_status = session_pointer->Create(tensorflow_graph); - if (!create_status.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; - return create_status; - } - - session->reset(session_pointer); - - return tensorflow::Status::OK(); -} - -tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type, - std::vector* label_strings) { - // Read the label list - NSString* labels_path = FilePathForResourceName(file_name, file_type); - if (!labels_path) { - LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] - << [file_type UTF8String]; - return tensorflow::errors::NotFound([file_name UTF8String], - [file_type UTF8String]); - } - std::ifstream t; - t.open([labels_path UTF8String]); - std::string line; - while (t) { - std::getline(t, line); - label_strings->push_back(line); - } - t.close(); - return tensorflow::Status::OK(); -} \ No newline at end of file diff --git a/tensorflow/contrib/ios_examples/simple/AppDelegate.h b/tensorflow/contrib/ios_examples/simple/AppDelegate.h deleted file mode 100644 index 75b1f1da38..0000000000 --- a/tensorflow/contrib/ios_examples/simple/AppDelegate.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface AppDelegate : UIResponder - -@property (strong, nonatomic) UIWindow *window; - -@end diff --git a/tensorflow/contrib/ios_examples/simple/AppDelegate.mm b/tensorflow/contrib/ios_examples/simple/AppDelegate.mm deleted file mode 100644 index 1e808eb976..0000000000 --- a/tensorflow/contrib/ios_examples/simple/AppDelegate.mm +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "AppDelegate.h" - -#import "RunModelViewController.h" - -@implementation AppDelegate - -- (BOOL)application:(UIApplication *)application - didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { - - UITabBarController *bar = [[UITabBarController alloc] init]; - [bar setViewControllers: - @[[[RunModelViewController alloc] init]]]; - bar.selectedIndex = 0; - self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; - self.window.rootViewController = bar; - [self.window makeKeyAndVisible]; - return YES; -} - -- (void)applicationWillResignActive:(UIApplication *)application {} - -- (void)applicationDidEnterBackground:(UIApplication *)application {} - -- (void)applicationWillEnterForeground:(UIApplication *)application {} - -- (void)applicationDidBecomeActive:(UIApplication *)application {} - -- (void)applicationWillTerminate:(UIApplication *)application {} - -@end diff --git a/tensorflow/contrib/ios_examples/simple/RunModel-Info.plist b/tensorflow/contrib/ios_examples/simple/RunModel-Info.plist deleted file mode 100644 index ca80e68091..0000000000 --- a/tensorflow/contrib/ios_examples/simple/RunModel-Info.plist +++ /dev/null @@ -1,47 +0,0 @@ - - - - - CFBundleDevelopmentRegion - en - CFBundleDisplayName - tf_ios_makefile_example - CFBundleExecutable - tf_ios_makefile_example - CFBundleIdentifier - Google.RunModel - CFBundleInfoDictionaryVersion - 6.0 - CFBundleName - ios-app - CFBundlePackageType - APPL - CFBundleShortVersionString - 1.0 - CFBundleSignature - ???? - CFBundleVersion - 1.0 - LSRequiresIPhoneOS - - UILaunchStoryboardName - RunModelViewController - UIRequiredDeviceCapabilities - - armv7 - - UISupportedInterfaceOrientations - - UIInterfaceOrientationPortrait - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - UISupportedInterfaceOrientations~ipad - - UIInterfaceOrientationPortrait - UIInterfaceOrientationPortraitUpsideDown - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - - diff --git a/tensorflow/contrib/ios_examples/simple/RunModelViewController.h b/tensorflow/contrib/ios_examples/simple/RunModelViewController.h deleted file mode 100644 index 4e1a83ccf5..0000000000 --- a/tensorflow/contrib/ios_examples/simple/RunModelViewController.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface RunModelViewController : UIViewController - -- (IBAction)getUrl:(id)sender; - -@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView; -@property (weak, nonatomic) IBOutlet UITextField *urlTextField; - -@end diff --git a/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm b/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm deleted file mode 100644 index 5c121962d9..0000000000 --- a/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm +++ /dev/null @@ -1,263 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "RunModelViewController.h" - -#include -#include -#include -#include -#include -#include - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "google/protobuf/io/zero_copy_stream_impl_lite.h" -#include "google/protobuf/message_lite.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session.h" - -#include "ios_image_load.h" - -NSString* RunInferenceOnImage(); - -namespace { -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const std::string& file_name) - : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -@interface RunModelViewController () -@end - -@implementation RunModelViewController { -} - -- (IBAction)getUrl:(id)sender { - NSString* inference_result = RunInferenceOnImage(); - self.urlContentTextView.text = inference_result; -} - -@end - -// Returns the top N confidence values over threshold in the provided vector, -// sorted by confidence in descending order. -static void GetTopN( - const Eigen::TensorMap, - Eigen::Aligned>& prediction, - const int num_results, const float threshold, - std::vector >* top_results) { - // Will contain top N results in ascending order. - std::priority_queue, - std::vector >, - std::greater > > top_result_pq; - - const int count = prediction.size(); - for (int i = 0; i < count; ++i) { - const float value = prediction(i); - - // Only add it if it beats the threshold and has a chance at being in - // the top N. - if (value < threshold) { - continue; - } - - top_result_pq.push(std::pair(value, i)); - - // If at capacity, kick the smallest value out. - if (top_result_pq.size() > num_results) { - top_result_pq.pop(); - } - } - - // Copy to output vector and reverse into descending order. - while (!top_result_pq.empty()) { - top_results->push_back(top_result_pq.top()); - top_result_pq.pop(); - } - std::reverse(top_results->begin(), top_results->end()); -} - - -bool PortableReadFileToProto(const std::string& file_name, - ::google::protobuf::MessageLite* proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(file_name)); - stream.SetOwnsCopyingStream(true); - // TODO(jiayq): the following coded stream is for debugging purposes to allow - // one to parse arbitrarily large messages for MessageLite. One most likely - // doesn't want to put protobufs larger than 64MB on Android, so we should - // eventually remove this and quit loud when a large protobuf is passed in. - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -NSString* FilePathForResourceName(NSString* name, NSString* extension) { - NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; - if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; - } - return file_path; -} - -NSString* RunInferenceOnImage() { - tensorflow::SessionOptions options; - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = tensorflow::NewSession(options, &session_pointer); - if (!session_status.ok()) { - std::string status_string = session_status.ToString(); - return [NSString stringWithFormat: @"Session create failed - %s", - status_string.c_str()]; - } - std::unique_ptr session(session_pointer); - LOG(INFO) << "Session created."; - - tensorflow::GraphDef tensorflow_graph; - LOG(INFO) << "Graph created."; - - NSString* network_path = FilePathForResourceName(@"tensorflow_inception_graph", @"pb"); - PortableReadFileToProto([network_path UTF8String], &tensorflow_graph); - - LOG(INFO) << "Creating session."; - tensorflow::Status s = session->Create(tensorflow_graph); - if (!s.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << s; - return @""; - } - - // Read the label list - NSString* labels_path = FilePathForResourceName(@"imagenet_comp_graph_label_strings", @"txt"); - std::vector label_strings; - std::ifstream t; - t.open([labels_path UTF8String]); - std::string line; - while(t){ - std::getline(t, line); - label_strings.push_back(line); - } - t.close(); - - // Read the Grace Hopper image. - NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); - int image_width; - int image_height; - int image_channels; - std::vector image_data = LoadImageFromFile( - [image_path UTF8String], &image_width, &image_height, &image_channels); - const int wanted_width = 224; - const int wanted_height = 224; - const int wanted_channels = 3; - const float input_mean = 117.0f; - const float input_std = 1.0f; - assert(image_channels >= wanted_channels); - tensorflow::Tensor image_tensor( - tensorflow::DT_FLOAT, - tensorflow::TensorShape({ - 1, wanted_height, wanted_width, wanted_channels})); - auto image_tensor_mapped = image_tensor.tensor(); - tensorflow::uint8* in = image_data.data(); - tensorflow::uint8* in_end = (in + (image_height * image_width * image_channels)); - float* out = image_tensor_mapped.data(); - for (int y = 0; y < wanted_height; ++y) { - const int in_y = (y * image_height) / wanted_height; - tensorflow::uint8* in_row = in + (in_y * image_width * image_channels); - float* out_row = out + (y * wanted_width * wanted_channels); - for (int x = 0; x < wanted_width; ++x) { - const int in_x = (x * image_width) / wanted_width; - tensorflow::uint8* in_pixel = in_row + (in_x * image_channels); - float* out_pixel = out_row + (x * wanted_channels); - for (int c = 0; c < wanted_channels; ++c) { - out_pixel[c] = (in_pixel[c] - input_mean) / input_std; - } - } - } - - NSString* result = [network_path stringByAppendingString: @" - loaded!"]; - result = [NSString stringWithFormat: @"%@ - %d, %s - %dx%d", result, - label_strings.size(), label_strings[0].c_str(), image_width, image_height]; - - std::string input_layer = "input"; - std::string output_layer = "output"; - std::vector outputs; - tensorflow::Status run_status = session->Run({{input_layer, image_tensor}}, - {output_layer}, {}, &outputs); - if (!run_status.ok()) { - LOG(ERROR) << "Running model failed: " << run_status; - tensorflow::LogAllRegisteredKernels(); - result = @"Error running model"; - return result; - } - tensorflow::string status_string = run_status.ToString(); - result = [NSString stringWithFormat: @"%@ - %s", result, - status_string.c_str()]; - - tensorflow::Tensor* output = &outputs[0]; - const int kNumResults = 5; - const float kThreshold = 0.1f; - std::vector > top_results; - GetTopN(output->flat(), kNumResults, kThreshold, &top_results); - - std::stringstream ss; - ss.precision(3); - for (const auto& result : top_results) { - const float confidence = result.first; - const int index = result.second; - - ss << index << " " << confidence << " "; - - // Write out the result as a string - if (index < label_strings.size()) { - // just for safety: theoretically, the output is under 1000 unless there - // is some numerical issues leading to a wrong prediction. - ss << label_strings[index]; - } else { - ss << "Prediction: " << index; - } - - ss << "\n"; - } - - LOG(INFO) << "Predictions: " << ss.str(); - - tensorflow::string predictions = ss.str(); - result = [NSString stringWithFormat: @"%@ - %s", result, - predictions.c_str()]; - - return result; -} diff --git a/tensorflow/contrib/ios_examples/simple/RunModelViewController.xib b/tensorflow/contrib/ios_examples/simple/RunModelViewController.xib deleted file mode 100644 index 93f334b985..0000000000 --- a/tensorflow/contrib/ios_examples/simple/RunModelViewController.xib +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/contrib/ios_examples/simple/data/grace_hopper.jpg b/tensorflow/contrib/ios_examples/simple/data/grace_hopper.jpg deleted file mode 100644 index d2a427810f..0000000000 Binary files a/tensorflow/contrib/ios_examples/simple/data/grace_hopper.jpg and /dev/null differ diff --git a/tensorflow/contrib/ios_examples/simple/ios_image_load.h b/tensorflow/contrib/ios_examples/simple/ios_image_load.h deleted file mode 100644 index 0e0b771118..0000000000 --- a/tensorflow/contrib/ios_examples/simple/ios_image_load.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ -#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ - -#include - -#include "tensorflow/core/framework/types.h" - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); - -#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/ios_examples/simple/ios_image_load.mm b/tensorflow/contrib/ios_examples/simple/ios_image_load.mm deleted file mode 100644 index 64d1ea21cf..0000000000 --- a/tensorflow/contrib/ios_examples/simple/ios_image_load.mm +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ios_image_load.h" - -#include -#include -#include -#include - -#import -#import - -using tensorflow::uint8; - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { - FILE* file_handle = fopen(file_name, "rb"); - fseek(file_handle, 0, SEEK_END); - const size_t bytes_in_file = ftell(file_handle); - fseek(file_handle, 0, SEEK_SET); - std::vector file_data(bytes_in_file); - fread(file_data.data(), 1, bytes_in_file, file_handle); - fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); - - const char* suffix = strrchr(file_name, '.'); - if (!suffix || suffix == file_name) { - suffix = ""; - } - CGImageRef image; - if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else { - CFRelease(image_provider); - CFRelease(file_data_ref); - fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); - *out_width = 0; - *out_height = 0; - *out_channels = 0; - return std::vector(); - } - - const int width = (int)CGImageGetWidth(image); - const int height = (int)CGImageGetHeight(image); - const int channels = 4; - CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); - const int bytes_per_row = (width * channels); - const int bytes_in_image = (bytes_per_row * height); - std::vector result(bytes_in_image); - const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); - CGColorSpaceRelease(color_space); - CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); - CGContextRelease(context); - CFRelease(image); - CFRelease(image_provider); - CFRelease(file_data_ref); - - *out_width = width; - *out_height = height; - *out_channels = channels; - return result; -} diff --git a/tensorflow/contrib/ios_examples/simple/main.mm b/tensorflow/contrib/ios_examples/simple/main.mm deleted file mode 100644 index d70550a730..0000000000 --- a/tensorflow/contrib/ios_examples/simple/main.mm +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -int main(int argc, char * argv[]) { - @autoreleasepool { - NSString *delegateClassName = @"AppDelegate"; - return UIApplicationMain(argc, argv, nil, delegateClassName); - } -} diff --git a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj deleted file mode 100644 index 94a0037e4f..0000000000 --- a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj +++ /dev/null @@ -1,377 +0,0 @@ -// !$*UTF8*$! -{ - archiveVersion = 1; - classes = { - }; - objectVersion = 46; - objects = { - -/* Begin PBXBuildFile section */ - 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */; }; - 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 590E7D871D02091F00DF5523 /* libprotobuf.a */; }; - 5993C7741D5D4EAF0048CE6A /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */; }; - 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; - 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; - 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; - 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; }; - 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; - 59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; - 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */; }; - 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */; }; - 59A3D0141CF4E82500C4259F /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */; }; - 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 59A3D0171CF4E86100C4259F /* UIKit.framework */; }; -/* End PBXBuildFile section */ - -/* Begin PBXFileReference section */ - 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libprotobuf-lite.a"; path = "../../makefile/gen/protobuf_ios/lib/libprotobuf-lite.a"; sourceTree = ""; }; - 590E7D871D02091F00DF5523 /* libprotobuf.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libprotobuf.a; path = ../../makefile/gen/protobuf_ios/lib/libprotobuf.a; sourceTree = ""; }; - 5911579B1CF4011C00C31E3A /* tf_ios_makefile_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_ios_makefile_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; - 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; - 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; - 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; - 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = cropped_panda.jpg; sourceTree = ""; }; - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; - 59A3CFF61CF4E68100C4259F /* imagenet_2012_challenge_label_map_proto.pbtxt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_2012_challenge_label_map_proto.pbtxt; sourceTree = ""; }; - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; - 59A3CFF81CF4E68100C4259F /* LICENSE */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = LICENSE; sourceTree = ""; }; - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; - 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; - 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; - 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; - 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "RunModel-Info.plist"; sourceTree = ""; }; - 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = ""; }; - 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = ""; }; - 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = ""; }; - 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; - 59A3D0151CF4E83D00C4259F /* Foundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Foundation.framework; path = System/Library/Frameworks/Foundation.framework; sourceTree = SDKROOT; }; - 59A3D0171CF4E86100C4259F /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; -/* End PBXFileReference section */ - -/* Begin PBXFrameworksBuildPhase section */ - 591157981CF4011C00C31E3A /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - 5993C7741D5D4EAF0048CE6A /* Accelerate.framework in Frameworks */, - 590E7D8A1D0209DD00DF5523 /* libprotobuf.a in Frameworks */, - 590E7D881D02091F00DF5523 /* libprotobuf-lite.a in Frameworks */, - 59A3D0181CF4E86100C4259F /* UIKit.framework in Frameworks */, - 59A3D0141CF4E82500C4259F /* CoreGraphics.framework in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXFrameworksBuildPhase section */ - -/* Begin PBXGroup section */ - 591157921CF4011C00C31E3A = { - isa = PBXGroup; - children = ( - 5993C7731D5D4EAF0048CE6A /* Accelerate.framework */, - 590E7D861D02091F00DF5523 /* libprotobuf-lite.a */, - 590E7D871D02091F00DF5523 /* libprotobuf.a */, - 59A3D0171CF4E86100C4259F /* UIKit.framework */, - 59A3D0151CF4E83D00C4259F /* Foundation.framework */, - 59A3D0131CF4E82500C4259F /* CoreGraphics.framework */, - 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, - 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, - 59A3CFF31CF4E68100C4259F /* data */, - 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, - 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, - 59A3CFFC1CF4E68100C4259F /* main.mm */, - 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */, - 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */, - 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */, - 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */, - 5911579C1CF4011C00C31E3A /* Products */, - ); - sourceTree = ""; - }; - 5911579C1CF4011C00C31E3A /* Products */ = { - isa = PBXGroup; - children = ( - 5911579B1CF4011C00C31E3A /* tf_ios_makefile_example.app */, - ); - name = Products; - sourceTree = ""; - }; - 59A3CFF31CF4E68100C4259F /* data */ = { - isa = PBXGroup; - children = ( - 59A3CFF41CF4E68100C4259F /* cropped_panda.jpg */, - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, - 59A3CFF61CF4E68100C4259F /* imagenet_2012_challenge_label_map_proto.pbtxt */, - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */, - 59A3CFF81CF4E68100C4259F /* LICENSE */, - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */, - ); - path = data; - sourceTree = ""; - }; -/* End PBXGroup section */ - -/* Begin PBXNativeTarget section */ - 5911579A1CF4011C00C31E3A /* tf_ios_makefile_example */ = { - isa = PBXNativeTarget; - buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_ios_makefile_example" */; - buildPhases = ( - 591157971CF4011C00C31E3A /* Sources */, - 591157981CF4011C00C31E3A /* Frameworks */, - 591157991CF4011C00C31E3A /* Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = tf_ios_makefile_example; - productName = tf_ios_makefile_example; - productReference = 5911579B1CF4011C00C31E3A /* tf_ios_makefile_example.app */; - productType = "com.apple.product-type.application"; - }; -/* End PBXNativeTarget section */ - -/* Begin PBXProject section */ - 591157931CF4011C00C31E3A /* Project object */ = { - isa = PBXProject; - attributes = { - LastUpgradeCheck = 0720; - ORGANIZATIONNAME = Google; - TargetAttributes = { - 5911579A1CF4011C00C31E3A = { - CreatedOnToolsVersion = 7.2; - DevelopmentTeam = 85Z3VXS37U; - }; - }; - }; - buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_ios_makefile_example" */; - compatibilityVersion = "Xcode 3.2"; - developmentRegion = English; - hasScannedForEncodings = 0; - knownRegions = ( - en, - Base, - ); - mainGroup = 591157921CF4011C00C31E3A; - productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; - projectDirPath = ""; - projectRoot = ""; - targets = ( - 5911579A1CF4011C00C31E3A /* tf_ios_makefile_example */, - ); - }; -/* End PBXProject section */ - -/* Begin PBXResourcesBuildPhase section */ - 591157991CF4011C00C31E3A /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */, - 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */, - 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */, - 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXResourcesBuildPhase section */ - -/* Begin PBXSourcesBuildPhase section */ - 591157971CF4011C00C31E3A /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 59A3D0091CF4E68100C4259F /* main.mm in Sources */, - 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */, - 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */, - 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXSourcesBuildPhase section */ - -/* Begin XCBuildConfiguration section */ - 591157B01CF4011D00C31E3A /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_DYNAMIC_NO_PIC = NO; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - MTL_ENABLE_DEBUG_INFO = YES; - ONLY_ACTIVE_ARCH = YES; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - }; - name = Debug; - }; - 591157B11CF4011D00C31E3A /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - MTL_ENABLE_DEBUG_INFO = NO; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - VALIDATE_PRODUCT = YES; - }; - name = Release; - }; - 591157B31CF4011D00C31E3A /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - CLANG_DEBUG_INFORMATION_LEVEL = default; - CODE_SIGN_IDENTITY = "iPhone Developer"; - ENABLE_BITCODE = NO; - GCC_ENABLE_CPP_EXCEPTIONS = YES; - GCC_ENABLE_CPP_RTTI = YES; - HEADER_SEARCH_PATHS = ( - "$(SRCROOT)/../../../..", - "$(SRCROOT)/../../makefile/downloads/protobuf/src/", - "$(SRCROOT)/../../makefile/downloads", - "$(SRCROOT)/../../makefile/gen/proto", - "$(SRCROOT)/../../makefile/downloads/eigen", - ); - INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ( - "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib", - "$(SRCROOT)/../../makefile/gen/lib", - ); - OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; - OTHER_LDFLAGS = ( - "-force_load", - "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", - "-Xlinker", - "-S", - "-Xlinker", - "-x", - "-Xlinker", - "-dead_strip", - ); - PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SEPARATE_STRIP = NO; - }; - name = Debug; - }; - 591157B41CF4011D00C31E3A /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - CLANG_DEBUG_INFORMATION_LEVEL = default; - CODE_SIGN_IDENTITY = "iPhone Developer"; - ENABLE_BITCODE = NO; - GCC_ENABLE_CPP_EXCEPTIONS = YES; - GCC_ENABLE_CPP_RTTI = YES; - HEADER_SEARCH_PATHS = ( - "$(SRCROOT)/../../../..", - "$(SRCROOT)/../../makefile/downloads/protobuf/src/", - "$(SRCROOT)/../../makefile/downloads", - "$(SRCROOT)/../../makefile/gen/proto", - "$(SRCROOT)/../../makefile/downloads/eigen", - ); - INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ( - "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib", - "$(SRCROOT)/../../makefile/gen/lib", - ); - ONLY_ACTIVE_ARCH = YES; - OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; - OTHER_LDFLAGS = ( - "-force_load", - "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", - "-Xlinker", - "-S", - "-Xlinker", - "-x", - "-Xlinker", - "-dead_strip", - ); - PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SEPARATE_STRIP = NO; - }; - name = Release; - }; -/* End XCBuildConfiguration section */ - -/* Begin XCConfigurationList section */ - 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_ios_makefile_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 591157B01CF4011D00C31E3A /* Debug */, - 591157B11CF4011D00C31E3A /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; - 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_ios_makefile_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 591157B31CF4011D00C31E3A /* Debug */, - 591157B41CF4011D00C31E3A /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; -/* End XCConfigurationList section */ - }; - rootObject = 591157931CF4011C00C31E3A /* Project object */; -} diff --git a/tensorflow/examples/ios/.gitignore b/tensorflow/examples/ios/.gitignore new file mode 100644 index 0000000000..e572b3012c --- /dev/null +++ b/tensorflow/examples/ios/.gitignore @@ -0,0 +1,4 @@ +project.xcworkspace +xcuserdata +imagenet_comp_graph_label_strings.txt +tensorflow_inception_graph.pb diff --git a/tensorflow/examples/ios/README.md b/tensorflow/examples/ios/README.md new file mode 100644 index 0000000000..9832399d72 --- /dev/null +++ b/tensorflow/examples/ios/README.md @@ -0,0 +1,194 @@ +# TensorFlow iOS Examples + +This folder contains examples of how to build applications for iOS devices using TensorFlow. + +## Running the Samples using CocoaPod + - You'll need Xcode 7.3 or later. + + - There are currently three examples: simple, benchmark, and camera. For now, + you can download the sample code by cloning the main tensorflow repository + (we are planning to make the samples available as a separate repository + later). + + - From the root of the tensorflow folder, download + [Inception v1](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip), + and extract the label and graph files into the data folders inside both the + simple and camera examples: + +```bash +mkdir -p ~/graphs +curl -o ~/graphs/inception5h.zip \ + https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip \ + && unzip ~/graphs/inception5h.zip -d ~/graphs/inception5h +cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/benchmark/data/ +cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/camera/data/ +cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/simple/data/ +``` + + - Change directory to one of the samples, download the TensorFlow-experimental + pod, and open the Xcode workspace. Observe: installing the pod can take a + long time since it is big (~450MB). For example, if you want to run the + simple example, then: +```bash +cd tensorflow/contrib/ios_examples/simple +pod install +open tf_simple_example.xcworkspace # obs, not the .xcodeproj directory +``` + + - Run the simple app in the simulator. You should see a single-screen app with + a "Run Model" button. Tap that, and you should see some debug output appear + below indicating that the example Grace Hopper image in directory data has + been analyzed, with a military uniform recognized. + + - Run the other samples using the same process. The camera example requires a + real device connected. Once you build and run that, you should get a live + camera view that you can point at objects to get real-time recognition + results. + +### Troubleshooting + + - Make sure you use the TensorFlow-experimental pod (and not TensorFlow). + + - The TensorFlow-experimental pod is current about ~450MB. The reason it is + so big is because we are bundling multiple platforms, and the pod includes + all TensorFlow functionality (e.g. operations). This is convenient during + development, but see below section on how you can build your own custom + TensorFlow library to reduce the size. + +### Creating Your own App + + - Create your own app using Xcode then add a file named Podfile at the project + root directory with the following content: +```bash +target 'YourProjectName' + pod 'TensorFlow-experimental' +``` + + - Then you run ```pod install``` to download and install the + TensorFlow-experimental pod, and finaly perform + ```open YourProjectName.xcworkspace``` and add your code. + + - In your apps "Build Settings", make sure to add $(inherited) to sections + "Other Linker Flags", and "Header Search Paths". + + - That's it. If you want to create your custom TensorFlow iOS library, for + example to reduce binary footprint, see below section. + +## Building the TensorFlow iOS libraries from source + + - You'll need Xcode 7.3 or later, with the command-line tools installed. + + - Follow the instructions at + [tensorflow/contrib/makefile](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile) + under "iOS" to compile a static library containing the core TensorFlow code. + + - You should see a single-screen app with a "Run Model" button. Tap that, and + you should see some debug output appear below indicating that the example + Grace Hopper image has been analyzed, with a military uniform recognized. + + - Once you have success there, make sure you have a real device connected and + open up the Xcode project in the `camera` subfolder. Once you build and run + that, you should get a live camera view that you can point at objects to get + real-time recognition results. + +### Troubleshooting + +If you're hitting problems, here's a checklist of common things to investigate: + + - Make sure that you've run the `build_all_ios.sh` script. + This will run `download_dependencies.sh`,`compile_ios_protobuf.sh` and `compile_ios_tensorflow.sh`. + (check each one if they have run successful.) + + - Check that you have version 7.3 of Xcode. + + - If there's a complaint about no Sessions registered, that means that the C++ + global constructors that TensorFlow relies on for registration haven't been + linked in properly. You'll have to make sure your project uses force_load, as + described below. + +### Creating your Own App from your source libraries + +You'll need to update various settings in your app to link against +TensorFlow. You can view them in the example projects, but here's a full +rundown: + + - The `compile_ios_tensorflow.sh` script builds a universal static library in + `tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a`. You'll need to add + this to your linking build stage, and in Search Paths add + `tensorflow/contrib/makefile/gen/lib` to the Library Search Paths setting. + + - You'll also need to add `libprotobuf.a` and `libprotobuf-lite.a` from + `tensorflow/contrib/makefile/gen/protobuf_ios/lib` to your _Build Stages_ and + _Library Search Paths_. + + - The _Header Search_ paths needs to contain: + - the root folder of tensorflow, + - `tensorflow/contrib/makefile/downloads/protobuf/src` + - `tensorflow/contrib/makefile/downloads`, + - `tensorflow/contrib/makefile/downloads/eigen`, and + - `tensorflow/contrib/makefile/gen/proto`. + + - In the Linking section, you need to add `-force_load` followed by the path to + the TensorFlow static library in the _Other Linker_ Flags section. This ensures + that the global C++ objects that are used to register important classes + inside the library are not stripped out. To the linker, they can appear + unused because no other code references the variables, but in fact their + constructors have the important side effect of registering the class. + + - You'll need to include the Accelerate framework in the "Link Binary with + Libraries" build phase of your project. + + - C++11 support (or later) should be enabled by setting `C++ Language Dialect` to + `GNU++11` (or `GNU++14`), and `C++ Standard Library` to `libc++`. + + - The library doesn't currently support bitcode, so you'll need to disable that + in your project settings. + + - Remove any use of the `-all_load` flag in your project. The protocol buffers + libraries (full and lite versions) contain duplicate symbols, and the `-all_load` + flag will cause these duplicates to become link errors. If you were using + `-all_load` to avoid issues with Objective-C categories in static libraries, + you may be able to replace it with the `-ObjC` flag. + +### Reducing the binary size + +TensorFlow is a comparatively large library for a mobile device, so it will +increase the size of your app. Currently on iOS we see around a 11 MB binary +footprint per CPU architecture, though we're actively working on reducing that. +It can be tricky to set up the right configuration in your own app to keep the +size minimized, so if you do run into this issue we recommend you start by +looking at the simple example to examine its size. Here's how you do that: + + - Open the Xcode project in tensorflow/contrib/ios_examples/simple. + + - Make sure you've followed the steps above to get the data files. + + - Choose "Generic iOS Device" as the build configuration. + + - Select Product->Build. + + - Once the build's complete, open the Report Navigator and select the logs. + + - Near the bottom, you'll see a line saying "Touch tf_simple_example.app". + + - Expand that line using the icon on the right, and copy the first argument to + the Touch command. + + - Go to the terminal, type `ls -lah ` and then paste the path you copied. + + - For example it might look like `ls -lah /Users/petewarden/Library/Developer/Xcode/DerivedData/tf_simple_example-etdbksqytcnzeyfgdwiihzkqpxwr/Build/Products/Debug-iphoneos/tf_simple_example.app` + + - Running this command will show the size of the executable as the + `tf_simple_example` line. + +Right now you'll see a size of around 23 MB, since it's including two +architectures (armv7 and arm64). As a first step, you should make sure the size +increase you see in your own app is similar, and if it's larger, look at the +"Other Linker Flags" used in the Simple Xcode project settings to strip the +executable. + +After that, you can manually look at modifying the list of kernels +included in tensorflow/contrib/makefile/tf_op_files.txt to reduce the number of +implementations to the ones you're actually using in your own model. We're +hoping to automate this step in the future, but for now manually removing them +is the best approach. diff --git a/tensorflow/examples/ios/benchmark/AppDelegate.h b/tensorflow/examples/ios/benchmark/AppDelegate.h new file mode 100644 index 0000000000..94046d9728 --- /dev/null +++ b/tensorflow/examples/ios/benchmark/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/examples/ios/benchmark/AppDelegate.mm b/tensorflow/examples/ios/benchmark/AppDelegate.mm new file mode 100644 index 0000000000..23ffba0f7b --- /dev/null +++ b/tensorflow/examples/ios/benchmark/AppDelegate.mm @@ -0,0 +1,44 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +#import "BenchmarkViewController.h" + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + + UITabBarController *bar = [[UITabBarController alloc] init]; + [bar setViewControllers: + @[[[BenchmarkViewController alloc] init]]]; + bar.selectedIndex = 0; + self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; + self.window.rootViewController = bar; + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application {} + +- (void)applicationDidEnterBackground:(UIApplication *)application {} + +- (void)applicationWillEnterForeground:(UIApplication *)application {} + +- (void)applicationDidBecomeActive:(UIApplication *)application {} + +- (void)applicationWillTerminate:(UIApplication *)application {} + +@end diff --git a/tensorflow/examples/ios/benchmark/Benchmark-Info.plist b/tensorflow/examples/ios/benchmark/Benchmark-Info.plist new file mode 100644 index 0000000000..0cdbf28a31 --- /dev/null +++ b/tensorflow/examples/ios/benchmark/Benchmark-Info.plist @@ -0,0 +1,47 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tf_benchmark_example + CFBundleExecutable + tf_benchmark_example + CFBundleIdentifier + com.google.tf_benchmark_example + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ios-app + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + UILaunchStoryboardName + BenchmarkViewController + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/examples/ios/benchmark/BenchmarkViewController.h b/tensorflow/examples/ios/benchmark/BenchmarkViewController.h new file mode 100644 index 0000000000..c9cbc49280 --- /dev/null +++ b/tensorflow/examples/ios/benchmark/BenchmarkViewController.h @@ -0,0 +1,24 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface BenchmarkViewController : UIViewController + +- (IBAction)getUrl:(id)sender; + +@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property(weak, nonatomic) IBOutlet UITextField *urlTextField; + +@end diff --git a/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm b/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm new file mode 100644 index 0000000000..cab7b36f17 --- /dev/null +++ b/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm @@ -0,0 +1,302 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "BenchmarkViewController.h" + +#include +#include +#include +#include +#include +#include +#include + +//#include "google/protobuf/io/coded_stream.h" +//#include "google/protobuf/io/zero_copy_stream_impl.h" +//#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +//#include "google/protobuf/message_lite.h" +#include "tensorflow/core/framework/op_kernel.h" +//#include "tensorflow/core/framework/tensor.h" +//#include "tensorflow/core/framework/types.pb.h" +//#include "tensorflow/core/platform/env.h" +//#include "tensorflow/core/platform/logging.h" +//#include "tensorflow/core/platform/mutex.h" +//#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/stat_summarizer.h" + +#include "ios_image_load.h" + +NSString* RunInferenceOnImage(); + +namespace { +class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { + public: + explicit IfstreamInputStream(const std::string& file_name) + : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} + ~IfstreamInputStream() { ifs_.close(); } + + int Read(void* buffer, int size) { + if (!ifs_) { + return -1; + } + ifs_.read(static_cast(buffer), size); + return (int)ifs_.gcount(); + } + + private: + std::ifstream ifs_; +}; +} // namespace + +@interface BenchmarkViewController () +@end + +@implementation BenchmarkViewController { +} + +- (IBAction)getUrl:(id)sender { + NSString* inference_result = RunInferenceOnImage(); + self.urlContentTextView.text = inference_result; +} + +@end + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +static void GetTopN( + const Eigen::TensorMap, + Eigen::Aligned>& prediction, + const int num_results, const float threshold, + std::vector>* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, std::vector>, + std::greater>> + top_result_pq; + + long count = prediction.size(); + for (int i = 0; i < count; ++i) { + const float value = prediction(i); + + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +bool PortableReadFileToProto(const std::string& file_name, + ::google::protobuf::MessageLite* proto) { + ::google::protobuf::io::CopyingInputStreamAdaptor stream( + new IfstreamInputStream(file_name)); + stream.SetOwnsCopyingStream(true); + // TODO(jiayq): the following coded stream is for debugging purposes to allow + // one to parse arbitrarily large messages for MessageLite. One most likely + // doesn't want to put protobufs larger than 64MB on Android, so we should + // eventually remove this and quit loud when a large protobuf is passed in. + ::google::protobuf::io::CodedInputStream coded_stream(&stream); + // Total bytes hard limit / warning limit are set to 1GB and 512MB + // respectively. + coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); + return proto->ParseFromCodedStream(&coded_stream); +} + +NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = + [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." + << [extension UTF8String] << "' in bundle."; + } + return file_path; +} + +// A utility function to get the current time in seconds, for simple profiling. +double time() { + timeval t; + gettimeofday(&t, nullptr); + return t.tv_sec + 1e-6 * t.tv_usec; +} + +// Runs the session with profiling enabled, and prints out details of the time +// that each node in the graph takes to the debug log. +tensorflow::Status BenchmarkInference( + tensorflow::Session* session, + const std::vector> inputs, + const std::vector& output_layer_names, + std::vector* output_layers, + tensorflow::StatSummarizer* stat_summarizer, double* average_time) { + tensorflow::Status run_status; + const int iterations_count = 20; + double total_time = 0.0; + tensorflow::RunOptions run_options; + run_options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); + tensorflow::RunMetadata run_metadata; + for (int iteration = 0; iteration < (iterations_count + 1); ++iteration) { + const double start_time = time(); + run_status = session->Run(run_options, inputs, output_layer_names, {}, + output_layers, &run_metadata); + const double end_time = time(); + if (iteration != 0) { + total_time += end_time - start_time; + } + if (!run_status.ok()) { + LOG(ERROR) << "Running model failed: " << run_status; + tensorflow::LogAllRegisteredKernels(); + return run_status; + } + } + assert(run_metadata.has_step_stats()); + const tensorflow::StepStats& step_stats = run_metadata.step_stats(); + stat_summarizer->ProcessStepStats(step_stats); + stat_summarizer->PrintStepStats(); + + *average_time = total_time / iterations_count; + NSLog(@"Took %f seconds", *average_time); + + return tensorflow::Status::OK(); +} + +NSString* RunInferenceOnImage() { + tensorflow::SessionOptions options; + + tensorflow::Session* session_pointer = nullptr; + tensorflow::Status session_status = + tensorflow::NewSession(options, &session_pointer); + if (!session_status.ok()) { + std::string status_string = session_status.ToString(); + return [NSString + stringWithFormat:@"Session create failed - %s", status_string.c_str()]; + } + std::unique_ptr session(session_pointer); + LOG(INFO) << "Session created."; + + tensorflow::GraphDef tensorflow_graph; + LOG(INFO) << "Graph created."; + + NSString* network_path = + FilePathForResourceName(@"tensorflow_inception_graph", @"pb"); + PortableReadFileToProto([network_path UTF8String], &tensorflow_graph); + + LOG(INFO) << "Creating session."; + tensorflow::Status s = session->Create(tensorflow_graph); + if (!s.ok()) { + LOG(ERROR) << "Could not create TensorFlow Graph: " << s; + return @""; + } + + // Read the label list + NSString* labels_path = + FilePathForResourceName(@"imagenet_comp_graph_label_strings", @"txt"); + std::vector label_strings; + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while (t) { + std::getline(t, line); + label_strings.push_back(line); + } + t.close(); + + // Read the Grace Hopper image. + NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); + int image_width; + int image_height; + int image_channels; + std::vector image_data = LoadImageFromFile( + [image_path UTF8String], &image_width, &image_height, &image_channels); + const int wanted_width = 224; + const int wanted_height = 224; + const int wanted_channels = 3; + const float input_mean = 117.0f; + const float input_std = 1.0f; + assert(image_channels >= wanted_channels); + tensorflow::Tensor image_tensor( + tensorflow::DT_FLOAT, + tensorflow::TensorShape( + {1, wanted_height, wanted_width, wanted_channels})); + auto image_tensor_mapped = image_tensor.tensor(); + tensorflow::uint8* in = image_data.data(); + float* out = image_tensor_mapped.data(); + for (int y = 0; y < wanted_height; ++y) { + const int in_y = (y * image_height) / wanted_height; + tensorflow::uint8* in_row = in + (in_y * image_width * image_channels); + float* out_row = out + (y * wanted_width * wanted_channels); + for (int x = 0; x < wanted_width; ++x) { + const int in_x = (x * image_width) / wanted_width; + tensorflow::uint8* in_pixel = in_row + (in_x * image_channels); + float* out_pixel = out_row + (x * wanted_channels); + for (int c = 0; c < wanted_channels; ++c) { + out_pixel[c] = (in_pixel[c] - input_mean) / input_std; + } + } + } + tensorflow::string input_layer = "input"; + tensorflow::string output_layer = "output"; + std::vector outputs; + tensorflow::StatSummarizer stat_summarizer(tensorflow_graph); + double average_time = 0.0; + BenchmarkInference(session.get(), {{input_layer, image_tensor}}, + {output_layer}, &outputs, &stat_summarizer, &average_time); + NSString* result = + [NSString stringWithFormat:@"Average time: %.4f seconds \n\n", average_time]; + + tensorflow::Tensor* output = &outputs[0]; + const int kNumResults = 5; + const float kThreshold = 0.1f; + std::vector> top_results; + GetTopN(output->flat(), kNumResults, kThreshold, &top_results); + + std::stringstream ss; + ss.precision(3); + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + + ss << index << " " << confidence << " "; + + // Write out the result as a string + if (index < label_strings.size()) { + // just for safety: theoretically, the output is under 1000 unless there + // is some numerical issues leading to a wrong prediction. + ss << label_strings[index]; + } else { + ss << "Prediction: " << index; + } + + ss << "\n"; + } + + LOG(INFO) << "Predictions: " << ss.str(); + + tensorflow::string predictions = ss.str(); + result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; + + return result; +} diff --git a/tensorflow/examples/ios/benchmark/BenchmarkViewController.xib b/tensorflow/examples/ios/benchmark/BenchmarkViewController.xib new file mode 100644 index 0000000000..56c3708062 --- /dev/null +++ b/tensorflow/examples/ios/benchmark/BenchmarkViewController.xib @@ -0,0 +1,47 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/examples/ios/benchmark/Podfile b/tensorflow/examples/ios/benchmark/Podfile new file mode 100644 index 0000000000..e163d56e8d --- /dev/null +++ b/tensorflow/examples/ios/benchmark/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tf_benchmark_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/examples/ios/benchmark/data/grace_hopper.jpg b/tensorflow/examples/ios/benchmark/data/grace_hopper.jpg new file mode 100644 index 0000000000..d2a427810f Binary files /dev/null and b/tensorflow/examples/ios/benchmark/data/grace_hopper.jpg differ diff --git a/tensorflow/examples/ios/benchmark/ios_image_load.h b/tensorflow/examples/ios/benchmark/ios_image_load.h new file mode 100644 index 0000000000..78eaded8d7 --- /dev/null +++ b/tensorflow/examples/ios/benchmark/ios_image_load.h @@ -0,0 +1,27 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ +#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ + +#include + +#include "tensorflow/core/framework/types.h" + +std::vector LoadImageFromFile(const char* file_name, + int* out_width, + int* out_height, + int* out_channels); + +#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/examples/ios/benchmark/ios_image_load.mm b/tensorflow/examples/ios/benchmark/ios_image_load.mm new file mode 100644 index 0000000000..64d1ea21cf --- /dev/null +++ b/tensorflow/examples/ios/benchmark/ios_image_load.mm @@ -0,0 +1,87 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ios_image_load.h" + +#include +#include +#include +#include + +#import +#import + +using tensorflow::uint8; + +std::vector LoadImageFromFile(const char* file_name, + int* out_width, int* out_height, + int* out_channels) { + FILE* file_handle = fopen(file_name, "rb"); + fseek(file_handle, 0, SEEK_END); + const size_t bytes_in_file = ftell(file_handle); + fseek(file_handle, 0, SEEK_SET); + std::vector file_data(bytes_in_file); + fread(file_data.data(), 1, bytes_in_file, file_handle); + fclose(file_handle); + CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), + bytes_in_file, + kCFAllocatorNull); + CGDataProviderRef image_provider = + CGDataProviderCreateWithCFData(file_data_ref); + + const char* suffix = strrchr(file_name, '.'); + if (!suffix || suffix == file_name) { + suffix = ""; + } + CGImageRef image; + if (strcasecmp(suffix, ".png") == 0) { + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, + kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || + (strcasecmp(suffix, ".jpeg") == 0)) { + image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, + kCGRenderingIntentDefault); + } else { + CFRelease(image_provider); + CFRelease(file_data_ref); + fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); + *out_width = 0; + *out_height = 0; + *out_channels = 0; + return std::vector(); + } + + const int width = (int)CGImageGetWidth(image); + const int height = (int)CGImageGetHeight(image); + const int channels = 4; + CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); + const int bytes_per_row = (width * channels); + const int bytes_in_image = (bytes_per_row * height); + std::vector result(bytes_in_image); + const int bits_per_component = 8; + CGContextRef context = CGBitmapContextCreate(result.data(), width, height, + bits_per_component, bytes_per_row, color_space, + kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGColorSpaceRelease(color_space); + CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); + CGContextRelease(context); + CFRelease(image); + CFRelease(image_provider); + CFRelease(file_data_ref); + + *out_width = width; + *out_height = height; + *out_channels = channels; + return result; +} diff --git a/tensorflow/examples/ios/benchmark/main.mm b/tensorflow/examples/ios/benchmark/main.mm new file mode 100644 index 0000000000..d70550a730 --- /dev/null +++ b/tensorflow/examples/ios/benchmark/main.mm @@ -0,0 +1,22 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +int main(int argc, char * argv[]) { + @autoreleasepool { + NSString *delegateClassName = @"AppDelegate"; + return UIApplicationMain(argc, argv, nil, delegateClassName); + } +} diff --git a/tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj b/tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj new file mode 100644 index 0000000000..d61b65ba61 --- /dev/null +++ b/tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj @@ -0,0 +1,388 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C8BA8FD1EC682E700CCCC8C /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; + 1C8BA8FE1EC682E700CCCC8C /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; + 1C8BA8FF1EC682E700CCCC8C /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */; }; + 1C8BA9001EC682E700CCCC8C /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; + 1C8BA9051EC682E700CCCC8C /* BenchmarkViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */; }; + 1C8BA9061EC682E700CCCC8C /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; + 1C8BA9071EC682E700CCCC8C /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; }; + 1C8BA9081EC682E700CCCC8C /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; + 1CB1883E1ECCC0DC00C93EF7 /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */; }; + 1CB1883F1ECCC10D00C93EF7 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */; }; + 1E0EBA4DF4C722C63814B257 /* libPods-tf_benchmark_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_benchmark_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 1CB1883B1ECCC09A00C93EF7 /* CoreFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreFoundation.framework; path = System/Library/Frameworks/CoreFoundation.framework; sourceTree = SDKROOT; }; + 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; + 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; + 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; + 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 59A3CFFD1CF4E68100C4259F /* Benchmark-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "Benchmark-Info.plist"; sourceTree = ""; }; + 59A3CFFE1CF4E68100C4259F /* BenchmarkViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = BenchmarkViewController.h; sourceTree = ""; }; + 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = ""; }; + 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = BenchmarkViewController.xib; sourceTree = ""; }; + 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_benchmark_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example.debug.xcconfig"; sourceTree = ""; }; + 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_benchmark_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_benchmark_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example.release.xcconfig"; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 1C8BA9011EC682E700CCCC8C /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 1CB1883F1ECCC10D00C93EF7 /* UIKit.framework in Frameworks */, + 1CB1883E1ECCC0DC00C93EF7 /* CoreGraphics.framework in Frameworks */, + 1E0EBA4DF4C722C63814B257 /* libPods-tf_benchmark_example.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 2BD56010B574F539C2070A57 /* Pods */ = { + isa = PBXGroup; + children = ( + 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */, + DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */, + ); + name = Pods; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, + 59A3CFFC1CF4E68100C4259F /* main.mm */, + 59A3CFFD1CF4E68100C4259F /* Benchmark-Info.plist */, + 59A3CFFE1CF4E68100C4259F /* BenchmarkViewController.h */, + 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */, + 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */, + 5911579C1CF4011C00C31E3A /* Products */, + 2BD56010B574F539C2070A57 /* Pods */, + 76A25A27041EB307BDFF0DD1 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, + 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */, + 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */, + ); + path = data; + sourceTree = ""; + }; + 76A25A27041EB307BDFF0DD1 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */, + 1CB1883B1ECCC09A00C93EF7 /* CoreFoundation.framework */, + 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */, + 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 1C8BA8FB1EC682E700CCCC8C /* tf_benchmark_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 1C8BA9091EC682E700CCCC8C /* Build configuration list for PBXNativeTarget "tf_benchmark_example" */; + buildPhases = ( + 0388D751057A257A12848245 /* [CP] Check Pods Manifest.lock */, + 1C8BA8FC1EC682E700CCCC8C /* Sources */, + 1C8BA9011EC682E700CCCC8C /* Frameworks */, + 1C8BA9041EC682E700CCCC8C /* Resources */, + 8999A303091D4E86202C2F64 /* [CP] Embed Pods Frameworks */, + A7B4B278BCC417B76A47ABB0 /* [CP] Copy Pods Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tf_benchmark_example; + productName = benchmark; + productReference = 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_benchmark_example" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 1C8BA8FB1EC682E700CCCC8C /* tf_benchmark_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 1C8BA9041EC682E700CCCC8C /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 1C8BA9051EC682E700CCCC8C /* BenchmarkViewController.xib in Resources */, + 1C8BA9061EC682E700CCCC8C /* imagenet_comp_graph_label_strings.txt in Resources */, + 1C8BA9071EC682E700CCCC8C /* tensorflow_inception_graph.pb in Resources */, + 1C8BA9081EC682E700CCCC8C /* grace_hopper.jpg in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 0388D751057A257A12848245 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Check Pods Manifest.lock"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n"; + showEnvVarsInLog = 0; + }; + 8999A303091D4E86202C2F64 /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + A7B4B278BCC417B76A47ABB0 /* [CP] Copy Pods Resources */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Copy Pods Resources"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example-resources.sh\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 1C8BA8FC1EC682E700CCCC8C /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 1C8BA8FD1EC682E700CCCC8C /* main.mm in Sources */, + 1C8BA8FE1EC682E700CCCC8C /* AppDelegate.mm in Sources */, + 1C8BA8FF1EC682E700CCCC8C /* BenchmarkViewController.mm in Sources */, + 1C8BA9001EC682E700CCCC8C /* ios_image_load.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 1C8BA90A1EC682E700CCCC8C /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */; + buildSettings = { + CODE_SIGN_IDENTITY = "iPhone Developer"; + ENABLE_BITCODE = NO; + HEADER_SEARCH_PATHS = "$(inherited)"; + INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ""; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-benchmark-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + }; + name = Debug; + }; + 1C8BA90B1EC682E700CCCC8C /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */; + buildSettings = { + CODE_SIGN_IDENTITY = "iPhone Developer"; + ENABLE_BITCODE = NO; + HEADER_SEARCH_PATHS = "$(inherited)"; + INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ""; + ONLY_ACTIVE_ARCH = YES; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-benchmark-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + }; + name = Release; + }; + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 1C8BA9091EC682E700CCCC8C /* Build configuration list for PBXNativeTarget "tf_benchmark_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 1C8BA90A1EC682E700CCCC8C /* Debug */, + 1C8BA90B1EC682E700CCCC8C /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_benchmark_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} diff --git a/tensorflow/examples/ios/camera/CameraExampleAppDelegate.h b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.h new file mode 100644 index 0000000000..0039d5e7ca --- /dev/null +++ b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface CameraExampleAppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/examples/ios/camera/CameraExampleAppDelegate.m b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.m new file mode 100644 index 0000000000..d134c2b591 --- /dev/null +++ b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.m @@ -0,0 +1,44 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "CameraExampleAppDelegate.h" + +@implementation CameraExampleAppDelegate + +@synthesize window = _window; + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + [[UIApplication sharedApplication] setIdleTimerDisabled:NO]; +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + [[UIApplication sharedApplication] setIdleTimerDisabled:YES]; +} + +- (void)applicationWillTerminate:(UIApplication *)application { +} + +@end diff --git a/tensorflow/examples/ios/camera/CameraExampleViewController.h b/tensorflow/examples/ios/camera/CameraExampleViewController.h new file mode 100644 index 0000000000..0aefbc6eed --- /dev/null +++ b/tensorflow/examples/ios/camera/CameraExampleViewController.h @@ -0,0 +1,47 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/memmapped_file_system.h" + +@interface CameraExampleViewController + : UIViewController { + IBOutlet UIView *previewView; + IBOutlet UISegmentedControl *camerasControl; + AVCaptureVideoPreviewLayer *previewLayer; + AVCaptureVideoDataOutput *videoDataOutput; + dispatch_queue_t videoDataOutputQueue; + AVCaptureStillImageOutput *stillImageOutput; + UIView *flashView; + UIImage *square; + BOOL isUsingFrontFacingCamera; + AVSpeechSynthesizer *synth; + NSMutableDictionary *oldPredictionValues; + NSMutableArray *labelLayers; + AVCaptureSession *session; + std::unique_ptr tf_session; + std::unique_ptr tf_memmapped_env; + std::vector labels; +} +@property(strong, nonatomic) CATextLayer *predictionTextLayer; + +- (IBAction)takePicture:(id)sender; +- (IBAction)switchCameras:(id)sender; + +@end diff --git a/tensorflow/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/examples/ios/camera/CameraExampleViewController.mm new file mode 100644 index 0000000000..d113d50ff8 --- /dev/null +++ b/tensorflow/examples/ios/camera/CameraExampleViewController.mm @@ -0,0 +1,621 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import +#import +#import +#import "CameraExampleViewController.h" + +#include + +#include "tensorflow_utils.h" + +// If you have your own model, modify this to the file name, and make sure +// you've added the file to your app resources too. +static NSString* model_file_name = @"tensorflow_inception_graph"; +static NSString* model_file_type = @"pb"; +// This controls whether we'll be loading a plain GraphDef proto, or a +// file created by the convert_graphdef_memmapped_format utility that wraps a +// GraphDef and parameter file that can be mapped into memory from file to +// reduce overall memory usage. +const bool model_uses_memory_mapping = false; +// If you have your own model, point this to the labels file. +static NSString* labels_file_name = @"imagenet_comp_graph_label_strings"; +static NSString* labels_file_type = @"txt"; +// These dimensions need to match those the model was trained with. +const int wanted_input_width = 224; +const int wanted_input_height = 224; +const int wanted_input_channels = 3; +const float input_mean = 117.0f; +const float input_std = 1.0f; +const std::string input_layer_name = "input"; +const std::string output_layer_name = "softmax1"; + +static void *AVCaptureStillImageIsCapturingStillImageContext = + &AVCaptureStillImageIsCapturingStillImageContext; + +@interface CameraExampleViewController (InternalMethods) +- (void)setupAVCapture; +- (void)teardownAVCapture; +@end + +@implementation CameraExampleViewController + +- (void)setupAVCapture { + NSError *error = nil; + + session = [AVCaptureSession new]; + if ([[UIDevice currentDevice] userInterfaceIdiom] == + UIUserInterfaceIdiomPhone) + [session setSessionPreset:AVCaptureSessionPreset640x480]; + else + [session setSessionPreset:AVCaptureSessionPresetPhoto]; + + AVCaptureDevice *device = + [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; + AVCaptureDeviceInput *deviceInput = + [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; + assert(error == nil); + + isUsingFrontFacingCamera = NO; + if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; + + stillImageOutput = [AVCaptureStillImageOutput new]; + [stillImageOutput + addObserver:self + forKeyPath:@"capturingStillImage" + options:NSKeyValueObservingOptionNew + context:(void *)(AVCaptureStillImageIsCapturingStillImageContext)]; + if ([session canAddOutput:stillImageOutput]) + [session addOutput:stillImageOutput]; + + videoDataOutput = [AVCaptureVideoDataOutput new]; + + NSDictionary *rgbOutputSettings = [NSDictionary + dictionaryWithObject:[NSNumber numberWithInt:kCMPixelFormat_32BGRA] + forKey:(id)kCVPixelBufferPixelFormatTypeKey]; + [videoDataOutput setVideoSettings:rgbOutputSettings]; + [videoDataOutput setAlwaysDiscardsLateVideoFrames:YES]; + videoDataOutputQueue = + dispatch_queue_create("VideoDataOutputQueue", DISPATCH_QUEUE_SERIAL); + [videoDataOutput setSampleBufferDelegate:self queue:videoDataOutputQueue]; + + if ([session canAddOutput:videoDataOutput]) + [session addOutput:videoDataOutput]; + [[videoDataOutput connectionWithMediaType:AVMediaTypeVideo] setEnabled:YES]; + + previewLayer = [[AVCaptureVideoPreviewLayer alloc] initWithSession:session]; + [previewLayer setBackgroundColor:[[UIColor blackColor] CGColor]]; + [previewLayer setVideoGravity:AVLayerVideoGravityResizeAspect]; + CALayer *rootLayer = [previewView layer]; + [rootLayer setMasksToBounds:YES]; + [previewLayer setFrame:[rootLayer bounds]]; + [rootLayer addSublayer:previewLayer]; + [session startRunning]; + + if (error) { + NSString *title = [NSString stringWithFormat:@"Failed with error %d", (int)[error code]]; + UIAlertController *alertController = + [UIAlertController alertControllerWithTitle:title + message:[error localizedDescription] + preferredStyle:UIAlertControllerStyleAlert]; + UIAlertAction *dismiss = + [UIAlertAction actionWithTitle:@"Dismiss" style:UIAlertActionStyleDefault handler:nil]; + [alertController addAction:dismiss]; + [self presentViewController:alertController animated:YES completion:nil]; + [self teardownAVCapture]; + } +} + +- (void)teardownAVCapture { + [stillImageOutput removeObserver:self forKeyPath:@"isCapturingStillImage"]; + [previewLayer removeFromSuperlayer]; +} + +- (void)observeValueForKeyPath:(NSString *)keyPath + ofObject:(id)object + change:(NSDictionary *)change + context:(void *)context { + if (context == AVCaptureStillImageIsCapturingStillImageContext) { + BOOL isCapturingStillImage = + [[change objectForKey:NSKeyValueChangeNewKey] boolValue]; + + if (isCapturingStillImage) { + // do flash bulb like animation + flashView = [[UIView alloc] initWithFrame:[previewView frame]]; + [flashView setBackgroundColor:[UIColor whiteColor]]; + [flashView setAlpha:0.f]; + [[[self view] window] addSubview:flashView]; + + [UIView animateWithDuration:.4f + animations:^{ + [flashView setAlpha:1.f]; + }]; + } else { + [UIView animateWithDuration:.4f + animations:^{ + [flashView setAlpha:0.f]; + } + completion:^(BOOL finished) { + [flashView removeFromSuperview]; + flashView = nil; + }]; + } + } +} + +- (AVCaptureVideoOrientation)avOrientationForDeviceOrientation: + (UIDeviceOrientation)deviceOrientation { + AVCaptureVideoOrientation result = + (AVCaptureVideoOrientation)(deviceOrientation); + if (deviceOrientation == UIDeviceOrientationLandscapeLeft) + result = AVCaptureVideoOrientationLandscapeRight; + else if (deviceOrientation == UIDeviceOrientationLandscapeRight) + result = AVCaptureVideoOrientationLandscapeLeft; + return result; +} + +- (IBAction)takePicture:(id)sender { + if ([session isRunning]) { + [session stopRunning]; + [sender setTitle:@"Continue" forState:UIControlStateNormal]; + + flashView = [[UIView alloc] initWithFrame:[previewView frame]]; + [flashView setBackgroundColor:[UIColor whiteColor]]; + [flashView setAlpha:0.f]; + [[[self view] window] addSubview:flashView]; + + [UIView animateWithDuration:.2f + animations:^{ + [flashView setAlpha:1.f]; + } + completion:^(BOOL finished) { + [UIView animateWithDuration:.2f + animations:^{ + [flashView setAlpha:0.f]; + } + completion:^(BOOL finished) { + [flashView removeFromSuperview]; + flashView = nil; + }]; + }]; + + } else { + [session startRunning]; + [sender setTitle:@"Freeze Frame" forState:UIControlStateNormal]; + } +} + ++ (CGRect)videoPreviewBoxForGravity:(NSString *)gravity + frameSize:(CGSize)frameSize + apertureSize:(CGSize)apertureSize { + CGFloat apertureRatio = apertureSize.height / apertureSize.width; + CGFloat viewRatio = frameSize.width / frameSize.height; + + CGSize size = CGSizeZero; + if ([gravity isEqualToString:AVLayerVideoGravityResizeAspectFill]) { + if (viewRatio > apertureRatio) { + size.width = frameSize.width; + size.height = + apertureSize.width * (frameSize.width / apertureSize.height); + } else { + size.width = + apertureSize.height * (frameSize.height / apertureSize.width); + size.height = frameSize.height; + } + } else if ([gravity isEqualToString:AVLayerVideoGravityResizeAspect]) { + if (viewRatio > apertureRatio) { + size.width = + apertureSize.height * (frameSize.height / apertureSize.width); + size.height = frameSize.height; + } else { + size.width = frameSize.width; + size.height = + apertureSize.width * (frameSize.width / apertureSize.height); + } + } else if ([gravity isEqualToString:AVLayerVideoGravityResize]) { + size.width = frameSize.width; + size.height = frameSize.height; + } + + CGRect videoBox; + videoBox.size = size; + if (size.width < frameSize.width) + videoBox.origin.x = (frameSize.width - size.width) / 2; + else + videoBox.origin.x = (size.width - frameSize.width) / 2; + + if (size.height < frameSize.height) + videoBox.origin.y = (frameSize.height - size.height) / 2; + else + videoBox.origin.y = (size.height - frameSize.height) / 2; + + return videoBox; +} + +- (void)captureOutput:(AVCaptureOutput *)captureOutput +didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer + fromConnection:(AVCaptureConnection *)connection { + CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); + CFRetain(pixelBuffer); + [self runCNNOnFrame:pixelBuffer]; + CFRelease(pixelBuffer); +} + +- (void)runCNNOnFrame:(CVPixelBufferRef)pixelBuffer { + assert(pixelBuffer != NULL); + + OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); + int doReverseChannels; + if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { + doReverseChannels = 1; + } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { + doReverseChannels = 0; + } else { + assert(false); // Unknown source format + } + + const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); + const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); + const int fullHeight = (int)CVPixelBufferGetHeight(pixelBuffer); + + CVPixelBufferLockFlags unlockFlags = kNilOptions; + CVPixelBufferLockBaseAddress(pixelBuffer, unlockFlags); + + unsigned char *sourceBaseAddr = + (unsigned char *)(CVPixelBufferGetBaseAddress(pixelBuffer)); + int image_height; + unsigned char *sourceStartAddr; + if (fullHeight <= image_width) { + image_height = fullHeight; + sourceStartAddr = sourceBaseAddr; + } else { + image_height = image_width; + const int marginY = ((fullHeight - image_width) / 2); + sourceStartAddr = (sourceBaseAddr + (marginY * sourceRowBytes)); + } + const int image_channels = 4; + + assert(image_channels >= wanted_input_channels); + tensorflow::Tensor image_tensor( + tensorflow::DT_FLOAT, + tensorflow::TensorShape( + {1, wanted_input_height, wanted_input_width, wanted_input_channels})); + auto image_tensor_mapped = image_tensor.tensor(); + tensorflow::uint8 *in = sourceStartAddr; + float *out = image_tensor_mapped.data(); + for (int y = 0; y < wanted_input_height; ++y) { + float *out_row = out + (y * wanted_input_width * wanted_input_channels); + for (int x = 0; x < wanted_input_width; ++x) { + const int in_x = (y * image_width) / wanted_input_width; + const int in_y = (x * image_height) / wanted_input_height; + tensorflow::uint8 *in_pixel = + in + (in_y * image_width * image_channels) + (in_x * image_channels); + float *out_pixel = out_row + (x * wanted_input_channels); + for (int c = 0; c < wanted_input_channels; ++c) { + out_pixel[c] = (in_pixel[c] - input_mean) / input_std; + } + } + } + + CVPixelBufferUnlockBaseAddress(pixelBuffer, unlockFlags); + + if (tf_session.get()) { + std::vector outputs; + tensorflow::Status run_status = tf_session->Run( + {{input_layer_name, image_tensor}}, {output_layer_name}, {}, &outputs); + if (!run_status.ok()) { + LOG(ERROR) << "Running model failed:" << run_status; + } else { + tensorflow::Tensor *output = &outputs[0]; + auto predictions = output->flat(); + + NSMutableDictionary *newValues = [NSMutableDictionary dictionary]; + for (int index = 0; index < predictions.size(); index += 1) { + const float predictionValue = predictions(index); + if (predictionValue > 0.05f) { + std::string label = labels[index % predictions.size()]; + NSString *labelObject = [NSString stringWithUTF8String:label.c_str()]; + NSNumber *valueObject = [NSNumber numberWithFloat:predictionValue]; + [newValues setObject:valueObject forKey:labelObject]; + } + } + dispatch_async(dispatch_get_main_queue(), ^(void) { + [self setPredictionValues:newValues]; + }); + } + } + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); +} + +- (void)dealloc { + [self teardownAVCapture]; +} + +// use front/back camera +- (IBAction)switchCameras:(id)sender { + AVCaptureDevicePosition desiredPosition; + if (isUsingFrontFacingCamera) + desiredPosition = AVCaptureDevicePositionBack; + else + desiredPosition = AVCaptureDevicePositionFront; + + for (AVCaptureDevice *d in + [AVCaptureDevice devicesWithMediaType:AVMediaTypeVideo]) { + if ([d position] == desiredPosition) { + [[previewLayer session] beginConfiguration]; + AVCaptureDeviceInput *input = + [AVCaptureDeviceInput deviceInputWithDevice:d error:nil]; + for (AVCaptureInput *oldInput in [[previewLayer session] inputs]) { + [[previewLayer session] removeInput:oldInput]; + } + [[previewLayer session] addInput:input]; + [[previewLayer session] commitConfiguration]; + break; + } + } + isUsingFrontFacingCamera = !isUsingFrontFacingCamera; +} + +- (void)didReceiveMemoryWarning { + [super didReceiveMemoryWarning]; +} + +- (void)viewDidLoad { + [super viewDidLoad]; + square = [UIImage imageNamed:@"squarePNG"]; + synth = [[AVSpeechSynthesizer alloc] init]; + labelLayers = [[NSMutableArray alloc] init]; + oldPredictionValues = [[NSMutableDictionary alloc] init]; + + tensorflow::Status load_status; + if (model_uses_memory_mapping) { + load_status = LoadMemoryMappedModel( + model_file_name, model_file_type, &tf_session, &tf_memmapped_env); + } else { + load_status = LoadModel(model_file_name, model_file_type, &tf_session); + } + if (!load_status.ok()) { + LOG(FATAL) << "Couldn't load model: " << load_status; + } + + tensorflow::Status labels_status = + LoadLabels(labels_file_name, labels_file_type, &labels); + if (!labels_status.ok()) { + LOG(FATAL) << "Couldn't load labels: " << labels_status; + } + [self setupAVCapture]; +} + +- (void)viewDidUnload { + [super viewDidUnload]; +} + +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; +} + +- (void)viewDidAppear:(BOOL)animated { + [super viewDidAppear:animated]; +} + +- (void)viewWillDisappear:(BOOL)animated { + [super viewWillDisappear:animated]; +} + +- (void)viewDidDisappear:(BOOL)animated { + [super viewDidDisappear:animated]; +} + +- (BOOL)shouldAutorotateToInterfaceOrientation: + (UIInterfaceOrientation)interfaceOrientation { + return (interfaceOrientation == UIInterfaceOrientationPortrait); +} + +- (BOOL)prefersStatusBarHidden { + return YES; +} + +- (void)setPredictionValues:(NSDictionary *)newValues { + const float decayValue = 0.75f; + const float updateValue = 0.25f; + const float minimumThreshold = 0.01f; + + NSMutableDictionary *decayedPredictionValues = + [[NSMutableDictionary alloc] init]; + for (NSString *label in oldPredictionValues) { + NSNumber *oldPredictionValueObject = + [oldPredictionValues objectForKey:label]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + const float decayedPredictionValue = (oldPredictionValue * decayValue); + if (decayedPredictionValue > minimumThreshold) { + NSNumber *decayedPredictionValueObject = + [NSNumber numberWithFloat:decayedPredictionValue]; + [decayedPredictionValues setObject:decayedPredictionValueObject + forKey:label]; + } + } + oldPredictionValues = decayedPredictionValues; + + for (NSString *label in newValues) { + NSNumber *newPredictionValueObject = [newValues objectForKey:label]; + NSNumber *oldPredictionValueObject = + [oldPredictionValues objectForKey:label]; + if (!oldPredictionValueObject) { + oldPredictionValueObject = [NSNumber numberWithFloat:0.0f]; + } + const float newPredictionValue = [newPredictionValueObject floatValue]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + const float updatedPredictionValue = + (oldPredictionValue + (newPredictionValue * updateValue)); + NSNumber *updatedPredictionValueObject = + [NSNumber numberWithFloat:updatedPredictionValue]; + [oldPredictionValues setObject:updatedPredictionValueObject forKey:label]; + } + NSArray *candidateLabels = [NSMutableArray array]; + for (NSString *label in oldPredictionValues) { + NSNumber *oldPredictionValueObject = + [oldPredictionValues objectForKey:label]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + if (oldPredictionValue > 0.05f) { + NSDictionary *entry = @{ + @"label" : label, + @"value" : oldPredictionValueObject + }; + candidateLabels = [candidateLabels arrayByAddingObject:entry]; + } + } + NSSortDescriptor *sort = + [NSSortDescriptor sortDescriptorWithKey:@"value" ascending:NO]; + NSArray *sortedLabels = [candidateLabels + sortedArrayUsingDescriptors:[NSArray arrayWithObject:sort]]; + + const float leftMargin = 10.0f; + const float topMargin = 10.0f; + + const float valueWidth = 48.0f; + const float valueHeight = 26.0f; + + const float labelWidth = 246.0f; + const float labelHeight = 26.0f; + + const float labelMarginX = 5.0f; + const float labelMarginY = 5.0f; + + [self removeAllLabelLayers]; + + int labelCount = 0; + for (NSDictionary *entry in sortedLabels) { + NSString *label = [entry objectForKey:@"label"]; + NSNumber *valueObject = [entry objectForKey:@"value"]; + const float value = [valueObject floatValue]; + + const float originY = + (topMargin + ((labelHeight + labelMarginY) * labelCount)); + + const int valuePercentage = (int)roundf(value * 100.0f); + + const float valueOriginX = leftMargin; + NSString *valueText = [NSString stringWithFormat:@"%d%%", valuePercentage]; + + [self addLabelLayerWithText:valueText + originX:valueOriginX + originY:originY + width:valueWidth + height:valueHeight + alignment:kCAAlignmentRight]; + + const float labelOriginX = (leftMargin + valueWidth + labelMarginX); + + [self addLabelLayerWithText:[label capitalizedString] + originX:labelOriginX + originY:originY + width:labelWidth + height:labelHeight + alignment:kCAAlignmentLeft]; + + if ((labelCount == 0) && (value > 0.5f)) { + [self speak:[label capitalizedString]]; + } + + labelCount += 1; + if (labelCount > 4) { + break; + } + } +} + +- (void)removeAllLabelLayers { + for (CATextLayer *layer in labelLayers) { + [layer removeFromSuperlayer]; + } + [labelLayers removeAllObjects]; +} + +- (void)addLabelLayerWithText:(NSString *)text + originX:(float)originX + originY:(float)originY + width:(float)width + height:(float)height + alignment:(NSString *)alignment { + CFTypeRef font = (CFTypeRef) @"Menlo-Regular"; + const float fontSize = 20.0f; + + const float marginSizeX = 5.0f; + const float marginSizeY = 2.0f; + + const CGRect backgroundBounds = CGRectMake(originX, originY, width, height); + + const CGRect textBounds = + CGRectMake((originX + marginSizeX), (originY + marginSizeY), + (width - (marginSizeX * 2)), (height - (marginSizeY * 2))); + + CATextLayer *background = [CATextLayer layer]; + [background setBackgroundColor:[UIColor blackColor].CGColor]; + [background setOpacity:0.5f]; + [background setFrame:backgroundBounds]; + background.cornerRadius = 5.0f; + + [[self.view layer] addSublayer:background]; + [labelLayers addObject:background]; + + CATextLayer *layer = [CATextLayer layer]; + [layer setForegroundColor:[UIColor whiteColor].CGColor]; + [layer setFrame:textBounds]; + [layer setAlignmentMode:alignment]; + [layer setWrapped:YES]; + [layer setFont:font]; + [layer setFontSize:fontSize]; + layer.contentsScale = [[UIScreen mainScreen] scale]; + [layer setString:text]; + + [[self.view layer] addSublayer:layer]; + [labelLayers addObject:layer]; +} + +- (void)setPredictionText:(NSString *)text withDuration:(float)duration { + if (duration > 0.0) { + CABasicAnimation *colorAnimation = + [CABasicAnimation animationWithKeyPath:@"foregroundColor"]; + colorAnimation.duration = duration; + colorAnimation.fillMode = kCAFillModeForwards; + colorAnimation.removedOnCompletion = NO; + colorAnimation.fromValue = (id)[UIColor darkGrayColor].CGColor; + colorAnimation.toValue = (id)[UIColor whiteColor].CGColor; + colorAnimation.timingFunction = + [CAMediaTimingFunction functionWithName:kCAMediaTimingFunctionLinear]; + [self.predictionTextLayer addAnimation:colorAnimation + forKey:@"colorAnimation"]; + } else { + self.predictionTextLayer.foregroundColor = [UIColor whiteColor].CGColor; + } + + [self.predictionTextLayer removeFromSuperlayer]; + [[self.view layer] addSublayer:self.predictionTextLayer]; + [self.predictionTextLayer setString:text]; +} + +- (void)speak:(NSString *)words { + if ([synth isSpeaking]) { + return; + } + AVSpeechUtterance *utterance = + [AVSpeechUtterance speechUtteranceWithString:words]; + utterance.voice = [AVSpeechSynthesisVoice voiceWithLanguage:@"en-US"]; + utterance.rate = 0.75 * AVSpeechUtteranceDefaultSpeechRate; + [synth speakUtterance:utterance]; +} + +@end diff --git a/tensorflow/examples/ios/camera/Info.plist b/tensorflow/examples/ios/camera/Info.plist new file mode 100644 index 0000000000..772fb38dcc --- /dev/null +++ b/tensorflow/examples/ios/camera/Info.plist @@ -0,0 +1,44 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tf_camera_example + CFBundleExecutable + ${EXECUTABLE_NAME} + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ${PRODUCT_NAME} + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + NSCameraUsageDescription + Capture images to detect object + UIMainStoryboardFile + MainStoryboard_iPhone + UIRequiresFullScreen + + UIStatusBarHidden + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard b/tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard new file mode 100644 index 0000000000..0f10a22e41 --- /dev/null +++ b/tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/examples/ios/camera/Podfile b/tensorflow/examples/ios/camera/Podfile new file mode 100644 index 0000000000..117828f071 --- /dev/null +++ b/tensorflow/examples/ios/camera/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tf_camera_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/examples/ios/camera/data/grace_hopper.jpg b/tensorflow/examples/ios/camera/data/grace_hopper.jpg new file mode 100644 index 0000000000..d2a427810f Binary files /dev/null and b/tensorflow/examples/ios/camera/data/grace_hopper.jpg differ diff --git a/tensorflow/examples/ios/camera/ios_image_load.h b/tensorflow/examples/ios/camera/ios_image_load.h new file mode 100644 index 0000000000..87a847e145 --- /dev/null +++ b/tensorflow/examples/ios/camera/ios_image_load.h @@ -0,0 +1,27 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_ +#define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_ + +#include + +#include "tensorflow/core/framework/types.h" + +std::vector LoadImageFromFile(const char* file_name, + int* out_width, + int* out_height, + int* out_channels); + +#endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_ diff --git a/tensorflow/examples/ios/camera/ios_image_load.mm b/tensorflow/examples/ios/camera/ios_image_load.mm new file mode 100644 index 0000000000..64d1ea21cf --- /dev/null +++ b/tensorflow/examples/ios/camera/ios_image_load.mm @@ -0,0 +1,87 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ios_image_load.h" + +#include +#include +#include +#include + +#import +#import + +using tensorflow::uint8; + +std::vector LoadImageFromFile(const char* file_name, + int* out_width, int* out_height, + int* out_channels) { + FILE* file_handle = fopen(file_name, "rb"); + fseek(file_handle, 0, SEEK_END); + const size_t bytes_in_file = ftell(file_handle); + fseek(file_handle, 0, SEEK_SET); + std::vector file_data(bytes_in_file); + fread(file_data.data(), 1, bytes_in_file, file_handle); + fclose(file_handle); + CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), + bytes_in_file, + kCFAllocatorNull); + CGDataProviderRef image_provider = + CGDataProviderCreateWithCFData(file_data_ref); + + const char* suffix = strrchr(file_name, '.'); + if (!suffix || suffix == file_name) { + suffix = ""; + } + CGImageRef image; + if (strcasecmp(suffix, ".png") == 0) { + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, + kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || + (strcasecmp(suffix, ".jpeg") == 0)) { + image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, + kCGRenderingIntentDefault); + } else { + CFRelease(image_provider); + CFRelease(file_data_ref); + fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); + *out_width = 0; + *out_height = 0; + *out_channels = 0; + return std::vector(); + } + + const int width = (int)CGImageGetWidth(image); + const int height = (int)CGImageGetHeight(image); + const int channels = 4; + CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); + const int bytes_per_row = (width * channels); + const int bytes_in_image = (bytes_per_row * height); + std::vector result(bytes_in_image); + const int bits_per_component = 8; + CGContextRef context = CGBitmapContextCreate(result.data(), width, height, + bits_per_component, bytes_per_row, color_space, + kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGColorSpaceRelease(color_space); + CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); + CGContextRelease(context); + CFRelease(image); + CFRelease(image_provider); + CFRelease(file_data_ref); + + *out_width = width; + *out_height = height; + *out_channels = channels; + return result; +} diff --git a/tensorflow/examples/ios/camera/main.mm b/tensorflow/examples/ios/camera/main.mm new file mode 100644 index 0000000000..42eff697ef --- /dev/null +++ b/tensorflow/examples/ios/camera/main.mm @@ -0,0 +1,27 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "CameraExampleAppDelegate.h" + +int main(int argc, char *argv[]) { + int retVal = 0; + + @autoreleasepool { + retVal = UIApplicationMain( + argc, argv, nil, NSStringFromClass([CameraExampleAppDelegate class])); + } + return retVal; +} diff --git a/tensorflow/examples/ios/camera/tensorflow_utils.h b/tensorflow/examples/ios/camera/tensorflow_utils.h new file mode 100644 index 0000000000..78bdb82aae --- /dev/null +++ b/tensorflow/examples/ios/camera/tensorflow_utils.h @@ -0,0 +1,52 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ +#define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ + +#include +#include + +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/memmapped_file_system.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +// Reads a serialized GraphDef protobuf file from the bundle, typically +// created with the freeze_graph script. Populates the session argument with a +// Session object that has the model loaded. +tensorflow::Status LoadModel(NSString* file_name, NSString* file_type, + std::unique_ptr* session); + +// Loads a model from a file that has been created using the +// convert_graphdef_memmapped_format tool. This bundles together a GraphDef +// proto together with a file that can be memory-mapped, containing the weight +// parameters for the model. This is useful because it reduces the overall +// memory pressure, since the read-only parameter regions can be easily paged +// out and don't count toward memory limits on iOS. +tensorflow::Status LoadMemoryMappedModel( + NSString* file_name, NSString* file_type, + std::unique_ptr* session, + std::unique_ptr* memmapped_env); + +// Takes a text file with a single label on each line, and returns a list. +tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type, + std::vector* label_strings); + +// Sorts the results from a model execution, and returns the highest scoring. +void GetTopN(const Eigen::TensorMap, + Eigen::Aligned>& prediction, + const int num_results, const float threshold, + std::vector >* top_results); + +#endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ diff --git a/tensorflow/examples/ios/camera/tensorflow_utils.mm b/tensorflow/examples/ios/camera/tensorflow_utils.mm new file mode 100644 index 0000000000..56d1e53081 --- /dev/null +++ b/tensorflow/examples/ios/camera/tensorflow_utils.mm @@ -0,0 +1,219 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "tensorflow_utils.h" + +#include +#include +#include +#include +#include +#include + +namespace { + +// Helper class used to load protobufs efficiently. +class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { + public: + explicit IfstreamInputStream(const std::string& file_name) + : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} + ~IfstreamInputStream() { ifs_.close(); } + + int Read(void* buffer, int size) { + if (!ifs_) { + return -1; + } + ifs_.read(static_cast(buffer), size); + return ifs_.gcount(); + } + + private: + std::ifstream ifs_; +}; +} // namespace + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +void GetTopN(const Eigen::TensorMap, + Eigen::Aligned>& prediction, + const int num_results, const float threshold, + std::vector >* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, + std::vector >, + std::greater > > + top_result_pq; + + const int count = prediction.size(); + for (int i = 0; i < count; ++i) { + const float value = prediction(i); + + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +bool PortableReadFileToProto(const std::string& file_name, + ::google::protobuf::MessageLite* proto) { + ::google::protobuf::io::CopyingInputStreamAdaptor stream( + new IfstreamInputStream(file_name)); + stream.SetOwnsCopyingStream(true); + ::google::protobuf::io::CodedInputStream coded_stream(&stream); + // Total bytes hard limit / warning limit are set to 1GB and 512MB + // respectively. + coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); + return proto->ParseFromCodedStream(&coded_stream); +} + +NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = + [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." + << [extension UTF8String] << "' in bundle."; + return nullptr; + } + return file_path; +} + +tensorflow::Status LoadModel(NSString* file_name, NSString* file_type, + std::unique_ptr* session) { + tensorflow::SessionOptions options; + + tensorflow::Session* session_pointer = nullptr; + tensorflow::Status session_status = + tensorflow::NewSession(options, &session_pointer); + if (!session_status.ok()) { + LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; + return session_status; + } + session->reset(session_pointer); + + tensorflow::GraphDef tensorflow_graph; + + NSString* model_path = FilePathForResourceName(file_name, file_type); + if (!model_path) { + LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] + << [file_type UTF8String]; + return tensorflow::errors::NotFound([file_name UTF8String], + [file_type UTF8String]); + } + const bool read_proto_succeeded = + PortableReadFileToProto([model_path UTF8String], &tensorflow_graph); + if (!read_proto_succeeded) { + LOG(ERROR) << "Failed to load model proto from" << [model_path UTF8String]; + return tensorflow::errors::NotFound([model_path UTF8String]); + } + + tensorflow::Status create_status = (*session)->Create(tensorflow_graph); + if (!create_status.ok()) { + LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; + return create_status; + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status LoadMemoryMappedModel( + NSString* file_name, NSString* file_type, + std::unique_ptr* session, + std::unique_ptr* memmapped_env) { + NSString* network_path = FilePathForResourceName(file_name, file_type); + memmapped_env->reset( + new tensorflow::MemmappedEnv(tensorflow::Env::Default())); + tensorflow::Status mmap_status = + (memmapped_env->get())->InitializeFromFile([network_path UTF8String]); + if (!mmap_status.ok()) { + LOG(ERROR) << "MMap failed with " << mmap_status.error_message(); + return mmap_status; + } + + tensorflow::GraphDef tensorflow_graph; + tensorflow::Status load_graph_status = ReadBinaryProto( + memmapped_env->get(), + tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, + &tensorflow_graph); + if (!load_graph_status.ok()) { + LOG(ERROR) << "MMap load graph failed with " + << load_graph_status.error_message(); + return load_graph_status; + } + + tensorflow::SessionOptions options; + // Disable optimizations on this graph so that constant folding doesn't + // increase the memory footprint by creating new constant copies of the weight + // parameters. + options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_opt_level(::tensorflow::OptimizerOptions::L0); + options.env = memmapped_env->get(); + + tensorflow::Session* session_pointer = nullptr; + tensorflow::Status session_status = + tensorflow::NewSession(options, &session_pointer); + if (!session_status.ok()) { + LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; + return session_status; + } + + tensorflow::Status create_status = session_pointer->Create(tensorflow_graph); + if (!create_status.ok()) { + LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; + return create_status; + } + + session->reset(session_pointer); + + return tensorflow::Status::OK(); +} + +tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type, + std::vector* label_strings) { + // Read the label list + NSString* labels_path = FilePathForResourceName(file_name, file_type); + if (!labels_path) { + LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] + << [file_type UTF8String]; + return tensorflow::errors::NotFound([file_name UTF8String], + [file_type UTF8String]); + } + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while (t) { + std::getline(t, line); + label_strings->push_back(line); + } + t.close(); + return tensorflow::Status::OK(); +} diff --git a/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj b/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj new file mode 100644 index 0000000000..ee9fe57c79 --- /dev/null +++ b/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj @@ -0,0 +1,412 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C3C9DCB1ED3AB4200B8B5FA /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */; }; + 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */; }; + 1C968D171ED3B8F20054F5C3 /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; + 1C968D181ED3B8F20054F5C3 /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; + 1C968D191ED3B8F20054F5C3 /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; }; + 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */; }; + 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */; }; + 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */; }; + 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */; }; + 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */; }; + 1CDB2D4C1ED3A9CD007929E9 /* tensorflow_utils.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */; }; + 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; }; + 54DC6C3C5F734F3A58069F0C /* libPods-tf_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 1C3C9DC81ED3AB4200B8B5FA /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; + 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; + 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_camera_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; path = MainStoryboard_iPhone.storyboard; sourceTree = ""; }; + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = System/Library/Frameworks/CoreMedia.framework; sourceTree = SDKROOT; }; + 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = System/Library/Frameworks/AVFoundation.framework; sourceTree = SDKROOT; }; + 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = ""; }; + 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = ""; }; + 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = ""; }; + 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = ""; }; + 1CDB2D471ED3A9CD007929E9 /* tensorflow_utils.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = tensorflow_utils.h; sourceTree = ""; }; + 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = tensorflow_utils.mm; sourceTree = ""; }; + 1CDB2D4D1ED3AA35007929E9 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_camera_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example.debug.xcconfig"; sourceTree = ""; }; + 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example.release.xcconfig"; sourceTree = ""; }; + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; + 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; + 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 1C564C0A1ED3A92E00087306 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */, + 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */, + 54DC6C3C5F734F3A58069F0C /* libPods-tf_camera_example.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 24D7686C331131624F4454A0 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */, + 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */, + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, + 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 3E9FC355632FB928EA23BEED /* Pods */ = { + isa = PBXGroup; + children = ( + 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */, + 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */, + ); + name = Pods; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */, + 1C3C9DC81ED3AB4200B8B5FA /* ios_image_load.h */, + 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */, + 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */, + 1CDB2D4D1ED3AA35007929E9 /* Info.plist */, + 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */, + 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */, + 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */, + 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */, + 1CDB2D471ED3A9CD007929E9 /* tensorflow_utils.h */, + 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 5911579C1CF4011C00C31E3A /* Products */, + 3E9FC355632FB928EA23BEED /* Pods */, + 24D7686C331131624F4454A0 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, + 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */, + 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */, + ); + path = data; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 1C564C0C1ED3A92E00087306 /* tf_camera_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tf_camera_example" */; + buildPhases = ( + 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */, + 1C564C091ED3A92E00087306 /* Sources */, + 1C564C0A1ED3A92E00087306 /* Frameworks */, + 1C564C0B1ED3A92E00087306 /* Resources */, + 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */, + 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tf_camera_example; + productName = tf_camera_example; + productReference = 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 0830; + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 1C564C0C1ED3A92E00087306 = { + CreatedOnToolsVersion = 8.3.2; + DevelopmentTeam = 5DRPWFQSHP; + ProvisioningStyle = Automatic; + }; + }; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_camera_example" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 1C564C0C1ED3A92E00087306 /* tf_camera_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 1C564C0B1ED3A92E00087306 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 1C968D171ED3B8F20054F5C3 /* grace_hopper.jpg in Resources */, + 1C968D181ED3B8F20054F5C3 /* imagenet_comp_graph_label_strings.txt in Resources */, + 1C968D191ED3B8F20054F5C3 /* tensorflow_inception_graph.pb in Resources */, + 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */, + 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Copy Pods Resources"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example-resources.sh\"\n"; + showEnvVarsInLog = 0; + }; + 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Check Pods Manifest.lock"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 1C564C091ED3A92E00087306 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 1CDB2D4C1ED3A9CD007929E9 /* tensorflow_utils.mm in Sources */, + 1C3C9DCB1ED3AB4200B8B5FA /* ios_image_load.mm in Sources */, + 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */, + 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */, + 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 1C564C361ED3A92E00087306 /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + DEVELOPMENT_TEAM = 5DRPWFQSHP; + INFOPLIST_FILE = Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 10.3; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 3.0; + }; + name = Debug; + }; + 1C564C371ED3A92E00087306 /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + DEVELOPMENT_TEAM = 5DRPWFQSHP; + INFOPLIST_FILE = Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 10.3; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_OPTIMIZATION_LEVEL = "-Owholemodule"; + SWIFT_VERSION = 3.0; + }; + name = Release; + }; + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tf_camera_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 1C564C361ED3A92E00087306 /* Debug */, + 1C564C371ED3A92E00087306 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_camera_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} diff --git a/tensorflow/examples/ios/simple/AppDelegate.h b/tensorflow/examples/ios/simple/AppDelegate.h new file mode 100644 index 0000000000..75b1f1da38 --- /dev/null +++ b/tensorflow/examples/ios/simple/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property (strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/examples/ios/simple/AppDelegate.mm b/tensorflow/examples/ios/simple/AppDelegate.mm new file mode 100644 index 0000000000..1e808eb976 --- /dev/null +++ b/tensorflow/examples/ios/simple/AppDelegate.mm @@ -0,0 +1,44 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +#import "RunModelViewController.h" + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + + UITabBarController *bar = [[UITabBarController alloc] init]; + [bar setViewControllers: + @[[[RunModelViewController alloc] init]]]; + bar.selectedIndex = 0; + self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; + self.window.rootViewController = bar; + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application {} + +- (void)applicationDidEnterBackground:(UIApplication *)application {} + +- (void)applicationWillEnterForeground:(UIApplication *)application {} + +- (void)applicationDidBecomeActive:(UIApplication *)application {} + +- (void)applicationWillTerminate:(UIApplication *)application {} + +@end diff --git a/tensorflow/examples/ios/simple/Podfile b/tensorflow/examples/ios/simple/Podfile new file mode 100644 index 0000000000..1740ad6457 --- /dev/null +++ b/tensorflow/examples/ios/simple/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tf_simple_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/examples/ios/simple/RunModel-Info.plist b/tensorflow/examples/ios/simple/RunModel-Info.plist new file mode 100644 index 0000000000..d0a8742456 --- /dev/null +++ b/tensorflow/examples/ios/simple/RunModel-Info.plist @@ -0,0 +1,47 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tf_simple_example + CFBundleExecutable + tf_simple_example + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ios-app + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + UILaunchStoryboardName + RunModelViewController + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/examples/ios/simple/RunModelViewController.h b/tensorflow/examples/ios/simple/RunModelViewController.h new file mode 100644 index 0000000000..4e1a83ccf5 --- /dev/null +++ b/tensorflow/examples/ios/simple/RunModelViewController.h @@ -0,0 +1,24 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface RunModelViewController : UIViewController + +- (IBAction)getUrl:(id)sender; + +@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property (weak, nonatomic) IBOutlet UITextField *urlTextField; + +@end diff --git a/tensorflow/examples/ios/simple/RunModelViewController.mm b/tensorflow/examples/ios/simple/RunModelViewController.mm new file mode 100644 index 0000000000..c8ccb5c77b --- /dev/null +++ b/tensorflow/examples/ios/simple/RunModelViewController.mm @@ -0,0 +1,253 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "RunModelViewController.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/public/session.h" + +#include "ios_image_load.h" + +NSString* RunInferenceOnImage(); + +namespace { +class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { + public: + explicit IfstreamInputStream(const std::string& file_name) + : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} + ~IfstreamInputStream() { ifs_.close(); } + + int Read(void* buffer, int size) { + if (!ifs_) { + return -1; + } + ifs_.read(static_cast(buffer), size); + return (int)ifs_.gcount(); + } + + private: + std::ifstream ifs_; +}; +} // namespace + +@interface RunModelViewController () +@end + +@implementation RunModelViewController { +} + +- (IBAction)getUrl:(id)sender { + NSString* inference_result = RunInferenceOnImage(); + self.urlContentTextView.text = inference_result; +} + +@end + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +static void GetTopN( + const Eigen::TensorMap, + Eigen::Aligned>& prediction, + const int num_results, const float threshold, + std::vector >* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, + std::vector >, + std::greater > > top_result_pq; + + const long count = prediction.size(); + for (int i = 0; i < count; ++i) { + const float value = prediction(i); + + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + + +bool PortableReadFileToProto(const std::string& file_name, + ::google::protobuf::MessageLite* proto) { + ::google::protobuf::io::CopyingInputStreamAdaptor stream( + new IfstreamInputStream(file_name)); + stream.SetOwnsCopyingStream(true); + // TODO(jiayq): the following coded stream is for debugging purposes to allow + // one to parse arbitrarily large messages for MessageLite. One most likely + // doesn't want to put protobufs larger than 64MB on Android, so we should + // eventually remove this and quit loud when a large protobuf is passed in. + ::google::protobuf::io::CodedInputStream coded_stream(&stream); + // Total bytes hard limit / warning limit are set to 1GB and 512MB + // respectively. + coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); + return proto->ParseFromCodedStream(&coded_stream); +} + +NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." + << [extension UTF8String] << "' in bundle."; + } + return file_path; +} + +NSString* RunInferenceOnImage() { + tensorflow::SessionOptions options; + + tensorflow::Session* session_pointer = nullptr; + tensorflow::Status session_status = tensorflow::NewSession(options, &session_pointer); + if (!session_status.ok()) { + std::string status_string = session_status.ToString(); + return [NSString stringWithFormat: @"Session create failed - %s", + status_string.c_str()]; + } + std::unique_ptr session(session_pointer); + LOG(INFO) << "Session created."; + + tensorflow::GraphDef tensorflow_graph; + LOG(INFO) << "Graph created."; + + NSString* network_path = FilePathForResourceName(@"tensorflow_inception_graph", @"pb"); + PortableReadFileToProto([network_path UTF8String], &tensorflow_graph); + + LOG(INFO) << "Creating session."; + tensorflow::Status s = session->Create(tensorflow_graph); + if (!s.ok()) { + LOG(ERROR) << "Could not create TensorFlow Graph: " << s; + return @""; + } + + // Read the label list + NSString* labels_path = FilePathForResourceName(@"imagenet_comp_graph_label_strings", @"txt"); + std::vector label_strings; + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while(t){ + std::getline(t, line); + label_strings.push_back(line); + } + t.close(); + + // Read the Grace Hopper image. + NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); + int image_width; + int image_height; + int image_channels; + std::vector image_data = LoadImageFromFile( + [image_path UTF8String], &image_width, &image_height, &image_channels); + const int wanted_width = 224; + const int wanted_height = 224; + const int wanted_channels = 3; + const float input_mean = 117.0f; + const float input_std = 1.0f; + assert(image_channels >= wanted_channels); + tensorflow::Tensor image_tensor( + tensorflow::DT_FLOAT, + tensorflow::TensorShape({ + 1, wanted_height, wanted_width, wanted_channels})); + auto image_tensor_mapped = image_tensor.tensor(); + tensorflow::uint8* in = image_data.data(); + // tensorflow::uint8* in_end = (in + (image_height * image_width * image_channels)); + float* out = image_tensor_mapped.data(); + for (int y = 0; y < wanted_height; ++y) { + const int in_y = (y * image_height) / wanted_height; + tensorflow::uint8* in_row = in + (in_y * image_width * image_channels); + float* out_row = out + (y * wanted_width * wanted_channels); + for (int x = 0; x < wanted_width; ++x) { + const int in_x = (x * image_width) / wanted_width; + tensorflow::uint8* in_pixel = in_row + (in_x * image_channels); + float* out_pixel = out_row + (x * wanted_channels); + for (int c = 0; c < wanted_channels; ++c) { + out_pixel[c] = (in_pixel[c] - input_mean) / input_std; + } + } + } + + NSString* result = [network_path stringByAppendingString: @" - loaded!"]; + result = [NSString stringWithFormat: @"%@ - %lu, %s - %dx%d", result, + label_strings.size(), label_strings[0].c_str(), image_width, image_height]; + + std::string input_layer = "input"; + std::string output_layer = "output"; + std::vector outputs; + tensorflow::Status run_status = session->Run({{input_layer, image_tensor}}, + {output_layer}, {}, &outputs); + if (!run_status.ok()) { + LOG(ERROR) << "Running model failed: " << run_status; + tensorflow::LogAllRegisteredKernels(); + result = @"Error running model"; + return result; + } + tensorflow::string status_string = run_status.ToString(); + result = [NSString stringWithFormat: @"%@ - %s", result, + status_string.c_str()]; + + tensorflow::Tensor* output = &outputs[0]; + const int kNumResults = 5; + const float kThreshold = 0.1f; + std::vector > top_results; + GetTopN(output->flat(), kNumResults, kThreshold, &top_results); + + std::stringstream ss; + ss.precision(3); + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + + ss << index << " " << confidence << " "; + + // Write out the result as a string + if (index < label_strings.size()) { + // just for safety: theoretically, the output is under 1000 unless there + // is some numerical issues leading to a wrong prediction. + ss << label_strings[index]; + } else { + ss << "Prediction: " << index; + } + + ss << "\n"; + } + + LOG(INFO) << "Predictions: " << ss.str(); + + tensorflow::string predictions = ss.str(); + result = [NSString stringWithFormat: @"%@ - %s", result, + predictions.c_str()]; + + return result; +} diff --git a/tensorflow/examples/ios/simple/RunModelViewController.xib b/tensorflow/examples/ios/simple/RunModelViewController.xib new file mode 100644 index 0000000000..93f334b985 --- /dev/null +++ b/tensorflow/examples/ios/simple/RunModelViewController.xib @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/examples/ios/simple/data/grace_hopper.jpg b/tensorflow/examples/ios/simple/data/grace_hopper.jpg new file mode 100644 index 0000000000..d2a427810f Binary files /dev/null and b/tensorflow/examples/ios/simple/data/grace_hopper.jpg differ diff --git a/tensorflow/examples/ios/simple/ios_image_load.h b/tensorflow/examples/ios/simple/ios_image_load.h new file mode 100644 index 0000000000..0e0b771118 --- /dev/null +++ b/tensorflow/examples/ios/simple/ios_image_load.h @@ -0,0 +1,27 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ +#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ + +#include + +#include "tensorflow/core/framework/types.h" + +std::vector LoadImageFromFile(const char* file_name, + int* out_width, + int* out_height, + int* out_channels); + +#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/examples/ios/simple/ios_image_load.mm b/tensorflow/examples/ios/simple/ios_image_load.mm new file mode 100644 index 0000000000..64d1ea21cf --- /dev/null +++ b/tensorflow/examples/ios/simple/ios_image_load.mm @@ -0,0 +1,87 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ios_image_load.h" + +#include +#include +#include +#include + +#import +#import + +using tensorflow::uint8; + +std::vector LoadImageFromFile(const char* file_name, + int* out_width, int* out_height, + int* out_channels) { + FILE* file_handle = fopen(file_name, "rb"); + fseek(file_handle, 0, SEEK_END); + const size_t bytes_in_file = ftell(file_handle); + fseek(file_handle, 0, SEEK_SET); + std::vector file_data(bytes_in_file); + fread(file_data.data(), 1, bytes_in_file, file_handle); + fclose(file_handle); + CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), + bytes_in_file, + kCFAllocatorNull); + CGDataProviderRef image_provider = + CGDataProviderCreateWithCFData(file_data_ref); + + const char* suffix = strrchr(file_name, '.'); + if (!suffix || suffix == file_name) { + suffix = ""; + } + CGImageRef image; + if (strcasecmp(suffix, ".png") == 0) { + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, + kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || + (strcasecmp(suffix, ".jpeg") == 0)) { + image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, + kCGRenderingIntentDefault); + } else { + CFRelease(image_provider); + CFRelease(file_data_ref); + fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); + *out_width = 0; + *out_height = 0; + *out_channels = 0; + return std::vector(); + } + + const int width = (int)CGImageGetWidth(image); + const int height = (int)CGImageGetHeight(image); + const int channels = 4; + CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); + const int bytes_per_row = (width * channels); + const int bytes_in_image = (bytes_per_row * height); + std::vector result(bytes_in_image); + const int bits_per_component = 8; + CGContextRef context = CGBitmapContextCreate(result.data(), width, height, + bits_per_component, bytes_per_row, color_space, + kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGColorSpaceRelease(color_space); + CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); + CGContextRelease(context); + CFRelease(image); + CFRelease(image_provider); + CFRelease(file_data_ref); + + *out_width = width; + *out_height = height; + *out_channels = channels; + return result; +} diff --git a/tensorflow/examples/ios/simple/main.mm b/tensorflow/examples/ios/simple/main.mm new file mode 100644 index 0000000000..d70550a730 --- /dev/null +++ b/tensorflow/examples/ios/simple/main.mm @@ -0,0 +1,22 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +int main(int argc, char * argv[]) { + @autoreleasepool { + NSString *delegateClassName = @"AppDelegate"; + return UIApplicationMain(argc, argv, nil, delegateClassName); + } +} diff --git a/tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj b/tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj new file mode 100644 index 0000000000..55c06e28fb --- /dev/null +++ b/tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj @@ -0,0 +1,404 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; }; + 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; }; + 2530463E3C9A9D5FB9299C0E /* libPods-tf_simple_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */; }; + 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; + 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; + 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; + 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; }; + 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; + 59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; + 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */; }; + 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; + 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; + 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; + 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "RunModel-Info.plist"; sourceTree = ""; }; + 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = ""; }; + 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = ""; }; + 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = ""; }; + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_simple_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example.release.xcconfig"; sourceTree = ""; }; + 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_simple_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example.debug.xcconfig"; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 591157981CF4011C00C31E3A /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */, + 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */, + 2530463E3C9A9D5FB9299C0E /* libPods-tf_simple_example.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 24D7686C331131624F4454A0 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 3E9FC355632FB928EA23BEED /* Pods */ = { + isa = PBXGroup; + children = ( + 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */, + 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */, + ); + name = Pods; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, + 59A3CFFC1CF4E68100C4259F /* main.mm */, + 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */, + 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */, + 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */, + 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */, + 5911579C1CF4011C00C31E3A /* Products */, + 3E9FC355632FB928EA23BEED /* Pods */, + 24D7686C331131624F4454A0 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 5911579B1CF4011C00C31E3A /* tf_simple_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, + 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */, + 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */, + ); + path = data; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 5911579A1CF4011C00C31E3A /* tf_simple_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */; + buildPhases = ( + 1CD07C1CEB04E50C5975C7BB /* [CP] Check Pods Manifest.lock */, + 591157971CF4011C00C31E3A /* Sources */, + 591157981CF4011C00C31E3A /* Frameworks */, + 591157991CF4011C00C31E3A /* Resources */, + 0EABEF9F31578BDA8CA9D2A7 /* [CP] Embed Pods Frameworks */, + 96DDF9E6E35958387A215092 /* [CP] Copy Pods Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tf_simple_example; + productName = tf_ios_makefile_example; + productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 5911579A1CF4011C00C31E3A = { + CreatedOnToolsVersion = 7.2; + DevelopmentTeam = 85Z3VXS37U; + }; + }; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_simple_example" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 5911579A1CF4011C00C31E3A /* tf_simple_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 591157991CF4011C00C31E3A /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */, + 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */, + 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */, + 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 0EABEF9F31578BDA8CA9D2A7 /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + 1CD07C1CEB04E50C5975C7BB /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Check Pods Manifest.lock"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n"; + showEnvVarsInLog = 0; + }; + 96DDF9E6E35958387A215092 /* [CP] Copy Pods Resources */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Copy Pods Resources"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example-resources.sh\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 591157971CF4011C00C31E3A /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 59A3D0091CF4E68100C4259F /* main.mm in Sources */, + 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */, + 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */, + 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 591157B31CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */; + buildSettings = { + CLANG_DEBUG_INFORMATION_LEVEL = default; + CODE_SIGN_IDENTITY = "iPhone Developer"; + ENABLE_BITCODE = NO; + GCC_ENABLE_CPP_EXCEPTIONS = YES; + GCC_ENABLE_CPP_RTTI = YES; + HEADER_SEARCH_PATHS = "$(inherited)"; + INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 9.2; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-simple-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SEPARATE_STRIP = NO; + }; + name = Debug; + }; + 591157B41CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */; + buildSettings = { + CLANG_DEBUG_INFORMATION_LEVEL = default; + CODE_SIGN_IDENTITY = "iPhone Developer"; + ENABLE_BITCODE = NO; + GCC_ENABLE_CPP_EXCEPTIONS = YES; + GCC_ENABLE_CPP_RTTI = YES; + HEADER_SEARCH_PATHS = "$(inherited)"; + INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 9.2; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ""; + ONLY_ACTIVE_ARCH = YES; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-simple-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SEPARATE_STRIP = NO; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_simple_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B31CF4011D00C31E3A /* Debug */, + 591157B41CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} -- cgit v1.2.3 From 7b4c01794fbc2e6dc46e93a42416dd80929ce1e5 Mon Sep 17 00:00:00 2001 From: RJ Ryan Date: Wed, 7 Jun 2017 11:04:18 -0700 Subject: Support numpy-style padding and slicing of tf.spectral.rfft/irfft to match the desired FFT length. Fixes incorrect RFFT/IRFFT results when fft_length does not match the input dimension. PiperOrigin-RevId: 158289991 --- tensorflow/core/kernels/fft_ops.cc | 50 +++++++++++++++---- tensorflow/core/ops/spectral_ops.cc | 26 ++++++++++ tensorflow/python/kernel_tests/fft_ops_test.py | 69 +++++++++++++++++++++++--- tensorflow/python/ops/spectral_ops.py | 52 +++++++++++++++++++ 4 files changed, 179 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 639f6a76de..b479956632 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -39,15 +39,15 @@ class FFTBase : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& in = ctx->input(0); - const TensorShape& shape = in.shape(); + const TensorShape& input_shape = in.shape(); const int fft_rank = Rank(); OP_REQUIRES( - ctx, shape.dims() >= fft_rank, + ctx, input_shape.dims() >= fft_rank, errors::InvalidArgument("Input must have rank of at least ", fft_rank, - " but got: ", shape.DebugString())); + " but got: ", input_shape.DebugString())); Tensor* out; - TensorShape output_shape = shape; + TensorShape output_shape = input_shape; uint64 fft_shape[3] = {0, 0, 0}; // In R2C or C2R mode, we use a second input to specify the FFT length @@ -57,13 +57,29 @@ class FFTBase : public OpKernel { OP_REQUIRES(ctx, fft_length.shape().dims() == 1 && fft_length.shape().dim_size(0) == fft_rank, - errors::InvalidArgument("fft_length must have shape [", + errors::InvalidArgument("fft_length must have shape [", fft_rank, "]")); auto fft_length_as_vec = fft_length.vec(); for (int i = 0; i < fft_rank; ++i) { fft_shape[i] = fft_length_as_vec(i); - uint64 dim = IsForward() && i == fft_rank - 1 && fft_shape[i] != 0 + // Each input dimension must have length of at least fft_shape[i]. For + // IRFFTs, the inner-most input dimension must have length of at least + // fft_shape[i] / 2 + 1. + bool inner_most = (i == fft_rank - 1); + uint64 min_input_dim_length = + !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i]; + auto input_index = input_shape.dims() - fft_rank + i; + OP_REQUIRES( + ctx, + // We pass through empty tensors, so special case them here. + input_shape.dim_size(input_index) == 0 || + input_shape.dim_size(input_index) >= min_input_dim_length, + errors::InvalidArgument( + "Input dimension ", input_index, + " must have length of at least ", min_input_dim_length, + " but got: ", input_shape.dim_size(input_index))); + uint64 dim = IsForward() && inner_most && fft_shape[i] != 0 ? fft_shape[i] / 2 + 1 : fft_shape[i]; output_shape.set_dim(output_shape.dims() - fft_rank + i, dim); @@ -76,7 +92,7 @@ class FFTBase : public OpKernel { } OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out)); - if (shape.num_elements() == 0) { + if (input_shape.num_elements() == 0) { return; } @@ -120,20 +136,32 @@ class FFTCPU : public FFTBase { } else { if (IsForward()) { auto input = (Tensor(in)).flat_inner_dims(); + auto input_dims = input.dimensions(); + + // Slice input to fft_shape on its inner-most dimensions. + Eigen::DSizes input_slice_sizes; + input_slice_sizes[0] = input_dims[0]; + TensorShape temp_shape{input_dims[0]}; + for (int i = 1; i <= FFTRank; ++i) { + input_slice_sizes[i] = fft_shape[i - 1]; + temp_shape.AddDim(fft_shape[i - 1]); + } + auto output = out->flat_inner_dims(); - Eigen::DSizes startIndices; + const Eigen::DSizes zero_start_indices; // Compute the full FFT using a temporary tensor. Tensor temp; OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::v(), - in.shape(), &temp)); + temp_shape, &temp)); auto full_fft = temp.flat_inner_dims(); full_fft.device(device) = - input.template fft(axes); + input.slice(zero_start_indices, input_slice_sizes) + .template fft(axes); // Slice away the negative frequency components. output.device(device) = - full_fft.slice(startIndices, output.dimensions()); + full_fft.slice(zero_start_indices, output.dimensions()); } else { // TODO: reconstruct the full fft and take the inverse. ctx->CtxFailureWithWarning( diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc index 09b460fd14..592aaa25c3 100644 --- a/tensorflow/core/ops/spectral_ops.cc +++ b/tensorflow/core/ops/spectral_ops.cc @@ -201,6 +201,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, followed by the `fft_length / 2` positive-frequency terms. +Along the axis `RFFT` is computed on, if `fft_length` is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A float32 tensor. fft_length: An int32 tensor of shape [1]. The FFT length. output: A complex64 tensor of the same rank as `input`. The inner-most @@ -230,6 +234,10 @@ dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to compute `input` is odd, it should be provided since it cannot be inferred properly. +Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller +than the corresponding dimension of `input`, the dimension is cropped. If it is +larger, the dimension is padded with zeros. + input: A complex64 tensor. fft_length: An int32 tensor of shape [1]. The FFT length. output: A float32 tensor of the same rank as `input`. The inner-most @@ -257,6 +265,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the of `output`: the zero-frequency term, followed by the `fft_length / 2` positive-frequency terms. +Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A float32 tensor. fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. output: A complex64 tensor of the same rank as `input`. The inner-most 2 @@ -287,6 +299,11 @@ from the size of the inner-most 2 dimensions of `input`. If the FFT length used to compute `input` is odd, it should be provided since it cannot be inferred properly. +Along each axis `IRFFT2D` is computed on, if `fft_length` (or +`fft_length / 2 + 1` for the inner-most dimension) is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A complex64 tensor. fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. output: A float32 tensor of the same rank as `input`. The inner-most 2 @@ -314,6 +331,10 @@ Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the of `output`: the zero-frequency term, followed by the `fft_length / 2` positive-frequency terms. +Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A float32 tensor. fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. output: A complex64 tensor of the same rank as `input`. The inner-most 3 @@ -344,6 +365,11 @@ from the size of the inner-most 3 dimensions of `input`. If the FFT length used to compute `input` is odd, it should be provided since it cannot be inferred properly. +Along each axis `IRFFT3D` is computed on, if `fft_length` (or +`fft_length / 2 + 1` for the inner-most dimension) is smaller than the +corresponding dimension of `input`, the dimension is cropped. If it is larger, +the dimension is padded with zeros. + input: A complex64 tensor. fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. output: A float32 tensor of the same rank as `input`. The inner-most 3 diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py index 84928bd2e1..2f3c5a6c33 100644 --- a/tensorflow/python/kernel_tests/fft_ops_test.py +++ b/tensorflow/python/kernel_tests/fft_ops_test.py @@ -22,8 +22,10 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops @@ -297,6 +299,38 @@ class RFFTOpsTest(BaseFFTOpsTest): self._CompareBackward(c2r.astype(np.complex64), rank, (size,) * rank, use_placeholder=True) + def testFftLength(self): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + for size in (5, 6): + inner_dim = size // 2 + 1 + r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( + (size,) * dims) + c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), + 10).reshape((size,) * (dims - 1) + (inner_dim,)) + + # Test truncation (FFT size < dimensions). + fft_length = (size - 2,) * rank + self._CompareForward(r2c.astype(np.float32), rank, fft_length) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length) + + # Confirm it works with unknown shapes as well. + self._CompareForward(r2c.astype(np.float32), rank, fft_length, + use_placeholder=True) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length, + use_placeholder=True) + + # Test padding (FFT size > dimensions). + fft_length = (size + 2,) * rank + self._CompareForward(r2c.astype(np.float32), rank, fft_length) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length) + + # Confirm it works with unknown shapes as well. + self._CompareForward(r2c.astype(np.float32), rank, fft_length, + use_placeholder=True) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length, + use_placeholder=True) + def testRandom(self): np.random.seed(12345) @@ -326,10 +360,10 @@ class RFFTOpsTest(BaseFFTOpsTest): for dims in xrange(0, rank): x = np.zeros((1,) * dims).astype(np.complex64) with self.assertRaisesWithPredicateMatch( - ValueError, "Shape must be .*rank {}.*".format(rank)): + ValueError, "Shape .* must have rank at least {}".format(rank)): self._tfFFT(x, rank) with self.assertRaisesWithPredicateMatch( - ValueError, "Shape must be .*rank {}.*".format(rank)): + ValueError, "Shape .* must have rank at least {}".format(rank)): self._tfIFFT(x, rank) for dims in xrange(rank, rank + 2): x = np.zeros((1,) * rank) @@ -337,10 +371,10 @@ class RFFTOpsTest(BaseFFTOpsTest): # Test non-rank-1 fft_length produces an error. fft_length = np.zeros((1, 1)).astype(np.int32) with self.assertRaisesWithPredicateMatch(ValueError, - "Shape must be .*rank 1"): + "Shape .* must have rank 1"): self._tfFFT(x, rank, fft_length) with self.assertRaisesWithPredicateMatch(ValueError, - "Shape must be .*rank 1"): + "Shape .* must have rank 1"): self._tfIFFT(x, rank, fft_length) # Test wrong fft_length length. @@ -352,6 +386,29 @@ class RFFTOpsTest(BaseFFTOpsTest): ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): self._tfIFFT(x, rank, fft_length) + # Test that calling the kernel directly without padding to fft_length + # produces an error. + rffts_for_rank = {1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft], + 2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d], + 3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d]} + rfft_fn, irfft_fn = rffts_for_rank[rank] + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "Input dimension .* must have length of at least 6 but got: 5"): + x = np.zeros((5,) * rank).astype(np.float32) + fft_length = [6] * rank + with self.test_session(): + rfft_fn(x, fft_length).eval() + # TODO(rjryan): Remove when CPU-based IRFFT is supported. + if test.is_gpu_available(cuda_only=True): + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "Input dimension .* must have length of at least .* but got: 3"): + x = np.zeros((3,) * rank).astype(np.complex64) + fft_length = [6] * rank + with self.test_session(): + irfft_fn(x, fft_length).eval() + def testGrad_Simple(self): if test.is_gpu_available(cuda_only=True): for rank in VALID_FFT_RANKS: @@ -359,9 +416,7 @@ class RFFTOpsTest(BaseFFTOpsTest): if rank == 3: continue for dims in xrange(rank, rank + 2): - for size in ( - 5, - 6,): + for size in (5, 6): re = np.ones(shape=(size,) * dims, dtype=np.float32) im = -np.ones(shape=(size,) * dims, dtype=np.float32) self._checkGradReal(self._tfFFTForRank(rank), re, use_gpu=True) diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py index 95a2806330..47ff7018f2 100644 --- a/tensorflow/python/ops/spectral_ops.py +++ b/tensorflow/python/ops/spectral_ops.py @@ -33,6 +33,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops as _ops +from tensorflow.python.framework import tensor_util as _tensor_util from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import math_ops as _math_ops @@ -70,6 +71,52 @@ def _infer_fft_length_for_irfft(input_tensor, fft_rank): return _ops.convert_to_tensor(fft_length, _dtypes.int32) +def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False): + """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims.""" + fft_shape = _tensor_util.constant_value_as_shape(fft_length) + + # Edge case: skip padding empty tensors. + if (input_tensor.shape.ndims is not None and + any(dim.value == 0 for dim in input_tensor.shape)): + return input_tensor + + # If we know the shapes ahead of time, we can either skip or pre-compute the + # appropriate paddings. Otherwise, fall back to computing paddings in + # TensorFlow. + if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None: + # Slice the last FFT-rank dimensions from input_tensor's shape. + input_fft_shape = input_tensor.shape[-fft_shape.ndims:] + + if input_fft_shape.is_fully_defined(): + # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. + if is_reverse: + fft_shape = fft_shape[:-1].concatenate(fft_shape[-1].value // 2 + 1) + + paddings = [[0, max(fft_dim.value - input_dim.value, 0)] + for fft_dim, input_dim in zip(fft_shape, input_fft_shape)] + if any(pad > 0 for _, pad in paddings): + outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims - + fft_shape.ndims), 0) + return _array_ops.pad(input_tensor, outer_paddings + paddings) + return input_tensor + + # If we can't determine the paddings ahead of time, then we have to pad. If + # the paddings end up as zero, tf.pad has a special-case that does no work. + input_rank = _array_ops.rank(input_tensor) + input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:] + outer_dims = _math_ops.maximum(0, input_rank - fft_rank) + outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype) + # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. + if is_reverse: + fft_length = _array_ops.concat([fft_length[:-1], + fft_length[-1:] // 2 + 1], 0) + fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape) + paddings = _array_ops.concat([outer_paddings, fft_paddings], 0) + paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings], + axis=1) + return _array_ops.pad(input_tensor, paddings) + + def _rfft_wrapper(fft_fn, fft_rank, default_name): """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" @@ -77,10 +124,12 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name): with _ops.name_scope(name, default_name, [input_tensor, fft_length]) as name: input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32) + input_tensor.shape.with_rank_at_least(fft_rank) if fft_length is None: fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank) else: fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) + input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length) return fft_fn(input_tensor, fft_length, name) _rfft.__doc__ = fft_fn.__doc__ return _rfft @@ -93,10 +142,13 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name): with _ops.name_scope(name, default_name, [input_tensor, fft_length]) as name: input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64) + input_tensor.shape.with_rank_at_least(fft_rank) if fft_length is None: fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank) else: fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) + input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, + is_reverse=True) return ifft_fn(input_tensor, fft_length, name) _irfft.__doc__ = ifft_fn.__doc__ return _irfft -- cgit v1.2.3 From ebae3deba801b55debbf67205c43ee54f80a7494 Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Wed, 7 Jun 2017 11:13:12 -0700 Subject: Switch back to max_num_rows_to_load instead of reading slice by slice due to performance regression from network overhead. Add check when using initializing values to avoid seg fault PiperOrigin-RevId: 158291218 --- .../framework/kernels/load_and_remap_matrix_op.cc | 218 +++++++++++-------- tensorflow/contrib/framework/ops/checkpoint_ops.cc | 4 + .../contrib/framework/python/ops/checkpoint_ops.py | 50 ++++- .../framework/python/ops/checkpoint_ops_test.py | 239 ++++++++++++++++++--- 4 files changed, 383 insertions(+), 128 deletions(-) diff --git a/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc b/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc index 12a4d36bf1..a74ad98663 100644 --- a/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc +++ b/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc @@ -31,6 +31,29 @@ limitations under the License. namespace tensorflow { +namespace { +// Returning a Status instead of using OP_REQUIRES directly since that doesn't +// seem to work outside the main OpKernel functions. +Status RemapVectorToMap(const TTypes::Vec& remapping, + std::vector* id_present, + std::unordered_map* old_id_to_new_id) { + id_present->clear(); + id_present->resize(remapping.size(), false); + for (int i = 0; i < remapping.size(); ++i) { + const int64 old_id = remapping(i); + if (old_id < 0) continue; + (*id_present)[i] = true; + if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) { + return errors::Unimplemented( + strings::StrCat("Old ID ", old_id, " is mapped to both new ID ", + old_id_to_new_id->at(old_id), " and ", i, + ", which is not supported.")); + } + } + return Status::OK(); +} +} // anonymous namespace + // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and // swaps around the rows/columns according to row_remapping/col_remapping. // "Missing" cells are initialized with values from initializing_values. @@ -40,13 +63,15 @@ class LoadAndRemapMatrixOp : public OpKernel { : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_)); OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_)); + OP_REQUIRES_OK( + context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_)); } void Compute(OpKernelContext* context) override { // Checks what we're remapping and inverts the relevant remapping Tensors to // be maps with key = old ID, value = new ID. - std::vector> old_row_to_new_row_pairs; - std::vector row_id_present(num_rows_); + std::unordered_map old_row_to_new_row_map; + std::vector row_id_present; const Tensor* row_remapping_t; OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t)); const auto row_remapping = row_remapping_t->vec(); @@ -54,16 +79,27 @@ class LoadAndRemapMatrixOp : public OpKernel { errors::InvalidArgument(strings::StrCat( "Size of row_remapping is ", row_remapping.size(), " intead of being equal to num_rows=", num_rows_))); - old_row_to_new_row_pairs.reserve(num_rows_); + OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present, + &old_row_to_new_row_map)); + + // Calculates the min/max old row ID that we need to read, to save us from + // reading some unnecessary slices of the old tensor. + int64 min_old_row = -1; + int64 max_old_row = -1; for (int i = 0; i < row_remapping.size(); ++i) { - if (row_remapping(i) < 0) continue; - row_id_present[i] = true; - old_row_to_new_row_pairs.push_back(std::make_pair(row_remapping(i), i)); + if (min_old_row < 0 || + (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) { + min_old_row = row_remapping(i); + } + if (max_old_row < 0 || + (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) { + max_old_row = row_remapping(i); + } } // Processes the remapping for columns. std::unordered_map old_col_to_new_col_map; - std::vector col_id_present(num_cols_); + std::vector col_id_present; const Tensor* col_remapping_t; OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t)); const auto col_remapping = col_remapping_t->vec(); @@ -77,19 +113,8 @@ class LoadAndRemapMatrixOp : public OpKernel { errors::InvalidArgument(strings::StrCat( "Provided col_remapping, but its size is ", col_remapping.size(), " instead of being equal to num_cols=", num_cols_))); - for (int i = 0; i < col_remapping.size(); ++i) { - const int64 old_col = col_remapping(i); - if (old_col < 0) continue; - col_id_present[i] = true; - OP_REQUIRES( - context, - gtl::InsertIfNotPresent(&old_col_to_new_col_map, old_col, i), - errors::Unimplemented(strings::StrCat( - "Old column ID ", old_col, " is mapped to both new column ID ", - old_col_to_new_col_map[old_col], " and ", i, - ", which is not currently supported - but could be " - "implemented."))); - } + OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present, + &old_col_to_new_col_map)); } else { col_id_present.clear(); col_id_present.resize(num_cols_, true); @@ -139,29 +164,27 @@ class LoadAndRemapMatrixOp : public OpKernel { " instead of being equal to num_cols=", num_cols_))); } - // Uses TensorSlice to selectively read rows of interest from the old - // tensor. Given BundleReader's use of RandomAccessFile and InputBuffer, - // there shouldn't too many more additional disk seeks when compared to - // loading the old tensor in chunks, once we sort the row IDs. Even if there - // are locality concerns with some reading patterns, that just means if we - // had read it in chunks, then we would have had to read, copy, and process - // then discard many redundant rows - so we should come out ahead this way. - // In addition, this frees us from having to hold the entire old tensor in - // memory. - std::sort(old_row_to_new_row_pairs.begin(), old_row_to_new_row_pairs.end()); + // Uses TensorSlice to potentially load the old tensor in chunks in case + // memory usage is a concern. std::vector tensor_slices; - tensor_slices.reserve(old_row_to_new_row_pairs.size()); TensorSlice slice(tensor_shape.dims()); - for (const auto& pair : old_row_to_new_row_pairs) { - OP_REQUIRES( - context, pair.first < tensor_shape.dim_size(0), - errors::InvalidArgument(strings::StrCat( - "Trying to read row ", pair.first, " from tensor ", - old_tensor_name, ", which only has ", tensor_shape.dim_size(0), - " rows (with shape ", tensor_shape.DebugString(), ")."))); - slice.set_start(0, pair.first); - slice.set_length(0, 1); - tensor_slices.push_back(slice); + if (min_old_row >= 0 && max_old_row >= 0) { + int64 row_start = min_old_row; + // TODO(weiho): Given the list of old row IDs of interest (the keys of + // old_row_to_new_row_map), we could also try something smarter to + // find some minimal set of covering ranges for the list of old row IDs + // such that the size of each range is less than max_rows_in_memory_. + while (row_start <= max_old_row) { + const int64 slice_length = + max_rows_in_memory_ <= 0 + // If max_rows_in_memory_ <= 0, we just load the entire chunk. + ? max_old_row - row_start + 1 + : std::min(max_rows_in_memory_, max_old_row - row_start + 1); + slice.set_start(0, row_start); + slice.set_length(0, slice_length); + tensor_slices.push_back(slice); + row_start += slice_length; + } } // Allocates the output matrix. @@ -174,52 +197,72 @@ class LoadAndRemapMatrixOp : public OpKernel { // Iterates through tensor slices and copies over values from the old tensor // to the output matrix. - Tensor loaded_tensor_t(DT_FLOAT, - TensorShape({1, tensor_shape.dim_size(1)})); - for (int i = 0; i < tensor_slices.size(); ++i) { - const int64 new_row = old_row_to_new_row_pairs[i].second; - if (i % 500000 == 0) { - LOG(INFO) << "Processing slice " << i << " of " << tensor_slices.size() - << " - corresponding to old row " - << old_row_to_new_row_pairs[i].first << " of " - << tensor_shape.dim_size(0); - } + int64 row_index = min_old_row; + int64 rows_copied = 0; + Tensor loaded_tensor_t; + for (const TensorSlice& tensor_slice : tensor_slices) { + LOG(INFO) << "Loading slice " << tensor_slice.DebugString(); + TensorShape slice_shape; OP_REQUIRES_OK(context, - reader.LookupSlice(old_tensor_name, tensor_slices[i], - &loaded_tensor_t)); + tensor_slice.SliceTensorShape(tensor_shape, &slice_shape)); + // Potentially re-allocates the tensor buffer since the last slice may + // have fewer rows than the other slices. + if (loaded_tensor_t.shape() != slice_shape) { + loaded_tensor_t = Tensor(DT_FLOAT, slice_shape); + } + OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice, + &loaded_tensor_t)); - // Copies over the row element-by-element, in case remapping is needed - // along the column axis. - const auto& loaded_tensor = loaded_tensor_t.flat(); - for (int old_col = 0; old_col < loaded_tensor.size(); ++old_col) { - int64 new_col = old_col; - if (remap_cols) { - const int64* new_col_ptr = - gtl::FindOrNull(old_col_to_new_col_map, old_col); - if (new_col_ptr == nullptr) { - // Column remapping is specified, but this column is not found in - // old_col_to_new_col_map, so we leave it uninitialized, to be - // filled in with initializing_values later. - continue; - } - new_col = *new_col_ptr; + // Iterates through the old loaded tensor slice row-by-row. + for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) { + if (row_index % 500000 == min_old_row) { + LOG(INFO) << "Processing old row " << row_index; + } + + // If the old row ID is not found in old_row_to_new_row_map, continue + // to the next row; otherwise, copy it to the output matrix. + const int64* new_row_ptr = + gtl::FindOrNull(old_row_to_new_row_map, row_index); + if (new_row_ptr == nullptr) { + continue; } + ++rows_copied; + const int64 new_row = *new_row_ptr; - OP_REQUIRES(context, - new_row < num_rows_ && new_col < num_cols_ && - new_row >= 0 && new_col >= 0, - errors::Internal(strings::StrCat( - "new_row=", new_row, " and new_col=", new_col, - " should have been less than num_rows_=", num_rows_, - " and num_cols_=", num_cols_, - " and non-negative. This should never have happened " - "if the code were correct. Please file a bug."))); - output_matrix(new_row, new_col) = loaded_tensor(old_col); + // Copies over the row element-by-element, in case remapping is needed + // along the column axis. + const auto& loaded_tensor = loaded_tensor_t.matrix(); + for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1); + ++old_col) { + int64 new_col = old_col; + if (remap_cols) { + const int64* new_col_ptr = + gtl::FindOrNull(old_col_to_new_col_map, old_col); + if (new_col_ptr == nullptr) { + // Column remapping is specified, but this column is not found in + // old_col_to_new_col_map, so we leave it uninitialized, to be + // filled in with initializing_values later. + continue; + } + new_col = *new_col_ptr; + } + + OP_REQUIRES(context, + new_row < num_rows_ && new_col < num_cols_ && + new_row >= 0 && new_col >= 0, + errors::Internal(strings::StrCat( + "new_row=", new_row, " and new_col=", new_col, + " should have been less than num_rows_=", num_rows_, + " and num_cols_=", num_cols_, + " and non-negative. This should never have happened " + "if the code were correct. Please file a bug."))); + output_matrix(new_row, new_col) = loaded_tensor(row, old_col); + } } } - LOG(INFO) << "Copied " << tensor_slices.size() - << " rows from old matrix (with " << tensor_shape.dim_size(0) - << " rows) to new matrix (with " << num_rows_ << " rows)."; + LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with " + << tensor_shape.dim_size(0) << " rows) to new matrix (with " + << num_rows_ << " rows)."; // At this point, there are potentially whole rows/columns uninitialized // (corresponding to the indices where row_id_present/col_id_present are @@ -232,10 +275,14 @@ class LoadAndRemapMatrixOp : public OpKernel { int64 initializing_values_index = 0; for (int i = 0; i < num_rows_; ++i) { for (int j = 0; j < num_cols_; ++j) { - if (!row_id_present[i] || !col_id_present[j]) { - output_matrix(i, j) = initializing_values(initializing_values_index); - ++initializing_values_index; - } + if (row_id_present[i] && col_id_present[j]) continue; + OP_REQUIRES( + context, initializing_values_index < initializing_values.size(), + errors::InvalidArgument( + "initializing_values contained ", initializing_values.size(), + " elements, but more missing values remain.")); + output_matrix(i, j) = initializing_values(initializing_values_index); + ++initializing_values_index; } } @@ -251,6 +298,7 @@ class LoadAndRemapMatrixOp : public OpKernel { private: int64 num_rows_; int64 num_cols_; + int64 max_rows_in_memory_; }; REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU), diff --git a/tensorflow/contrib/framework/ops/checkpoint_ops.cc b/tensorflow/contrib/framework/ops/checkpoint_ops.cc index 09d487dd64..b49d7b4d40 100644 --- a/tensorflow/contrib/framework/ops/checkpoint_ops.cc +++ b/tensorflow/contrib/framework/ops/checkpoint_ops.cc @@ -83,6 +83,7 @@ REGISTER_OP("LoadAndRemapMatrix") .Input("initializing_values: float") .Attr("num_rows: int >= 0") .Attr("num_cols: int >= 1") + .Attr("max_rows_in_memory: int = -1") .Output("output_matrix: float") // TODO(b/30502450): Setting the op as being stateful prevents it from being // executed more often than expected (possibly due to stateful ops not being @@ -154,6 +155,9 @@ initializing_values: A float `Tensor` containing values to fill in for cells exactly the same as the number of missing / new cells. num_rows: Number of rows (length of the 1st dimension) in the output matrix. num_cols: Number of columns (length of the 2nd dimension) in the output matrix. +max_rows_in_memory: The maximum number of rows to load from the checkpoint at + once. If less than or equal to 0, the entire matrix will be loaded into + memory. Setting this arg trades increased disk reads for lower memory usage. output_matrix: Output matrix containing existing values loaded from the checkpoint, and with any missing values filled in from initializing_values. )doc"); diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops.py index fdb834f46b..92228f8916 100644 --- a/tensorflow/contrib/framework/python/ops/checkpoint_ops.py +++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops.py @@ -46,7 +46,8 @@ def _load_and_remap_matrix(ckpt_path, old_col_vocab_file=None, new_col_vocab_file=None, num_row_oov_buckets=0, - num_col_oov_buckets=0): + num_col_oov_buckets=0, + max_rows_in_memory=-1): """Loads a 2-D (matrix) `Tensor` from checkpoint. Generates 1D-remappings for rows and columns using the @@ -99,6 +100,10 @@ def _load_and_remap_matrix(ckpt_path, to append. Must be >= 0. num_col_oov_buckets: `int` specifying the number of out-of-vocabulary columns to append. Must be >= 0. + max_rows_in_memory: `int` specifying the maximum number of rows to load from + the checkpoint at once. If less than or equal to 0, the entire matrix will + be loaded into memory. Setting this arg trades increased disk reads for + lower memory usage. Returns: A Tensor of shape `[num_rows_to_load + num_row_oov_buckets, @@ -177,7 +182,8 @@ def _load_and_remap_matrix(ckpt_path, col_remapping=col_remapping, initializing_values=init_vals, num_rows=num_rows_to_load, - num_cols=new_col_vocab_size) + num_cols=new_col_vocab_size, + max_rows_in_memory=max_rows_in_memory) # Add OOV row(s) and column(s). if num_row_oov_buckets > 0: @@ -204,7 +210,8 @@ def load_and_remap_matrix_initializer(ckpt_path, new_col_vocab_file=None, num_row_oov_buckets=0, num_col_oov_buckets=0, - initializer=None): + initializer=None, + max_rows_in_memory=-1): r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor. The returned initializer loads a 2-D (matrix) `Tensor` with name @@ -297,6 +304,10 @@ def load_and_remap_matrix_initializer(ckpt_path, initializer: Initializer function to initialize missing values. Accepts a 1-D tensor as the arg to specify the shape of the returned tensor. If `None`, defaults to using `zeros_initializer()`. + max_rows_in_memory: `int` specifying the maximum number of rows to load from + the checkpoint at once. If less than or equal to 0, the entire matrix will + be loaded into memory. Setting this arg trades increased disk reads for + lower memory usage. Returns: A variable initializer function that should be used to initialize a @@ -378,7 +389,8 @@ def load_and_remap_matrix_initializer(ckpt_path, old_col_vocab_file=old_col_vocab_file, new_col_vocab_file=new_col_vocab_file, num_row_oov_buckets=row_oov_buckets_to_use, - num_col_oov_buckets=num_col_oov_buckets) + num_col_oov_buckets=num_col_oov_buckets, + max_rows_in_memory=max_rows_in_memory) return _initializer @@ -390,7 +402,8 @@ def load_embedding_initializer(ckpt_path, old_vocab_file, new_vocab_file, num_oov_buckets=0, - initializer=None): + initializer=None, + max_rows_in_memory=-1): """Returns a variable initializer for loading pre-trained embeddings. Wrapper around `load_and_remap_matrix_initializer()` specialized for loading @@ -416,6 +429,10 @@ def load_embedding_initializer(ckpt_path, initializer: Initializer function that accepts a 1-D tensor as the arg to specify the shape of the returned tensor. If `None`, defaults to using `truncated_normal_initializer()`. + max_rows_in_memory: `int` specifying the maximum number of rows to load from + the checkpoint at once. If less than or equal to 0, the entire matrix will + be loaded into memory. Setting this arg trades increased disk reads for + lower memory usage. Returns: A variable initializer function. @@ -437,7 +454,8 @@ def load_embedding_initializer(ckpt_path, new_col_vocab_file=None, num_row_oov_buckets=num_oov_buckets, num_col_oov_buckets=0, - initializer=initializer) + initializer=initializer, + max_rows_in_memory=max_rows_in_memory) def load_linear_multiclass_bias_initializer(ckpt_path, @@ -446,7 +464,8 @@ def load_linear_multiclass_bias_initializer(ckpt_path, old_class_vocab_file, new_class_vocab_file, num_class_oov_buckets=0, - initializer=None): + initializer=None, + max_rows_in_memory=-1): """Loads pre-trained multi-class biases for linear models from checkpoint. Wrapper around `load_and_remap_matrix_initializer()` specialized for loading @@ -469,6 +488,10 @@ def load_linear_multiclass_bias_initializer(ckpt_path, initializer: Initializer function that accepts a 1-D tensor as the arg to specify the shape of the returned tensor. If `None`, defaults to using `zeros_initializer()`. + max_rows_in_memory: `int` specifying the maximum number of rows to load from + the checkpoint at once. If less than or equal to 0, the entire matrix will + be loaded into memory. Setting this arg trades increased disk reads for + lower memory usage. Returns: A variable initializer function. @@ -488,7 +511,8 @@ def load_linear_multiclass_bias_initializer(ckpt_path, new_col_vocab_file=None, num_row_oov_buckets=num_class_oov_buckets, num_col_oov_buckets=0, - initializer=initializer) + initializer=initializer, + max_rows_in_memory=max_rows_in_memory) def load_variable_slot_initializer(ckpt_path, @@ -502,7 +526,8 @@ def load_variable_slot_initializer(ckpt_path, new_col_vocab_file=None, num_row_oov_buckets=0, num_col_oov_buckets=0, - initializer=None): + initializer=None, + max_rows_in_memory=-1): """Loads pre-trained multi-class slots for linear models from checkpoint. Wrapper around `load_and_remap_matrix_initializer()` specialized for loading @@ -549,6 +574,10 @@ def load_variable_slot_initializer(ckpt_path, initializer: Initializer function to initialize missing values. Accepts a 1-D tensor as the arg to specify the shape of the returned tensor. If `None`, defaults to using `zeros_initializer()`. + max_rows_in_memory: `int` specifying the maximum number of rows to load from + the checkpoint at once. If less than or equal to 0, the entire matrix will + be loaded into memory. Setting this arg trades increased disk reads for + lower memory usage. Returns: A variable initializer function that should be used to initialize a @@ -570,7 +599,8 @@ def load_variable_slot_initializer(ckpt_path, new_col_vocab_file=new_col_vocab_file, num_row_oov_buckets=num_row_oov_buckets, num_col_oov_buckets=num_col_oov_buckets, - initializer=initializer) + initializer=initializer, + max_rows_in_memory=max_rows_in_memory) def _initializer(shape, dtype=dtypes.float32, partition_info=None): del partition_info # Unused by this override. diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py index 321375ddfc..911c5a210c 100644 --- a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py @@ -118,7 +118,7 @@ class LoadAndRemapMatrixTest(test.TestCase): # No column remapping, new weight matrix has second row, then first row. row_remapping = [1, 0] - remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, @@ -128,12 +128,12 @@ class LoadAndRemapMatrixTest(test.TestCase): num_cols=self.old_num_cols) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping], - remapped_weight_matrix.eval()) + remapped_matrix.eval()) # No row remapping, new weight matrix has third col, then first col. row_remapping = list(range(self.old_num_rows)) col_remapping = [2, 0] - remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, @@ -143,12 +143,12 @@ class LoadAndRemapMatrixTest(test.TestCase): num_cols=len(col_remapping)) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping], - remapped_weight_matrix.eval()) + remapped_matrix.eval()) # Both row and column remappings. row_remapping = [1, 0, 4] col_remapping = [1, 15] - remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, @@ -158,12 +158,12 @@ class LoadAndRemapMatrixTest(test.TestCase): num_cols=len(col_remapping)) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping], - remapped_weight_matrix.eval()) + remapped_matrix.eval()) def test_load_and_remap_with_init(self): """Tests the op's load and remap where there are missing entries.""" init_val = 42 - remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], @@ -172,18 +172,17 @@ class LoadAndRemapMatrixTest(test.TestCase): num_rows=3, num_cols=2) - expected_remapped_weight_matrix = np.reshape( + expected_remapped_matrix = np.reshape( [33, init_val, init_val, init_val, 1, init_val], [3, 2]) with self.test_session(): - self.assertAllClose(expected_remapped_weight_matrix, - remapped_weight_matrix.eval()) + self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval()) def test_load_and_remap_all_missing_rows(self): """Tests when all the rows are missing and need to be initialized.""" num_rows = 7 initializing_values = [42] * num_rows * self.old_num_cols - remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[-1] * num_rows, @@ -194,14 +193,14 @@ class LoadAndRemapMatrixTest(test.TestCase): with self.test_session(): self.assertAllClose( np.reshape(initializing_values, (num_rows, self.old_num_cols)), - remapped_weight_matrix.eval()) + remapped_matrix.eval()) def test_load_and_remap_all_missing_rows_and_cols(self): """Tests when all the rows & cols are missing and need to be initialized.""" num_rows = 7 num_cols = 4 initializing_values = [42] * num_rows * num_cols - remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[-1] * num_rows, @@ -212,42 +211,216 @@ class LoadAndRemapMatrixTest(test.TestCase): with self.test_session(): self.assertAllClose( np.reshape(initializing_values, (num_rows, num_cols)), - remapped_weight_matrix.eval()) + remapped_matrix.eval()) - def test_load_and_remap_duplicate_row_remapping(self): - """Tests when an old row maps to multiple new rows. + def test_load_and_remap_invalid_remapping(self): + """Tests that errors are raised when an ID maps to multiple new IDs. (This should usually not happen when using public APIs). """ - row_remapping = [1, 0, 0, 0, 1, 2] - remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( + invalid_remapping = [1, 0, 0, 0, 1, 2] + + # Invalid row remapping. + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, - row_remapping=row_remapping, + row_remapping=invalid_remapping, col_remapping=[], initializing_values=[], - num_rows=len(row_remapping), + num_rows=len(invalid_remapping), num_cols=self.old_num_cols) - with self.test_session(): - self.assertAllClose(self.matrix_value[row_remapping], - remapped_weight_matrix.eval()) - - def test_load_and_remap_invalid_col_remapping(self): - """Tests that an error is raised when an old col maps to multiple new cols. + with self.test_session(), self.assertRaises(errors.UnimplementedError): + remapped_matrix.eval() - (This should usually not happen when using public APIs). - """ - col_remapping = [1, 0, 0, 0, 1, 2] - remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( + # Invalid column remapping. + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=list(range(self.old_num_rows)), - col_remapping=col_remapping, + col_remapping=invalid_remapping, initializing_values=[], num_rows=self.old_num_rows, - num_cols=len(col_remapping)) + num_cols=len(invalid_remapping)) with self.test_session(), self.assertRaises(errors.UnimplementedError): - remapped_weight_matrix.eval() + remapped_matrix.eval() + + def test_load_and_remap_incorrect_initializing_values(self): + """Tests that errors are raised with incorrect number of init values.""" + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( + ckpt_path=[self.bundle_file], + old_tensor_name=self.old_tensor_name, + row_remapping=[2, -1, 0], + col_remapping=[1, -1], + # Too few initializing values - there should be 4. For some reason, + # initializing_values must contain no element (instead of 3 or fewer) to + # ensure that a seg fault would reliably occur if the check raising the + # InvalidArgumentError were not present. + initializing_values=[], + num_rows=3, + num_cols=2) + with self.test_session(), self.assertRaises(errors.InvalidArgumentError): + remapped_matrix.eval() + + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( + ckpt_path=[self.bundle_file], + old_tensor_name=self.old_tensor_name, + row_remapping=[2, -1, 0], + col_remapping=[1, -1], + # Too many initializing values - there should be 4. + initializing_values=[0] * 5, + num_rows=3, + num_cols=2) + with self.test_session(), self.assertRaises(errors.InvalidArgumentError): + remapped_matrix.eval() + + +class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase): + """Tests for the load_and_remap_matrix() op. + + (Specifically focused on the max_rows_in_memory arg and its effects on + TensorBundle's BundleReader and TensorSlice logic). + """ + + def _test_loading_variable_with_max_rows(self, np_value, partitioner, + max_rows_in_memory): + """Helper function for various tests using max_rows_in_memory.""" + ops.reset_default_graph() + old_tensor_name = 'matrix_to_load_and_remap' + matrix = variable_scope.get_variable( + old_tensor_name, + dtype=dtypes.float32, + initializer=constant_op.constant(np_value, dtype=dtypes.float32), + partitioner=partitioner) + + with self.test_session() as sess: + ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt') + save = saver.Saver([matrix]) + variables.global_variables_initializer().run() + save.save(sess, ckpt_path) + num_rows, num_cols = np_value.shape + + # Tests loading the entire tensor (except reversed). + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( + ckpt_path=ckpt_path, + old_tensor_name=old_tensor_name, + # Simply reverses the rows of the matrix. + row_remapping=list(range(num_rows - 1, -1, -1)), + col_remapping=[], + initializing_values=[], + num_rows=num_rows, + num_cols=num_cols, + max_rows_in_memory=max_rows_in_memory) + self.assertAllClose(np_value[::-1], remapped_matrix.eval()) + + # Tests loading the tensor (except for the first and last rows), with + # uninitialized values. Requires num_rows to be at least 3 since we're + # skipping the first and last rows. + self.assertGreater(num_rows, 2) + prefix_rows = 2 + suffix_rows = 3 + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( + ckpt_path=ckpt_path, + old_tensor_name=old_tensor_name, + # Reverses the rows of the matrix, then prepends and appends + # uninitialized rows. + row_remapping=([-1] * prefix_rows + list(range(1, num_rows - 1)) + + [-1] * suffix_rows), + col_remapping=[], + initializing_values=[42] * (prefix_rows + suffix_rows) * num_cols, + num_rows=num_rows - 2 + prefix_rows + suffix_rows, + num_cols=num_cols, + max_rows_in_memory=max_rows_in_memory) + self.assertAllClose( + np.vstack([ + np.tile(42, [prefix_rows, num_cols]), np_value[1:-1], + np.tile(42, [suffix_rows, num_cols]) + ]), remapped_matrix.eval()) + + # Tests when everything is taken from initializing_values. + new_rows = 7 + initializing_values = [42] * new_rows * num_cols + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( + ckpt_path=ckpt_path, + old_tensor_name=old_tensor_name, + # Nothing is loaded from the old tensor. + row_remapping=[-1] * new_rows, + col_remapping=[], + initializing_values=initializing_values, + num_rows=new_rows, + num_cols=num_cols, + max_rows_in_memory=max_rows_in_memory) + self.assertAllClose( + np.reshape(initializing_values, (new_rows, num_cols)), + remapped_matrix.eval()) + + def test_loading_rows_divisible_by_max_rows(self): + """Tests loading normal var when rows are evenly divisible by max_rows.""" + self._test_loading_variable_with_max_rows( + np_value=np.reshape(list(range(0, 36)), (9, 4)), + partitioner=None, + # 9 is evenly divisible by 3. + max_rows_in_memory=3) + + def test_loading_rows_not_divisible_by_max_rows(self): + """Tests loading normal var when rows aren't divisible by max_rows.""" + self._test_loading_variable_with_max_rows( + np_value=np.reshape(list(range(0, 36)), (9, 4)), + partitioner=None, + # 9 is not evenly divisible by 4. + max_rows_in_memory=4) + + def test_loading_rows_less_than_max_rows(self): + """Tests loading normal var as a single slice. + + (When the specified max_rows_in_memory is larger than the number of rows) + """ + self._test_loading_variable_with_max_rows( + np_value=np.reshape(list(range(0, 36)), (9, 4)), + partitioner=None, + # 10 > 9. + max_rows_in_memory=10) + + def test_loading_no_max_rows(self): + """Tests loading normal var as a single slice with no valid max_rows.""" + self._test_loading_variable_with_max_rows( + np_value=np.reshape(list(range(0, 18)), (6, 3)), + partitioner=None, + max_rows_in_memory=-1) + + def test_loading_partitions_equals_max_rows(self): + """Tests loading partitioned var sliced on partition boundary.""" + self._test_loading_variable_with_max_rows( + np_value=np.reshape(list(range(0, 36)), (9, 4)), + partitioner=partitioned_variables.fixed_size_partitioner(3), + # With a tensor of shape [9, 3] and 3 partitions, each partition has + # exactly 3 rows. + max_rows_in_memory=3) + + def test_loading_partitions_greater_than_max_rows(self): + """Tests loading partitioned var with more slices than partitions.""" + self._test_loading_variable_with_max_rows( + np_value=np.reshape(list(range(0, 36)), (9, 4)), + partitioner=partitioned_variables.fixed_size_partitioner(3), + # Even though each partition has 3 rows, we'll only load the tensor one + # row at a time. + max_rows_in_memory=1) + + def test_loading_partitions_less_than_max_rows(self): + """Tests loading partitioned var as a single slice. + + (When the specified max_rows_in_memory is larger than the number of rows) + """ + self._test_loading_variable_with_max_rows( + np_value=np.reshape(list(range(0, 36)), (9, 4)), + partitioner=partitioned_variables.fixed_size_partitioner(3), + max_rows_in_memory=10) + + def test_loading_partitions_no_max_rows(self): + """Tests loading partitioned var as single slice with no valid max_rows.""" + self._test_loading_variable_with_max_rows( + np_value=np.reshape(list(range(0, 36)), (9, 4)), + partitioner=partitioned_variables.fixed_size_partitioner(3), + max_rows_in_memory=-1) class LoadAndRemapWrappersTest(test.TestCase): -- cgit v1.2.3 From 55f987692a25645a9db06e915c3fa248c3e5193c Mon Sep 17 00:00:00 2001 From: Yutaka Leon Date: Wed, 7 Jun 2017 11:16:59 -0700 Subject: Make tf.contrib.lookup python functions use the kernels v2 that uses the resource tensor as handler. PiperOrigin-RevId: 158291836 --- .../learn/python/learn/estimators/estimator.py | 4 +- tensorflow/contrib/lookup/lookup_ops.py | 1130 ++------------------ tensorflow/core/public/version.h | 3 +- tensorflow/python/estimator/estimator.py | 4 +- tensorflow/python/estimator/estimator_test.py | 2 +- tensorflow/python/training/saver_test_utils.py | 12 +- 6 files changed, 90 insertions(+), 1065 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index d7ba2209ad..534aac644a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -209,7 +209,9 @@ def _get_replica_device_setter(config): """ ps_ops = [ 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', - 'MutableHashTableOfTensors', 'MutableDenseHashTable' + 'MutableHashTableV2', 'MutableHashTableOfTensors', + 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', + 'MutableDenseHashTableV2' ] if config.task_type: diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 7600d30539..d5d413c56a 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -18,807 +18,32 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import functools - -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_lookup_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import string_ops +from tensorflow.python.ops import lookup_ops +# pylint: disable=unused-import +from tensorflow.python.ops.lookup_ops import FastHashSpec +from tensorflow.python.ops.lookup_ops import HasherSpec +from tensorflow.python.ops.lookup_ops import HashTable +from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets +from tensorflow.python.ops.lookup_ops import index_table_from_file +from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file +from tensorflow.python.ops.lookup_ops import InitializableLookupTableBase +from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer +from tensorflow.python.ops.lookup_ops import LookupInterface +from tensorflow.python.ops.lookup_ops import StrongHashSpec +from tensorflow.python.ops.lookup_ops import TableInitializerBase +from tensorflow.python.ops.lookup_ops import TextFileIdTableInitializer +from tensorflow.python.ops.lookup_ops import TextFileIndex +from tensorflow.python.ops.lookup_ops import TextFileInitializer +from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer +# pylint: enable=unused-import from tensorflow.python.training.saver import BaseSaverBuilder -from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated -class LookupInterface(object): - """Represent a lookup table that persists across different steps.""" - - def __init__(self, key_dtype, value_dtype, name): - """Construct a lookup table interface. - - Args: - key_dtype: The table key type. - value_dtype: The table value type. - name: A name for the operation (optional). - """ - self._key_dtype = dtypes.as_dtype(key_dtype) - self._value_dtype = dtypes.as_dtype(value_dtype) - self._name = name - - @property - def key_dtype(self): - """The table key dtype.""" - return self._key_dtype - - @property - def value_dtype(self): - """The table value dtype.""" - return self._value_dtype - - @property - def name(self): - """The name of the table.""" - return self._name - - @property - def init(self): - """The table initialization op.""" - raise NotImplementedError - - def size(self, name=None): - """Compute the number of elements in this table.""" - raise NotImplementedError - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values.""" - raise NotImplementedError - - def check_table_dtypes(self, key_dtype, value_dtype): - """Check that the given key_dtype and value_dtype matches the table dtypes. - - Args: - key_dtype: The key data type to check. - value_dtype: The value data type to check. - - Raises: - TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data - types. - """ - if key_dtype != self.key_dtype: - raise TypeError("Invalid key dtype, expected %s but got %s." % - (self.key_dtype, key_dtype)) - if value_dtype != self.value_dtype: - raise TypeError("Invalid value dtype, expected %s but got %s." % - (self.value_dtype, value_dtype)) - - -class InitializableLookupTableBase(LookupInterface): - """Initializable lookup table interface. - - An initializable lookup tables persist across different steps. - """ - - def __init__(self, table_ref, default_value, initializer): - """Construct a table object from a table reference. - - If requires a table initializer object (subclass of `TableInitializerBase`). - It provides the table key and value types, as well as the op to initialize - the table. The caller is responsible to execute the initialization op. - - Args: - table_ref: The table reference, i.e. the output of the lookup table ops. - default_value: The value to use if a key is missing in the table. - initializer: The table initializer to use. - """ - super(InitializableLookupTableBase, self).__init__( - initializer.key_dtype, initializer.value_dtype, - table_ref.op.name.split("/")[-1]) - self._table_ref = table_ref - self._default_value = ops.convert_to_tensor(default_value, - dtype=self._value_dtype) - self._default_value.get_shape().merge_with(tensor_shape.scalar()) - self._init = initializer.initialize(self) - - @property - def table_ref(self): - """Get the underlying table reference.""" - return self._table_ref - - @property - def default_value(self): - """The default value of the table.""" - return self._default_value - - @property - def init(self): - """The table initialization op.""" - return self._init - - def size(self, name=None): - """Compute the number of elements in this table. - - Args: - name: A name for the operation (optional). - - Returns: - A scalar tensor containing the number of elements in this table. - """ - with ops.name_scope(name, "%s_Size" % self._name, - [self._table_ref]) as scope: - # pylint: disable=protected-access - return gen_lookup_ops._lookup_table_size(self._table_ref, name=scope) - # pylint: enable=protected-access - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values. - - The `default_value` is used for keys not present in the table. - - Args: - keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. - name: A name for the operation (optional). - - Returns: - A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`. - - Raises: - TypeError: when `keys` or `default_value` doesn't match the table data - types. - """ - key_tensor = keys - if isinstance(keys, sparse_tensor.SparseTensor): - key_tensor = keys.values - - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - - with ops.name_scope( - name, "%s_Lookup" % self._name, - (self._table_ref, key_tensor, self._default_value)) as scope: - # pylint: disable=protected-access - values = gen_lookup_ops._lookup_table_find( - self._table_ref, key_tensor, self._default_value, name=scope) - # pylint: enable=protected-access - - values.set_shape(key_tensor.get_shape()) - if isinstance(keys, sparse_tensor.SparseTensor): - return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) - else: - return values - - -class HashTable(InitializableLookupTableBase): - """A generic hash table implementation. - - Example usage: - - ```python - table = tf.contrib.lookup.HashTable( - tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1) - out = table.lookup(input_tensor). - table.init.run() - print out.eval() - ``` - """ - - def __init__(self, initializer, default_value, shared_name=None, name=None): - """Creates a non-initialized `HashTable` object. - - Creates a table, the type of its keys and values are specified by the - initializer. - Before using the table you will have to initialize it. After initialization - the table will be immutable. - - Args: - initializer: The table initializer to use. See `HashTable` kernel for - supported key and value types. - default_value: The value to use if a key is missing in the table. - shared_name: If non-empty, this table will be shared under - the given name across multiple sessions. - name: A name for the operation (optional). - - Returns: - A `HashTable` object. - """ - with ops.name_scope( - name, "hash_table", (initializer, default_value)) as scope: - # pylint: disable=protected-access - table_ref = gen_lookup_ops._hash_table( - shared_name=shared_name, - key_dtype=initializer.key_dtype, - value_dtype=initializer.value_dtype, - name=scope) - # pylint: enable=protected-access - - super(HashTable, self).__init__(table_ref, default_value, initializer) - - -class TableInitializerBase(object): - """Base class for lookup table initializers.""" - - def __init__(self, key_dtype, value_dtype): - """Construct a table initializer object. - - Args: - key_dtype: Type of the table keys. - value_dtype: Type of the table values. - """ - self._key_dtype = dtypes.as_dtype(key_dtype) - self._value_dtype = dtypes.as_dtype(value_dtype) - - @property - def key_dtype(self): - """The expected table key dtype.""" - return self._key_dtype - - @property - def value_dtype(self): - """The expected table value dtype.""" - return self._value_dtype - - def initialize(self, table): - """Returns the table initialization op.""" - raise NotImplementedError - - -class KeyValueTensorInitializer(TableInitializerBase): - """Table initializers given `keys` and `values` tensors.""" - - def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): - """Constructs a table initializer object based on keys and values tensors. - - Args: - keys: The tensor for the keys. - values: The tensor for the values. - key_dtype: The `keys` data type. Used when `keys` is a python array. - value_dtype: The `values` data type. Used when `values` is a python array. - name: A name for the operation (optional). - """ - with ops.name_scope(name, "key_value_init", [keys, values]) as scope: - self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") - self._values = ops.convert_to_tensor(values, - dtype=value_dtype, - name="values") - self._name = scope - - super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, - self._values.dtype) - - def initialize(self, table): - """Initializes the given `table` with `keys` and `values` tensors. - - Args: - table: The table to initialize. - - Returns: - The operation that initializes the table. - - Raises: - TypeError: when the keys and values data types do not match the table - key and value data types. - """ - table.check_table_dtypes(self._keys.dtype, self._values.dtype) - with ops.name_scope( - self._name, - values=(table.table_ref, self._keys, self._values)) as scope: - # pylint: disable=protected-access - init_op = gen_lookup_ops._initialize_table( - table.table_ref, self._keys, self._values, name=scope) - # pylint: enable=protected-access - ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) - return init_op - - -class TextFileIndex(object): - WHOLE_LINE = -2 - LINE_NUMBER = -1 - - -class TextFileInitializer(TableInitializerBase): - """Table initializers from a text file. - - This initializer assigns one entry in the table for each line in the file. - - The key and value type of the table to initialize is given by `key_dtype` and - `value_dtype`. - - The key and value content to get from each line is specified by - the `key_index` and `value_index`. - - * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero, - expects data type int64. - * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data - type string. - * A value `>=0` means use the index (starting at zero) of the split line based - on `delimiter`. - - For example if we have a file with the following content: - - ``` - emerson 10 - lake 20 - palmer 30 - ``` - - The following snippet initializes a table with the first column as keys and - second column as values: - - * `emerson -> 10` - * `lake -> 20` - * `palmer -> 30` - - ```python - table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer( - "test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1) - ... - table.init.run() - ``` - - Similarly to initialize the whole line as keys and the line number as values. - - * `emerson 10 -> 0` - * `lake 20 -> 1` - * `palmer 30 -> 2` - - ```python - table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer( - "test.txt", tf.string, tf.contrib.lookup.TextFileIndex.WHOLE_LINE, - tf.int64, tf.contrib.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1) - ... - table.init.run() - ``` - """ - - def __init__(self, - filename, - key_dtype, - key_index, - value_dtype, - value_index, - vocab_size=None, - delimiter="\t", - name=None): - """Constructs a table initializer object to populate from a text file. - - It generates one key-value pair per line. The type of table key and - value are specified by `key_dtype` and `value_dtype`, respectively. - Similarly the content of the key and value are specified by the key_index - and value_index. - - - TextFileIndex.LINE_NUMBER means use the line number starting from zero, - expects data type int64. - - TextFileIndex.WHOLE_LINE means use the whole line content, expects data - type string. - - A value >=0 means use the index (starting at zero) of the split line based - on `delimiter`. - - Args: - filename: The filename of the text file to be used for initialization. - The path must be accessible from wherever the graph is initialized - (eg. trainer or eval workers). The filename may be a scalar `Tensor`. - key_dtype: The `key` data type. - key_index: the index that represents information of a line to get the - table 'key' values from. - value_dtype: The `value` data type. - value_index: the index that represents information of a line to get the - table 'value' values from.' - vocab_size: The number of elements in the file, if known. - delimiter: The delimiter to separate fields in a line. - name: A name for the operation (optional). - - Raises: - ValueError: when the filename is empty, or when the table key and value - data types do not match the expected data types. - """ - if not isinstance(filename, ops.Tensor) and not filename: - raise ValueError("Filename required for %s." % name) - - key_dtype = dtypes.as_dtype(key_dtype) - value_dtype = dtypes.as_dtype(value_dtype) - - if key_index < -2: - raise ValueError("Invalid key index %s." % (key_index)) - - if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64: - raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." % - (dtypes.int64, key_dtype)) - if ((key_index == TextFileIndex.WHOLE_LINE) and - (not key_dtype.is_integer) and (key_dtype != dtypes.string)): - raise ValueError( - "Signature mismatch. Keys must be integer or string, got %s." % - key_dtype) - if value_index < -2: - raise ValueError("Invalid value index %s." % (value_index)) - - if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64: - raise ValueError("Signature mismatch. Values must be dtype %s, got %s." % - (dtypes.int64, value_dtype)) - if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string: - raise ValueError("Signature mismatch. Values must be dtype %s, got %s." % - (dtypes.string, value_dtype)) - - if (vocab_size is not None) and (vocab_size <= 0): - raise ValueError("Invalid vocab_size %s." % vocab_size) - - self._filename = filename - self._key_index = key_index - self._value_index = value_index - self._vocab_size = vocab_size - self._delimiter = delimiter - self._name = name - - super(TextFileInitializer, self).__init__(key_dtype, value_dtype) - - def initialize(self, table): - """Initializes the table from a text file. - - Args: - table: The table to be initialized. - - Returns: - The operation that initializes the table. - - Raises: - TypeError: when the keys and values data types do not match the table - key and value data types. - """ - table.check_table_dtypes(self.key_dtype, self.value_dtype) - with ops.name_scope( - self._name, "text_file_init", (table.table_ref,)) as scope: - filename = ops.convert_to_tensor(self._filename, - dtypes.string, - name="asset_filepath") - # pylint: disable=protected-access - init_op = gen_lookup_ops._initialize_table_from_text_file( - table.table_ref, - filename, - self._key_index, - self._value_index, - -1 if self._vocab_size is None else self._vocab_size, - self._delimiter, - name=scope) - # pylint: enable=protected-access - ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) - # If the filename tensor is anything other than a string constant (e.g., if - # it is a placeholder) then it does not make sense to track it as an asset. - if constant_op.is_constant(filename): - ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) - return init_op - - -class TextFileStringTableInitializer(TextFileInitializer): - """Table initializer for `int64` IDs to string tables from a text file.""" - - def __init__(self, - filename, - key_column_index=TextFileIndex.LINE_NUMBER, - value_column_index=TextFileIndex.WHOLE_LINE, - vocab_size=None, - delimiter="\t", - name="text_file_string_table_init"): - """Constructs an initializer for an id-to-string table from a text file. - - It populates a table that its key and value types are int64 and string, - respectively. It generates one key-value pair per line. - The content of the key and value are specified by `key_column_index` - and `value_column_index`. - - - TextFileIndex.LINE_NUMBER means use the line number starting from zero, - expects data type int64. - - TextFileIndex.WHOLE_LINE means use the whole line content, expects data - type string. - - A value >=0 means use the index (starting at zero) of the split line based - on `delimiter`. - - Args: - filename: The filename of the text file to be used for initialization. - The path must be accessible from wherever the graph is initialized - (eg. trainer or eval workers). The filename may be a scalar `Tensor`. - key_column_index: The column index from the text file to get the keys - from. The default is 0 that represents the whole line content. - value_column_index: The column index from the text file to get the - values from. The default is to use the line number, starting from zero. - vocab_size: The number of elements in the file, if known. - delimiter: The delimiter to separate fields in a line. - name: Optional name for the op. - - Raises: - TypeError: when the filename is empty, or when the table key and value - data types do not match the expected data types. - """ - super(TextFileStringTableInitializer, self).__init__(filename, - dtypes.int64, - key_column_index, - dtypes.string, - value_column_index, - vocab_size=vocab_size, - delimiter=delimiter, - name=name) - - -class TextFileIdTableInitializer(TextFileInitializer): - """Table initializer for string to `int64` IDs tables from a text file.""" - - def __init__(self, - filename, - key_column_index=TextFileIndex.WHOLE_LINE, - value_column_index=TextFileIndex.LINE_NUMBER, - vocab_size=None, - delimiter="\t", - name="text_file_id_table_init", - key_dtype=dtypes.string): - """Constructs an initializer for an string-to-id table from a text file. - - It populates a table that its key and value types are string and int64, - respectively. It generates one key-value pair per line. - The content of the key and value are specified by the key_index - and value_index. - - - TextFileIndex.LINE_NUMBER means use the line number starting from zero, - expects data type int64. - - TextFileIndex.WHOLE_LINE means use the whole line content, expects data - type string. - - A value >=0 means use the index (starting at zero) of the split line based - on `delimiter`. - - Args: - filename: The filename of the text file to be used for initialization. - The path must be accessible from wherever the graph is initialized - (eg. trainer or eval workers). The filename may be a scalar `Tensor`. - key_column_index: The column index from the text file to get the `key` - values from. The default is to use the line number, starting from zero. - value_column_index: The column index from the text file ro get the `value` - values from. The default is 0 that represents the whole line content. - vocab_size: The number of elements in the file, if known. - delimiter: The delimiter to separate fields in a line. - name: Optional name for the op. - key_dtype: The `key` data type. - - Raises: - TypeError: when the filename is empty, or when the table key and value - data types do not match the expected data types. - """ - super(TextFileIdTableInitializer, self).__init__(filename, - key_dtype, - key_column_index, - dtypes.int64, - value_column_index, - vocab_size=vocab_size, - delimiter=delimiter, - name=name) - - -class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])): - """A structure for the spec of the hashing function to use for hash buckets. - - `hasher` is the name of the hashing function to use (eg. "fasthash", - "stronghash"). - `key` is optional and specify the key to use for the hash function if - supported, currently only used by a strong hash. - - Fields: - hasher: The hasher name to use. - key: The key to be used by the hashing function, if required. - """ - __slots__ = () - - -FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name - - -class StrongHashSpec(HasherSpec): - """A structure to specify a key of the strong keyed hash spec. - - The strong hash requires a `key`, which is a list of 2 unsigned integer - numbers. These should be non-zero; random numbers generated from random.org - would be a fine choice. - - Fields: - key: The key to be used by the keyed hashing function. - """ - __slots__ = () - - def __new__(cls, key): - if len(key) != 2: - raise ValueError("key must have size 2, got %s." % len(key)) - - if not isinstance(key[0], compat.integral_types) or not isinstance( - key[1], compat.integral_types): - raise TypeError("Invalid key %s. Must be unsigned integer values." % key) - - return super(cls, StrongHashSpec).__new__(cls, "stronghash", key) - - -def _as_string(tensor): - if dtypes.string == tensor.dtype.base_dtype: - return tensor - return string_ops.as_string(tensor) - - -class IdTableWithHashBuckets(LookupInterface): - """String to Id table wrapper that assigns out-of-vocabulary keys to buckets. - - For example, if an instance of `IdTableWithHashBuckets` is initialized with a - string-to-id table that maps: - - - emerson -> 0 - - lake -> 1 - - palmer -> 2 - - The `IdTableWithHashBuckets` object will performs the following mapping: - - - emerson -> 0 - - lake -> 1 - - palmer -> 2 - - -> bucket id between 3 and 3 + num_oov_buckets, calculated by: - hash() % num_oov_buckets + vocab_size - - If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`, - the lookup result is `[0, 1, 2, 4, 7]` - - If `table` is None, only out-of-vocabulary buckets are used. - - Example usage: - - ```python - num_oov_buckets = 3 - input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"]) - table = tf.IdTableWithHashBuckets( - tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value), - num_oov_buckets) - out = table.lookup(input_tensor). - table.init.run() - print out.eval() - ``` - - The hash function used for generating out-of-vocabulary buckets ID is handled - by `hasher_spec`. - """ - - def __init__(self, - table, - num_oov_buckets, - hasher_spec=FastHashSpec, - name=None, - key_dtype=None): - """Construct a `IdTableWithHashBuckets` object. - - Args: - table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. - num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. - hasher_spec: A `HasherSpec` to specify the hash function to use for - assignation of out-of-vocabulary buckets (optional). - name: A name for the operation (optional). - key_dtype: Data type of keys passed to `lookup`. Defaults to - `table.key_dtype` if `table` is specified, otherwise `tf.string`. - Must be string or integer, and must be castable to `table.key_dtype`. - - Raises: - ValueError: when `table` in None and `num_oov_buckets` is not positive. - TypeError: when `hasher_spec` is invalid. - """ - # If a name ends with a '/' it is a "name scope", remove all trailing '/' - # characters to use as table name. - if name: - name = name.rstrip("/") - if table: - if key_dtype is None: - key_dtype = table.key_dtype - supported_table_key_dtypes = (dtypes.int64, dtypes.string) - if table.key_dtype not in supported_table_key_dtypes: - raise TypeError("Invalid key dtype, expected one of %s, but got %s." % - (supported_table_key_dtypes, key_dtype)) - if table.key_dtype.is_integer != key_dtype.is_integer: - raise TypeError("Invalid key dtype, expected %s but got %s." % - ("integer" if key_dtype.is_integer else "non-integer", - table.key_dtype)) - if table.value_dtype != dtypes.int64: - raise TypeError("Invalid value dtype, expected %s but got %s." % - (dtypes.int64, table.value_dtype)) - self._table = table - name = name or self._table.name - else: - if num_oov_buckets <= 0: - raise ValueError("oov_buckets must be > 0 if no table is supplied.") - key_dtype = dtypes.string if key_dtype is None else key_dtype - self._table = None - name = name or "hash_bucket" - if (not key_dtype.is_integer) and (dtypes.string != key_dtype): - raise TypeError( - "Invalid key_dtype, expected integer or string, got %s." % key_dtype) - self._num_oov_buckets = num_oov_buckets - - if not isinstance(hasher_spec, HasherSpec): - raise TypeError("hasher_spec must be of type HasherSpec, got %s" % - hasher_spec) - self._hasher_spec = hasher_spec - super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64, - name.split("/")[-1]) - - @property - def init(self): - """The table initialization op.""" - if self._table: - return self._table.init - with ops.name_scope(None, "init"): - return control_flow_ops.no_op() - - def size(self, name=None): - """Compute the number of elements in this table.""" - with ops.name_scope(name, "%s_Size" % self.name) as scope: - if self._table: - tsize = self._table.size(scope) - else: - tsize = ops.convert_to_tensor(0, dtype=dtypes.int64) - return tsize + self._num_oov_buckets - - def _get_string_to_hash_bucket_fn(self, hasher_spec): - """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" - if not isinstance(hasher_spec, HasherSpec): - raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec) - if hasher_spec.hasher == "fasthash": - return string_ops.string_to_hash_bucket_fast - if hasher_spec.hasher == "legacy": - return string_ops.string_to_hash_bucket - if hasher_spec.hasher == "stronghash": - return functools.partial( - string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) - raise ValueError("Unknown hasher %s" % hasher_spec.hasher) - - def lookup(self, keys, name=None): - """Looks up `keys` in the table, outputs the corresponding values. - - It assigns out-of-vocabulary keys to buckets based in their hashes. - - Args: - keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`. - name: Optional name for the op. - - Returns: - A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`. - - Raises: - TypeError: when `keys` doesn't match the table key data type. - """ - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - values = keys - if isinstance(keys, sparse_tensor.SparseTensor): - values = keys.values - if self._table and (self._table.key_dtype.base_dtype == dtypes.int64): - values = math_ops.to_int64(values) - - if self._num_oov_buckets == 0: - ids = self._table.lookup(values, name=name) - else: - # TODO(yleon): Consider moving this functionality to its own kernel. - with ops.name_scope(name, "%s_Lookup" % self.name) as scope: - str_to_hash_bucket = self._get_string_to_hash_bucket_fn( - self._hasher_spec) - buckets = str_to_hash_bucket( - _as_string(values), - num_buckets=self._num_oov_buckets, - name="hash_bucket") - if self._table: - ids = self._table.lookup(values) - buckets = math_ops.add(buckets, self._table.size()) - is_id_non_default = math_ops.not_equal(ids, self._table.default_value) - ids = array_ops.where(is_id_non_default, ids, buckets, name=scope) - else: - ids = buckets - if isinstance(keys, sparse_tensor.SparseTensor): - return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape) - return ids - - @deprecated("2017-04-10", "Use `index_table_from_file`.") def string_to_index_table_from_file(vocabulary_file=None, num_oov_buckets=0, @@ -831,113 +56,6 @@ def string_to_index_table_from_file(vocabulary_file=None, key_dtype=dtypes.string, name=name) -def index_table_from_file(vocabulary_file=None, - num_oov_buckets=0, - vocab_size=None, - default_value=-1, - hasher_spec=FastHashSpec, - key_dtype=dtypes.string, - name=None): - """Returns a lookup table that converts a string tensor into int64 IDs. - - This operation constructs a lookup table to convert tensor of strings into - int64 IDs. The mapping can be initialized from a vocabulary file specified in - `vocabulary_file`, where the whole line is the key and the zero-based line - number is the ID. - - Any lookup of an out-of-vocabulary token will return a bucket ID based on its - hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the - `default_value`. - The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets]`. - - The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.init.run()` once. - - Sample Usages: - - If we have a vocabulary file "test.txt" with the following content: - - ``` - emerson - lake - palmer - ``` - - ```python - features = tf.constant(["emerson", "lake", "and", "palmer"]) - table = tf.contrib.lookup.index_table_from_file( - vocabulary_file="test.txt", num_oov_buckets=1) - ids = table.lookup(features) - ... - tf.tables_initializer().run() - - ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket - ``` - - Args: - vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. - num_oov_buckets: The number of out-of-vocabulary buckets. - vocab_size: Number of the elements in the vocabulary, if known. - default_value: The value to use for out-of-vocabulary feature values. - Defaults to -1. - hasher_spec: A `HasherSpec` to specify the hash function to use for - assignation of out-of-vocabulary buckets. - key_dtype: The `key` data type. - name: A name for this op (optional). - - Returns: - The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`. - - Raises: - ValueError: If `vocabulary_file` is not set. - ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater - than zero. - """ - if vocabulary_file is None or ( - isinstance(vocabulary_file, str) and not vocabulary_file): - raise ValueError("vocabulary_file must be specified and must not be empty.") - if num_oov_buckets < 0: - raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." - % num_oov_buckets) - if vocab_size is not None and vocab_size < 1: - raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) - if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): - raise TypeError("Only integer and string keys are supported.") - - with ops.name_scope(name, "string_to_index") as feat_to_id_scope: - table = None - shared_name = "" - with ops.name_scope(None, "hash_table") as hash_table_scope: - if vocab_size: - # Keep the shared_name: - # ____ - shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size, - TextFileIndex.WHOLE_LINE, - TextFileIndex.LINE_NUMBER) - else: - # Keep the shared_name - # ___ - shared_name = "hash_table_%s_%s_%s" % (vocabulary_file, - TextFileIndex.WHOLE_LINE, - TextFileIndex.LINE_NUMBER) - init = TextFileIdTableInitializer( - vocabulary_file, vocab_size=vocab_size, - key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype, - name="table_init") - - table = HashTable( - init, default_value, shared_name=shared_name, name=hash_table_scope) - if num_oov_buckets: - table = IdTableWithHashBuckets( - table, - num_oov_buckets=num_oov_buckets, - hasher_spec=hasher_spec, - name=feat_to_id_scope, - key_dtype=key_dtype) - - return table - - @deprecated("2017-04-10", "Use `index_table_from_tensor`.") def string_to_index_table_from_tensor(mapping, num_oov_buckets=0, @@ -1011,41 +129,13 @@ def index_table_from_tensor(mapping, """ if mapping is None: raise ValueError("mapping must be specified.") - - if num_oov_buckets < 0: - raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." - % num_oov_buckets) - - if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype): - raise TypeError("Only integer and string keys are supported.") - - with ops.name_scope(name, "string_to_index") as feat_to_id_scope: - keys = ops.convert_to_tensor(mapping) - if keys.dtype.is_integer != dtype.is_integer: - raise ValueError("Expected %s, got %s." % ( - "integer" if dtype.is_integer else "non-integer", keys.dtype)) - if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype): - raise ValueError("Expected %s, got %s." % (dtype, keys.dtype)) - num_elements = array_ops.size(keys) - values = math_ops.to_int64(math_ops.range(num_elements)) - - shared_name = "" - with ops.name_scope(None, "hash_table") as hash_table_scope: - table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys - init = KeyValueTensorInitializer( - table_keys, values, table_keys.dtype.base_dtype, dtypes.int64, - name="table_init") - table = HashTable( - init, default_value, shared_name=shared_name, name=hash_table_scope) - if num_oov_buckets: - table = IdTableWithHashBuckets( - table, - num_oov_buckets=num_oov_buckets, - hasher_spec=hasher_spec, - name=feat_to_id_scope, - key_dtype=dtype) - - return table + return lookup_ops.index_table_from_tensor( + vocabulary_list=mapping, + num_oov_buckets=num_oov_buckets, + default_value=default_value, + hasher_spec=hasher_spec, + dtype=dtype, + name=name) @deprecated( @@ -1098,83 +188,6 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): return table.lookup(tensor) -def index_to_string_table_from_file(vocabulary_file, - vocab_size=None, - default_value="UNK", - name=None): - """Returns a lookup table that maps a `Tensor` of indices into strings. - - This operation constructs a lookup table to map int64 indices into string - values. The table is initialized from a vocabulary file specified in - `vocabulary_file`, where the whole line is the value and the - zero-based line number is the index. - - Any input which does not have a corresponding index in the vocabulary file - (an out-of-vocabulary entry) is assigned the `default_value` - - The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.init.run()` once. - - Sample Usages: - - If we have a vocabulary file "test.txt" with the following content: - - ``` - emerson - lake - palmer - ``` - - ```python - indices = tf.constant([1, 5], tf.int64) - table = tf.contrib.lookup.index_to_string_table_from_file( - vocabulary_file="test.txt", default_value="UNKNOWN") - values = table.lookup(indices) - ... - tf.tables_initializer().run() - - values.eval() ==> ["lake", "UNKNOWN"] - ``` - - Args: - vocabulary_file: The vocabulary filename. - vocab_size: Number of the elements in the vocabulary, if known. - default_value: The value to use for out-of-vocabulary indices. - name: A name for this op (optional). - - Returns: - The lookup table to map a string values associated to a given index `int64` - `Tensors`. - - Raises: - ValueError: when `vocabulary_file` is empty. - ValueError: when `vocab_size` is invalid. - """ - if not vocabulary_file: - raise ValueError("vocabulary_file must be specified.") - if vocab_size is not None and vocab_size < 1: - raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) - - with ops.name_scope(name, "index_to_string") as scope: - shared_name = "" - if vocab_size: - # Keep a shared_name - # ____ - shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size, - TextFileIndex.LINE_NUMBER, - TextFileIndex.WHOLE_LINE) - else: - # Keep a shared_name ___ - shared_name = "hash_table_%s_%s_%s" % (vocabulary_file, - TextFileIndex.LINE_NUMBER, - TextFileIndex.WHOLE_LINE) - init = TextFileStringTableInitializer( - vocabulary_file, vocab_size=vocab_size, name="table_init") - - # TODO(yleon): Use a more effienct structure. - return HashTable(init, default_value, shared_name=shared_name, name=scope) - - def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): """Returns a lookup table that maps a `Tensor` of indices into strings. @@ -1223,16 +236,8 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): if mapping is None: raise ValueError("mapping must be specified.") - with ops.name_scope(name, "index_to_string") as scope: - values = ops.convert_to_tensor(mapping, dtypes.string) - num_elements = array_ops.size(values) - keys = math_ops.to_int64(math_ops.range(num_elements)) - - shared_name = "" - init = KeyValueTensorInitializer( - keys, values, dtypes.int64, dtypes.string, name="table_init") - # TODO(yleon): Use a more effienct structure. - return HashTable(init, default_value, shared_name=shared_name, name=scope) + return lookup_ops.index_to_string_table_from_tensor( + vocabulary_list=mapping, default_value=default_value, name=name) @deprecated( @@ -1338,14 +343,14 @@ class MutableHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None # pylint: disable=protected-access if self._default_value.get_shape().ndims == 0: - self._table_ref = gen_lookup_ops._mutable_hash_table( + self._table_ref = gen_lookup_ops._mutable_hash_table_v2( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: - self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors( + self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors_v2( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, @@ -1372,8 +377,10 @@ class MutableHashTable(LookupInterface): """ with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: - # pylint: disable=protected-access - return gen_lookup_ops._lookup_table_size(self._table_ref, name=name) + with ops.colocate_with(self._table_ref): + + # pylint: disable=protected-access + return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -1398,11 +405,12 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, (self._table_ref, keys, self._default_value)) as name: - # pylint: disable=protected-access - values = gen_lookup_ops._lookup_table_find( - self._table_ref, keys, self._default_value, name=name) + with ops.colocate_with(self._table_ref): + # pylint: disable=protected-access + values = gen_lookup_ops._lookup_table_find_v2( + self._table_ref, keys, self._default_value, name=name) - values.set_shape(keys.get_shape().concatenate(self._value_shape)) + values.set_shape(keys.get_shape().concatenate(self._value_shape)) return values def insert(self, keys, values, name=None): @@ -1422,13 +430,16 @@ class MutableHashTable(LookupInterface): TypeError: when `keys` or `values` doesn't match the table data types. """ - self.check_table_dtypes(keys.dtype, values.dtype) + # pylint: disable=protected-access + lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype) + # pylint: enable=protected-access with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: - # pylint: disable=protected-access - op = gen_lookup_ops._lookup_table_insert( - self._table_ref, keys, values, name=name) - return op + with ops.colocate_with(self._table_ref): + # pylint: disable=protected-access + op = gen_lookup_ops._lookup_table_insert_v2( + self._table_ref, keys, values, name=name) + return op def export(self, name=None): """Returns tensors of all keys and values in the table. @@ -1442,9 +453,10 @@ class MutableHashTable(LookupInterface): """ with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: - # pylint: disable=protected-access - exported_keys, exported_values = gen_lookup_ops._lookup_table_export( - self._table_ref, self._key_dtype, self._value_dtype, name=name) + with ops.colocate_with(self._table_ref): + # pylint: disable=protected-access + exported_keys, exported_values = gen_lookup_ops._lookup_table_export_v2( + self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( self._value_shape)) @@ -1464,8 +476,9 @@ class MutableHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access - return gen_lookup_ops._lookup_table_import( - self.op._table_ref, restored_tensors[0], restored_tensors[1]) + with ops.colocate_with(self.op._table_ref): + return gen_lookup_ops._lookup_table_import_v2( + self.op._table_ref, restored_tensors[0], restored_tensors[1]) class MutableDenseHashTable(LookupInterface): @@ -1539,7 +552,7 @@ class MutableDenseHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) # pylint: disable=protected-access - self._table_ref = gen_lookup_ops._mutable_dense_hash_table( + self._table_ref = gen_lookup_ops._mutable_dense_hash_table_v2( empty_key=empty_key, shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, @@ -1566,8 +579,9 @@ class MutableDenseHashTable(LookupInterface): """ with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: - # pylint: disable=protected-access - return gen_lookup_ops._lookup_table_size(self._table_ref, name=name) + with ops.colocate_with(self._table_ref): + # pylint: disable=protected-access + return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -1592,9 +606,10 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, [self._table_ref, keys]) as name: - # pylint: disable=protected-access - values = gen_lookup_ops._lookup_table_find( - self._table_ref, keys, self._default_value, name=name) + with ops.colocate_with(self._table_ref): + # pylint: disable=protected-access + values = gen_lookup_ops._lookup_table_find_v2( + self._table_ref, keys, self._default_value, name=name) if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: values.set_shape( @@ -1619,12 +634,15 @@ class MutableDenseHashTable(LookupInterface): TypeError: when `keys` or `values` doesn't match the table data types. """ - self.check_table_dtypes(keys.dtype, values.dtype) + # pylint: disable=protected-access + lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype) + # pylint: enable=protected-access with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: - # pylint: disable=protected-access - op = gen_lookup_ops._lookup_table_insert( - self._table_ref, keys, values, name=name) + with ops.colocate_with(self._table_ref): + # pylint: disable=protected-access + op = gen_lookup_ops._lookup_table_insert_v2( + self._table_ref, keys, values, name=name) return op def export(self, name=None): @@ -1639,9 +657,10 @@ class MutableDenseHashTable(LookupInterface): """ with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: - # pylint: disable=protected-access - exported_keys, exported_values = gen_lookup_ops._lookup_table_export( - self._table_ref, self._key_dtype, self._value_dtype, name=name) + with ops.colocate_with(self._table_ref): + # pylint: disable=protected-access + exported_keys, exported_values = gen_lookup_ops._lookup_table_export_v2( + self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( self._value_shape)) @@ -1661,5 +680,6 @@ class MutableDenseHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access - return gen_lookup_ops._lookup_table_import( - self.op._table_ref, restored_tensors[0], restored_tensors[1]) + with ops.colocate_with(self.op._table_ref): + return gen_lookup_ops._lookup_table_import_v2( + self.op._table_ref, restored_tensors[0], restored_tensors[1]) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 566d9aa908..57ff12dcd7 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -88,10 +88,11 @@ limitations under the License. // shapes, particularly when restoring a graph from GraphDef // produced at version 22 or later. (04/10/2016) // 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2. +// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017) #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 23 +#define TF_GRAPH_DEF_VERSION 24 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 944073a4ca..f424598ccb 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -725,7 +725,9 @@ def _get_replica_device_setter(config): """ ps_ops = [ 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', - 'MutableHashTableOfTensors', 'MutableDenseHashTable' + 'MutableHashTableV2', 'MutableHashTableOfTensors', + 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', + 'MutableDenseHashTableV2' ] if config.task_type: diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index c8aab5dac8..4119a07bd8 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -1286,7 +1286,7 @@ class EstimatorExportTest(test.TestCase): self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops) # Note that the SavedModel builder replaced the Saver with a new one - self.assertTrue('save_1/LookupTableImport' in graph_ops) + self.assertTrue('save_1/LookupTableImportV2' in graph_ops) # Clean up. gfile.DeleteRecursively(tmpdir) diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py index 6a73565f82..bcabb41304 100644 --- a/tensorflow/python/training/saver_test_utils.py +++ b/tensorflow/python/training/saver_test_utils.py @@ -34,7 +34,7 @@ class CheckpointedOp(object): # pylint: disable=protected-access def __init__(self, name, table_ref=None): if table_ref is None: - self.table_ref = gen_lookup_ops._mutable_hash_table( + self.table_ref = gen_lookup_ops._mutable_hash_table_v2( key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) else: self.table_ref = table_ref @@ -52,10 +52,10 @@ class CheckpointedOp(object): return self._saveable def insert(self, keys, values): - return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values) + return gen_lookup_ops._lookup_table_insert_v2(self.table_ref, keys, values) def lookup(self, keys, default): - return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default) + return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys, default) def keys(self): return self._export()[0] @@ -64,8 +64,8 @@ class CheckpointedOp(object): return self._export()[1] def _export(self): - return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string, - dtypes.float32) + return gen_lookup_ops._lookup_table_export_v2(self.table_ref, dtypes.string, + dtypes.float32) class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): """A custom saveable for CheckpointedOp.""" @@ -81,6 +81,6 @@ class CheckpointedOp(object): super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) def restore(self, restore_tensors, shapes): - return gen_lookup_ops._lookup_table_import( + return gen_lookup_ops._lookup_table_import_v2( self.op.table_ref, restore_tensors[0], restore_tensors[1]) # pylint: enable=protected-access -- cgit v1.2.3 From ba656b261141554f33b96c655e3a0c76eb0d837d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 11:18:16 -0700 Subject: Use template specialization instead of overloaded methods. This is a more appropriate tool here. NFC PiperOrigin-RevId: 158292035 --- .../layers/kernels/sparse_feature_cross_kernel.cc | 92 ++++++++++++---------- tensorflow/core/kernels/sparse_cross_op.cc | 92 ++++++++++++---------- 2 files changed, 98 insertions(+), 86 deletions(-) diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 219473153b..72df272af8 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -41,13 +41,7 @@ class ColumnInterface { virtual int64 FeatureCount(int64 batch) const = 0; // Returns the fingerprint of nth feature from the specified batch. - InternalType Feature(int64 batch, int64 n) const { - InternalType not_used = InternalType(); - return DoFeature(batch, n, not_used); - } - - virtual InternalType DoFeature(int64 batch, int64 n, - InternalType not_used) const = 0; + virtual InternalType Feature(int64 batch, int64 n) const = 0; virtual ~ColumnInterface() {} }; @@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface { return feature_counts_[batch]; } - // InternalType is int64 only when using HashCrosser. - int64 DoFeature(int64 batch, int64 n, int64 not_used) const { - const int64 start = feature_start_indices_[batch]; - if (DT_STRING == values_.dtype()) - return Fingerprint64(values_.vec().data()[start + n]); - return values_.vec().data()[start + n]; - } - - // InternalType is string or StringPiece when using StringCrosser. - string DoFeature(int64 batch, int64 n, string not_used) const { - const int64 start = feature_start_indices_[batch]; - if (DT_STRING == values_.dtype()) - return values_.vec().data()[start + n]; - return std::to_string(values_.vec().data()[start + n]); - } - - StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const { - const int64 start = feature_start_indices_[batch]; - return values_.vec().data()[start + n]; - } + InternalType Feature(int64 batch, int64 n) const override; ~SparseTensorColumn() override {} @@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface { std::vector feature_start_indices_; }; +// InternalType is int64 only when using HashCrosser. +template <> +int64 SparseTensorColumn::Feature(int64 batch, int64 n) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return Fingerprint64(values_.vec().data()[start + n]); + return values_.vec().data()[start + n]; +} + +// InternalType is string or StringPiece when using StringCrosser. +template <> +string SparseTensorColumn::Feature(int64 batch, int64 n) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return values_.vec().data()[start + n]; + return std::to_string(values_.vec().data()[start + n]); +} + +template <> +StringPiece SparseTensorColumn::Feature(int64 batch, + int64 n) const { + const int64 start = feature_start_indices_[batch]; + return values_.vec().data()[start + n]; +} + // A column that is backed by a dense tensor. template class DenseTensorColumn : public ColumnInterface { @@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface { int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); } - // InternalType is int64 only when using HashCrosser. - int64 DoFeature(int64 batch, int64 n, int64 not_used) const { - if (DT_STRING == tensor_.dtype()) - return Fingerprint64(tensor_.matrix()(batch, n)); - return tensor_.matrix()(batch, n); - } - - // Internal type is string or StringPiece when using StringCrosser. - string DoFeature(int64 batch, int64 n, string not_used) const { - if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); - return std::to_string(tensor_.matrix()(batch, n)); - } - - StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const { - return tensor_.matrix()(batch, n); - } + InternalType Feature(int64 batch, int64 n) const override; ~DenseTensorColumn() override {} @@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface { const Tensor& tensor_; }; +// InternalType is int64 only when using HashCrosser. +template <> +int64 DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) + return Fingerprint64(tensor_.matrix()(batch, n)); + return tensor_.matrix()(batch, n); +} + +// Internal type is string or StringPiece when using StringCrosser. +template <> +string DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); + return std::to_string(tensor_.matrix()(batch, n)); +} + +template <> +StringPiece DenseTensorColumn::Feature(int64 batch, + int64 n) const { + return tensor_.matrix()(batch, n); +} + // Updates Output tensors with sparse crosses. template class OutputUpdater { diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index ed93caad33..c7bf250fad 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -41,13 +41,7 @@ class ColumnInterface { virtual int64 FeatureCount(int64 batch) const = 0; // Returns the fingerprint of nth feature from the specified batch. - InternalType Feature(int64 batch, int64 n) const { - InternalType not_used = InternalType(); - return DoFeature(batch, n, not_used); - } - - virtual InternalType DoFeature(int64 batch, int64 n, - InternalType not_used) const = 0; + virtual InternalType Feature(int64 batch, int64 n) const = 0; virtual ~ColumnInterface() {} }; @@ -68,26 +62,7 @@ class SparseTensorColumn : public ColumnInterface { return feature_counts_[batch]; } - // InternalType is int64 only when using HashCrosser. - int64 DoFeature(int64 batch, int64 n, int64 not_used) const { - const int64 start = feature_start_indices_[batch]; - if (DT_STRING == values_.dtype()) - return Fingerprint64(values_.vec().data()[start + n]); - return values_.vec().data()[start + n]; - } - - // InternalType is string or StringPiece when using StringCrosser. - string DoFeature(int64 batch, int64 n, string not_used) const { - const int64 start = feature_start_indices_[batch]; - if (DT_STRING == values_.dtype()) - return values_.vec().data()[start + n]; - return std::to_string(values_.vec().data()[start + n]); - } - - StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const { - const int64 start = feature_start_indices_[batch]; - return values_.vec().data()[start + n]; - } + InternalType Feature(int64 batch, int64 n) const override; ~SparseTensorColumn() override {} @@ -97,6 +72,31 @@ class SparseTensorColumn : public ColumnInterface { std::vector feature_start_indices_; }; +// InternalType is int64 only when using HashCrosser. +template <> +int64 SparseTensorColumn::Feature(int64 batch, int64 n) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return Fingerprint64(values_.vec().data()[start + n]); + return values_.vec().data()[start + n]; +} + +// InternalType is string or StringPiece when using StringCrosser. +template <> +string SparseTensorColumn::Feature(int64 batch, int64 n) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return values_.vec().data()[start + n]; + return std::to_string(values_.vec().data()[start + n]); +} + +template <> +StringPiece SparseTensorColumn::Feature(int64 batch, + int64 n) const { + const int64 start = feature_start_indices_[batch]; + return values_.vec().data()[start + n]; +} + // A column that is backed by a dense tensor. template class DenseTensorColumn : public ColumnInterface { @@ -105,22 +105,7 @@ class DenseTensorColumn : public ColumnInterface { int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); } - // InternalType is int64 only when using HashCrosser. - int64 DoFeature(int64 batch, int64 n, int64 not_used) const { - if (DT_STRING == tensor_.dtype()) - return Fingerprint64(tensor_.matrix()(batch, n)); - return tensor_.matrix()(batch, n); - } - - // Internal type is string or StringPiece when using StringCrosser. - string DoFeature(int64 batch, int64 n, string not_used) const { - if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); - return std::to_string(tensor_.matrix()(batch, n)); - } - - StringPiece DoFeature(int64 batch, int64 n, StringPiece not_used) const { - return tensor_.matrix()(batch, n); - } + InternalType Feature(int64 batch, int64 n) const override; ~DenseTensorColumn() override {} @@ -128,6 +113,27 @@ class DenseTensorColumn : public ColumnInterface { const Tensor& tensor_; }; +// InternalType is int64 only when using HashCrosser. +template <> +int64 DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) + return Fingerprint64(tensor_.matrix()(batch, n)); + return tensor_.matrix()(batch, n); +} + +// Internal type is string or StringPiece when using StringCrosser. +template <> +string DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); + return std::to_string(tensor_.matrix()(batch, n)); +} + +template <> +StringPiece DenseTensorColumn::Feature(int64 batch, + int64 n) const { + return tensor_.matrix()(batch, n); +} + // Updates Output tensors with sparse crosses. template class OutputUpdater { -- cgit v1.2.3 From 94085bee74557f34fd7ad3bef969eecf6c8c4f4e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 11:20:44 -0700 Subject: Replace std::function object with regular function. The function is called recursively, and the std::function object had only existed to allow recursion from within a lambda expression. A regular function should be cheaper than a polymorphic function wrapper. PiperOrigin-RevId: 158292415 --- tensorflow/compiler/xla/service/hlo_instruction.cc | 109 +++++++++++++-------- tensorflow/compiler/xla/service/hlo_instruction.h | 3 + 2 files changed, 71 insertions(+), 41 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d713d826fb..ecbf1dd1e5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1306,7 +1306,7 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { void HloInstruction::DetachFromOperands() { CHECK_EQ(0, user_count()); - // An intruction may be repeated as an operand. To avoid calling RemoveUser + // An instruction may be repeated as an operand. To avoid calling RemoveUser // twice on the same operand, keep a set of already detached operands. std::set detached_operands; for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { @@ -2162,6 +2162,70 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { return true; } +// A helper class for memoized, recursive computation of HloOpcode::kFusion +// in HloInstruction::OperandElementUse below. +class HloInstruction::FusionReusesParamElements { + public: + using UseKind = HloInstruction::UseKind; + + // We could rather iterate backwards thru fused_instructions_ here, as it is + // in reverse postorder, and compute whether each fused instruction reuses the + // value of this parameter, which would save stack space but not allow us to + // finish early if we find a reuse. + static UseKind Compute(int64 i, const HloInstruction& hlo) { + tensorflow::gtl::FlatMap memoization_cache; + return ComputeInternal(i, hlo, &memoization_cache); + } + + private: + static UseKind ComputeInternal( + int64 i, const HloInstruction& hlo, + tensorflow::gtl::FlatMap* cache) { + if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) { + return UseKind::kUse; + } + + auto p = cache->emplace(&hlo, UseKind{}); + auto value_it = p.first; + const bool key_is_new = p.second; + + if (key_is_new) { + for (int64 j = 0; j < hlo.operands_.size(); ++j) { + UseKind old_val = value_it->second; + + // The next operation invalidates iterators. + UseKind new_val = + Plus(old_val, std::min(hlo.OperandElementUse(j), + ComputeInternal(i, *hlo.operand(j), cache))); + + // Re-acquire the iterator. We could work harder to do this only if + // absolutely necessary, but this code is not hot enough to warrant + // that. + value_it = cache->find(&hlo); + value_it->second = new_val; + } + } + return value_it->second; + } + + // Fold operation for UseKinds. + static UseKind Plus(UseKind a, UseKind b) { + if (a == UseKind::kNoUse) { + return b; + } else if (b == UseKind::kNoUse) { + return a; + } else if (a == UseKind::kReuse || b == UseKind::kReuse) { + return UseKind::kReuse; + } else if (a == UseKind::kUsePermutingElements || + b == UseKind::kUsePermutingElements) { + return UseKind::kReuse; + } else { + CHECK(a == UseKind::kUse && b == UseKind::kUse); + return UseKind::kUse; + } + } +}; + HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { switch (opcode_) { case HloOpcode::kBitcast: @@ -2176,46 +2240,9 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { // Pad reuses the padding value but not the padded array elements. // Reduce reuses the init value but not the operand array elements. return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; - case HloOpcode::kFusion: { - tensorflow::gtl::FlatMap cache; - // We could rather iterate backwards thru fused_instructions_ here, as it - // is in reverse postorder, and compute whether each fused instruction - // reuses the value of this parameter, which would save stack space but - // not allow us to finish early if we find a reuse. - std::function reuses_parameter_elements = - [i, &cache, &reuses_parameter_elements](const HloInstruction& hlo) { - auto plus = [](const UseKind& a, const UseKind& b) { - if (a == UseKind::kNoUse) { - return b; - } else if (b == UseKind::kNoUse) { - return a; - } else if (a == UseKind::kReuse || b == UseKind::kReuse) { - return UseKind::kReuse; - } else if (a == UseKind::kUsePermutingElements || - b == UseKind::kUsePermutingElements) { - return UseKind::kReuse; - } - CHECK(UseKind::kUse == a && UseKind::kUse == b); - return UseKind::kUse; - }; - - if (hlo.opcode_ == HloOpcode::kParameter && - hlo.parameter_number_ == i) { - return UseKind::kUse; - } - if (!ContainsKey(cache, &hlo)) { - for (int64 j = 0; j < hlo.operands_.size(); ++j) { - UseKind old = cache[&hlo]; - UseKind updated = plus( - old, std::min(hlo.OperandElementUse(j), - reuses_parameter_elements(*hlo.operand(j)))); - cache[&hlo] = updated; - } - } - return cache[&hlo]; - }; - return reuses_parameter_elements(*fused_expression_root()); - } + case HloOpcode::kFusion: + // Uses the memoizing, recursive computation defined above. + return FusionReusesParamElements::Compute(i, *fused_expression_root()); default: return IsElementwise() ? UseKind::kUse : UseKind::kReuse; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3bf46341be..522414325e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -775,6 +775,9 @@ class HloInstruction { private: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; + // Helper class for computing OperandElementUse for kFusion. + class FusionReusesParamElements; + // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, -- cgit v1.2.3 From b702e7e79d4cf2abb87146d6593e3d3ee3c2bba7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 11:32:47 -0700 Subject: Update ops-related pbtxt files. PiperOrigin-RevId: 158294289 --- tensorflow/core/ops/ops.pbtxt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 568c7e0216..31d47c1ab3 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -9268,7 +9268,7 @@ op { type: DT_FLOAT } summary: "Inverse real-valued fast Fourier transform." - description: "Computes the inverse 1-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most dimension of `input`.\n\nThe inner-most dimension of `input` is assumed to be the result of `RFFT`: the\n`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If\n`fft_length` is not provided, it is computed from the size of the inner-most\ndimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to\ncompute `input` is odd, it should be provided since it cannot be inferred\nproperly." + description: "Computes the inverse 1-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most dimension of `input`.\n\nThe inner-most dimension of `input` is assumed to be the result of `RFFT`: the\n`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If\n`fft_length` is not provided, it is computed from the size of the inner-most\ndimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to\ncompute `input` is odd, it should be provided since it cannot be inferred\nproperly.\n\nAlong the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller\nthan the corresponding dimension of `input`, the dimension is cropped. If it is\nlarger, the dimension is padded with zeros." } op { name: "IRFFT2D" @@ -9288,7 +9288,7 @@ op { type: DT_FLOAT } summary: "Inverse 2D real-valued fast Fourier transform." - description: "Computes the inverse 2-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 2 dimensions of `input`.\n\nThe inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 2 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly." + description: "Computes the inverse 2-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 2 dimensions of `input`.\n\nThe inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 2 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly.\n\nAlong each axis `IRFFT2D` is computed on, if `fft_length` (or\n`fft_length / 2 + 1` for the inner-most dimension) is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros." } op { name: "IRFFT3D" @@ -9308,7 +9308,7 @@ op { type: DT_FLOAT } summary: "Inverse 3D real-valued fast Fourier transform." - description: "Computes the inverse 3-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 3 dimensions of `input`.\n\nThe inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 3 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly." + description: "Computes the inverse 3-dimensional discrete Fourier transform of a real-valued\nsignal over the inner-most 3 dimensions of `input`.\n\nThe inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:\nThe inner-most dimension contains the `fft_length / 2 + 1` unique components of\nthe DFT of a real-valued signal. If `fft_length` is not provided, it is computed\nfrom the size of the inner-most 3 dimensions of `input`. If the FFT length used\nto compute `input` is odd, it should be provided since it cannot be inferred\nproperly.\n\nAlong each axis `IRFFT3D` is computed on, if `fft_length` (or\n`fft_length / 2 + 1` for the inner-most dimension) is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros." } op { name: "Identity" @@ -16180,7 +16180,7 @@ op { type: DT_COMPLEX64 } summary: "Real-valued fast Fourier transform." - description: "Computes the 1-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most dimension of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the\n`fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,\nfollowed by the `fft_length / 2` positive-frequency terms." + description: "Computes the 1-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most dimension of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the\n`fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,\nfollowed by the `fft_length / 2` positive-frequency terms.\n\nAlong the axis `RFFT` is computed on, if `fft_length` is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros." } op { name: "RFFT2D" @@ -16200,7 +16200,7 @@ op { type: DT_COMPLEX64 } summary: "2D real-valued fast Fourier transform." - description: "Computes the 2-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 2 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms." + description: "Computes the 2-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 2 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms.\n\nAlong each axis `RFFT2D` is computed on, if `fft_length` is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros." } op { name: "RFFT3D" @@ -16220,7 +16220,7 @@ op { type: DT_COMPLEX64 } summary: "3D real-valued fast Fourier transform." - description: "Computes the 3-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 3 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms." + description: "Computes the 3-dimensional discrete Fourier transform of a real-valued signal\nover the inner-most 3 dimensions of `input`.\n\nSince the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the\n`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension\nof `output`: the zero-frequency term, followed by the `fft_length / 2`\npositive-frequency terms.\n\nAlong each axis `RFFT3D` is computed on, if `fft_length` is smaller than the\ncorresponding dimension of `input`, the dimension is cropped. If it is larger,\nthe dimension is padded with zeros." } op { name: "RGBToHSV" -- cgit v1.2.3 From abe0877ef1abfbd678ba05f195c553303cc8263c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 11:34:47 -0700 Subject: Add bazel version check to .configure PiperOrigin-RevId: 158294569 --- configure | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/configure b/configure index 472f099796..71c14345f5 100755 --- a/configure +++ b/configure @@ -3,6 +3,8 @@ set -e set -o pipefail +MIN_BAZEL_VERSION=0.4.5 + # Find out the absolute path to where ./configure resides pushd `dirname $0` > /dev/null SOURCE_BASE_DIR=`pwd -P` @@ -163,6 +165,22 @@ function setup_python { echo "export PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" > tools/python_bin_path.sh } +function version { + echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; +} + + +bazel version > bazel.version +curr_bazel_version=$(head -n 1 bazel.version | cut -d ' ' -f3) +rm -f bazel.version + +echo "You have bazel $curr_bazel_version installed." +if [ "$(version "$MIN_BAZEL_VERSION")" -gt "$(version "$curr_bazel_version")" ]; then + echo "Please upgrade your bazel installation to version $MIN_BAZEL_VERSION or higher to build TensorFlow!" + echo "Exiting..." + exit 1 +fi + # This file contains customized config settings. rm -f .tf_configure.bazelrc touch .tf_configure.bazelrc -- cgit v1.2.3 From 492afc2e37c254bf4d97de84d581fe09d2d3dfe4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 11:38:51 -0700 Subject: Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 158295169 --- tensorflow/go/op/wrappers.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 72c28f95df..c4af3a60a8 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -13463,6 +13463,11 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { // to compute `input` is odd, it should be provided since it cannot be inferred // properly. // +// Along each axis `IRFFT2D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// // Arguments: // input: A complex64 tensor. // fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. @@ -16691,6 +16696,10 @@ func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { // compute `input` is odd, it should be provided since it cannot be inferred // properly. // +// Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller +// than the corresponding dimension of `input`, the dimension is cropped. If it is +// larger, the dimension is padded with zeros. +// // Arguments: // input: A complex64 tensor. // fft_length: An int32 tensor of shape [1]. The FFT length. @@ -16874,6 +16883,10 @@ func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o * // `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, // followed by the `fft_length / 2` positive-frequency terms. // +// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// // Arguments: // input: A float32 tensor. // fft_length: An int32 tensor of shape [1]. The FFT length. @@ -17169,6 +17182,10 @@ func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, // of `output`: the zero-frequency term, followed by the `fft_length / 2` // positive-frequency terms. // +// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// // Arguments: // input: A float32 tensor. // fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. @@ -17514,6 +17531,11 @@ func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output) (indices tf // to compute `input` is odd, it should be provided since it cannot be inferred // properly. // +// Along each axis `IRFFT3D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// // Arguments: // input: A complex64 tensor. // fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. @@ -18936,6 +18958,10 @@ func Erfc(scope *Scope, x tf.Output) (y tf.Output) { // of `output`: the zero-frequency term, followed by the `fft_length / 2` // positive-frequency terms. // +// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// // Arguments: // input: A float32 tensor. // fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. -- cgit v1.2.3 From f105df0478cea110129811062ca3d29f289492c0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 11:45:30 -0700 Subject: In the CUDA path of depthwise_conv2d, optimize backward filter convolution for images 2 or 4 times smaller than 16x16. Also initialize in_cols from blockDim, to fix the regression caused in CL 157906773. PiperOrigin-RevId: 158296136 --- .../core/kernels/depthwise_conv_op_gpu.cu.cc | 107 ++++++++++++--------- 1 file changed, 62 insertions(+), 45 deletions(-) diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index e4d7c3d11e..319dbb68e6 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -1308,9 +1308,11 @@ __global__ void __launch_bounds__(640, 2) // a partial convolution for two elements, one each in the lower and upper half // of a tile. The intermediate result of 4 consecutive columns are then // accumulated and written to shared memory. Finally, the values in shared -// memory are warp-accumulated (in chunks of 32 elements) and summed up in -// global memory using atomics. -template +// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed +// up in global memory using atomics. +template = args.in_rows * args.in_cols + int kAccumPixels> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { @@ -1321,7 +1323,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const int batches = args.batch; const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int in_cols = blockDim.y; // slower (see b/62280718): args.in_cols; const int in_depth = args.in_depth; const int filter_rows = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; @@ -1352,8 +1354,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const int tensor_offset = block_rows * in_row_size; // The accumulator has a fixed number of pixels that can be reduced by one // warp. Pixels beyond block_pixels/4 are never written. - const int accum_pixels = 32; - const int accum_increment = accum_pixels * block_slices; + const int accum_increment = kAccumPixels * block_slices; const int accum_size = filter_pixels * accum_increment; const int thread_depth = threadIdx.x; @@ -1383,7 +1384,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( // Position in accumulator (1 per 4 threads, depth major). const int accum_pix = thread_pix / 4; - const int accum_idx = thread_depth * accum_pixels + accum_pix; + const int accum_idx = thread_depth * kAccumPixels + accum_pix; const int max_depth = in_depth - thread_depth; const int accum_offset = tile_size + accum_idx; @@ -1438,19 +1439,17 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const T* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { - const int filter_idx = i / accum_pixels; + const int filter_idx = i / kAccumPixels; const int filter_pix = filter_idx / block_slices; const int filter_depth = filter_idx % block_slices + start_depth; const int filter_offset = filter_pix * in_depth + filter_depth; if (filter_depth < in_depth) { T val = accum_data[i]; - // Sum up the 32 pixels of the same depth from the accumulator. - val += CudaShuffleDown(val, 16); - val += CudaShuffleDown(val, 8); - val += CudaShuffleDown(val, 4); - val += CudaShuffleDown(val, 2); - val += CudaShuffleDown(val, 1); - if (!(thread_idx & 31) /* i.e. 'lane_idx == 0' */) { + // Warp-accumulate the pixels of the same depth from the accumulator. + for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) { + val += CudaShuffleDown(val, delta); + } + if (!(thread_idx & kAccumPixels - 1)) { CudaAtomicAdd(filter_offset + filter, val); } } @@ -1567,9 +1566,11 @@ __global__ void __launch_bounds__(640, 2) // a partial convolution for two elements, one each in the lower and upper half // of a tile. The intermediate result of 4 consecutive columns are then // accumulated and written to shared memory. Finally, the values in shared -// memory are warp-accumulated (in chunks of 32 elements) and summed up in -// global memory using atomics. -template +// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed +// up in global memory using atomics. +template = args.in_rows * args.in_cols + int kAccumPixels> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { @@ -1580,7 +1581,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const int batches = args.batch; const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int in_cols = blockDim.x; // slower (see b/62280718): args.in_cols; const int in_depth = args.in_depth; const int filter_rows = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; @@ -1610,8 +1611,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const int in_blocks = (in_slices + block_slices - 1) / block_slices; // The accumulator has a fixed number of pixels that can be reduced by one // warp. Pixels beyond block_pixels/4 are never written. - const int accum_pixels = 32; - const int accum_increment = accum_pixels * block_slices; + const int accum_increment = kAccumPixels * block_slices; const int accum_size = filter_pixels * accum_increment; const int thread_col = threadIdx.x; @@ -1640,7 +1640,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( // Position in accumulator (1 per 4 threads, depth major). const int accum_pix = thread_pix / 4; - const int accum_idx = thread_depth * accum_pixels + accum_pix; + const int accum_idx = thread_depth * kAccumPixels + accum_pix; const int max_slice = in_slices - thread_depth; const int accum_offset = tile_size + accum_idx; @@ -1692,19 +1692,17 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const T* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { - const int filter_idx = i / accum_pixels; + const int filter_idx = i / kAccumPixels; const int filter_pix = filter_idx / block_slices; const int filter_depth = (slice + filter_idx % block_slices) % in_depth; const int filter_offset = filter_pix * in_depth + filter_depth; if (filter_depth < in_depth) { T val = accum_data[i]; - // Sum up 32 pixels of the same depth from the accumulator. - val += CudaShuffleDown(val, 16); - val += CudaShuffleDown(val, 8); - val += CudaShuffleDown(val, 4); - val += CudaShuffleDown(val, 2); - val += CudaShuffleDown(val, 1); - if (!(thread_idx & 31) /* i.e. 'lane_idx == 0' */) { + // Warp-accumulate pixels of the same depth from the accumulator. + for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) { + val += CudaShuffleDown(val, delta); + } + if (!(thread_idx & kAccumPixels - 1)) { CudaAtomicAdd(filter_offset + filter, val); } } @@ -1712,7 +1710,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( } } -template +template void LaunchDepthwiseConv2dBackpropFilterGPUSmall( const GpuDevice& d, const DepthwiseArgs args, int block_rows, int shared_memory_size, const T* out_backprop, const T* input, @@ -1724,22 +1723,22 @@ void LaunchDepthwiseConv2dBackpropFilterGPUSmall( dim3 block_dim = dim3(block_slices, args.in_cols, block_rows); CudaLaunchConfig config = GetCudaLaunchConfig( num_out_backprop, d, - DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall, + DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>, shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall + DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels> <<>>( args, out_backprop, input, filter_backprop); } else if (data_format == FORMAT_NCHW) { dim3 block_dim = dim3(args.in_cols, block_rows, block_slices); CudaLaunchConfig config = GetCudaLaunchConfig( num_out_backprop, d, - DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall, + DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>, shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall + DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels> <<>>( args, out_backprop, input, filter_backprop); } else { @@ -1759,21 +1758,39 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( return false; } + const int in_pixels = args.in_rows * args.in_cols; + int accum_pixels = 8; + while (accum_pixels * 8 < in_pixels) { + accum_pixels *= 2; + } + const int block_slices = 8; const int tile_cols = args.in_cols + args.filter_cols - 1; const int tile_rows = block_rows * 2 + args.filter_rows - 1; const int tile_pixels = tile_rows * tile_cols; - const int accum_size = args.filter_rows * args.filter_cols * 32; + const int filter_pixels = args.filter_rows * args.filter_cols; const int shared_memory_size = - block_slices * (tile_pixels + accum_size) * sizeof(T); + block_slices * (tile_pixels + filter_pixels * accum_pixels) * sizeof(T); if (shared_memory_size > d.sharedMemPerBlock()) { return false; } - LaunchDepthwiseConv2dBackpropFilterGPUSmall( - d, args, block_rows, shared_memory_size, out_backprop, input, - filter_backprop, data_format); + if (accum_pixels == 8) { + LaunchDepthwiseConv2dBackpropFilterGPUSmall( + d, args, block_rows, shared_memory_size, out_backprop, input, + filter_backprop, data_format); + } else if (accum_pixels == 16) { + LaunchDepthwiseConv2dBackpropFilterGPUSmall( + d, args, block_rows, shared_memory_size, out_backprop, input, + filter_backprop, data_format); + } else { + LaunchDepthwiseConv2dBackpropFilterGPUSmall( + d, args, block_rows, shared_memory_size, out_backprop, input, + filter_backprop, data_format); + } return true; } -- cgit v1.2.3 From edb5fed7fcb23f2f7ad8f556eb44c0a8213184ca Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Wed, 7 Jun 2017 13:15:04 -0700 Subject: Add label-vocab support to binary logistic head. Add assertion that binary classifier label is in range [0., 1.] Fixed Classifier Integration tests. PiperOrigin-RevId: 158307521 --- .../estimator/canned/dnn_linear_combined_test.py | 31 +++-- tensorflow/python/estimator/canned/dnn_test.py | 51 ++++--- tensorflow/python/estimator/canned/head.py | 139 +++++++++++++------ tensorflow/python/estimator/canned/head_test.py | 149 ++++++++++----------- 4 files changed, 220 insertions(+), 150 deletions(-) diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py index 16b4be7b24..dd89f780e6 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py @@ -322,6 +322,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase): if self._model_dir: shutil.rmtree(self._model_dir) + def _as_label(self, data_in_float): + return np.rint(data_in_float).astype(np.int64) + def _test_complete_flow( self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, n_classes, batch_size): @@ -363,12 +366,13 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase): def test_numpy_input_fn(self): """Tests complete flow with numpy_input_fn.""" - n_classes = 2 + n_classes = 3 input_dimension = 2 batch_size = 10 - data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) + data = np.linspace( + 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) x_data = data.reshape(batch_size, input_dimension) - y_data = np.reshape(data[:batch_size], (batch_size, 1)) + y_data = self._as_label(np.reshape(data[:batch_size], (batch_size, 1))) # learn y = x train_input_fn = numpy_io.numpy_input_fn( x={'x': x_data}, @@ -401,9 +405,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase): input_dimension = 1 n_classes = 2 batch_size = 10 - data = np.linspace(0., 2., batch_size, dtype=np.float32) + data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32) x = pd.DataFrame({'x': data}) - y = pd.Series(data) + y = pd.Series(self._as_label(data)) train_input_fn = pandas_io.pandas_input_fn( x=x, y=y, @@ -431,25 +435,28 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase): def test_input_fn_from_parse_example(self): """Tests complete flow with input_fn constructed from parse_example.""" input_dimension = 2 - n_classes = 2 + n_classes = 3 batch_size = 10 - data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) + data = np.linspace(0., n_classes-1., batch_size * input_dimension, + dtype=np.float32) data = data.reshape(batch_size, input_dimension) serialized_examples = [] for datum in data: example = example_pb2.Example(features=feature_pb2.Features( feature={ - 'x': feature_pb2.Feature( - float_list=feature_pb2.FloatList(value=datum)), - 'y': feature_pb2.Feature( - float_list=feature_pb2.FloatList(value=datum[:1])), + 'x': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=datum)), + 'y': + feature_pb2.Feature(int64_list=feature_pb2.Int64List( + value=self._as_label(datum[:1]))), })) serialized_examples.append(example.SerializeToString()) feature_spec = { 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32), - 'y': parsing_ops.FixedLenFeature([1], dtypes.float32), + 'y': parsing_ops.FixedLenFeature([1], dtypes.int64), } def _train_input_fn(): feature_map = parsing_ops.parse_example(serialized_examples, feature_spec) diff --git a/tensorflow/python/estimator/canned/dnn_test.py b/tensorflow/python/estimator/canned/dnn_test.py index 48cd0b7263..dd334059a9 100644 --- a/tensorflow/python/estimator/canned/dnn_test.py +++ b/tensorflow/python/estimator/canned/dnn_test.py @@ -300,12 +300,18 @@ class DNNClassifierPredictTest(test.TestCase): # logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597 # probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597] # class_ids = argmax(probabilities) = [0] - self.assertAllClose({ - prediction_keys.PredictionKeys.LOGITS: [-2.08], - prediction_keys.PredictionKeys.LOGISTIC: [0.11105597], - prediction_keys.PredictionKeys.PROBABILITIES: [0.88894403, 0.11105597], - prediction_keys.PredictionKeys.CLASS_IDS: [0], - }, next(dnn_classifier.predict(input_fn=input_fn))) + predictions = next(dnn_classifier.predict(input_fn=input_fn)) + self.assertAllClose([-2.08], + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose([0.11105597], + predictions[prediction_keys.PredictionKeys.LOGISTIC]) + self.assertAllClose( + [0.88894403, + 0.11105597], predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([0], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertAllEqual([b'0'], + predictions[prediction_keys.PredictionKeys.CLASSES]) def test_multi_dim(self): """Asserts predictions for multi-dimensional input and logits.""" @@ -535,6 +541,9 @@ class DNNClassifierIntegrationTest(test.TestCase): if self._model_dir: shutil.rmtree(self._model_dir) + def _as_label(self, data_in_float): + return np.rint(data_in_float).astype(np.int64) + def _test_complete_flow( self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, n_classes, batch_size): @@ -572,12 +581,13 @@ class DNNClassifierIntegrationTest(test.TestCase): def test_numpy_input_fn(self): """Tests complete flow with numpy_input_fn.""" - n_classes = 2 + n_classes = 3 input_dimension = 2 batch_size = 10 - data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) + data = np.linspace( + 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) x_data = data.reshape(batch_size, input_dimension) - y_data = np.reshape(data[:batch_size], (batch_size, 1)) + y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) # learn y = x train_input_fn = numpy_io.numpy_input_fn( x={'x': x_data}, @@ -608,11 +618,11 @@ class DNNClassifierIntegrationTest(test.TestCase): if not HAS_PANDAS: return input_dimension = 1 - n_classes = 2 + n_classes = 3 batch_size = 10 - data = np.linspace(0., 2., batch_size, dtype=np.float32) + data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32) x = pd.DataFrame({'x': data}) - y = pd.Series(data) + y = pd.Series(self._as_label(data)) train_input_fn = pandas_io.pandas_input_fn( x=x, y=y, @@ -640,25 +650,28 @@ class DNNClassifierIntegrationTest(test.TestCase): def test_input_fn_from_parse_example(self): """Tests complete flow with input_fn constructed from parse_example.""" input_dimension = 2 - n_classes = 2 + n_classes = 3 batch_size = 10 - data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) + data = np.linspace( + 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) data = data.reshape(batch_size, input_dimension) serialized_examples = [] for datum in data: example = example_pb2.Example(features=feature_pb2.Features( feature={ - 'x': feature_pb2.Feature( - float_list=feature_pb2.FloatList(value=datum)), - 'y': feature_pb2.Feature( - float_list=feature_pb2.FloatList(value=datum[:1])), + 'x': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=datum)), + 'y': + feature_pb2.Feature(int64_list=feature_pb2.Int64List( + value=self._as_label(datum[:1]))), })) serialized_examples.append(example.SerializeToString()) feature_spec = { 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32), - 'y': parsing_ops.FixedLenFeature([1], dtypes.float32), + 'y': parsing_ops.FixedLenFeature([1], dtypes.int64), } def _train_input_fn(): feature_map = parsing_ops.parse_example(serialized_examples, feature_spec) diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 631ddfc5df..8da1e5104c 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -302,7 +302,8 @@ def _multi_class_head_with_softmax_cross_entropy_loss(n_classes, Raises: ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid. """ - if label_vocabulary is not None and not isinstance(label_vocabulary, list): + if label_vocabulary is not None and not isinstance(label_vocabulary, + (list, tuple)): raise ValueError('label_vocabulary should be a list. Given type: {}'.format( type(label_vocabulary))) @@ -356,14 +357,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): label_ids = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels) - assert_less = check_ops.assert_less( - label_ids, - ops.convert_to_tensor(self._n_classes, dtype=label_ids.dtype), - message='Label IDs must < n_classes') - assert_greater = check_ops.assert_non_negative( - label_ids, message='Label Ids must >= 0') - with ops.control_dependencies((assert_less, assert_greater)): - return array_ops.identity(label_ids) + return _assert_range(label_ids, self._n_classes) def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): @@ -459,7 +453,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): def _binary_logistic_head_with_sigmoid_cross_entropy_loss( - weight_feature_key=None, thresholds=None): + weight_feature_key=None, thresholds=None, label_vocabulary=None): """Creates a `Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -475,6 +469,11 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( generated for each threshold value. This threshold is applied to the logistic values to determine the binary classification (i.e., above the threshold is `true`, below is `false`. + label_vocabulary: A list of strings represents possible label values. If it + is not given, that means labels are already encoded within [0, 1]. If + given, labels must be string type and have any value in + `label_vocabulary`. Also there will be errors if vocabulary is not + provided and labels are string. Returns: An instance of `Head` for binary classification. @@ -483,50 +482,81 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( ValueError: if `thresholds` contains a value outside of `(0, 1)`. """ thresholds = tuple(thresholds) if thresholds else tuple() + if label_vocabulary is not None and not isinstance(label_vocabulary, + (list, tuple)): + raise ValueError('label_vocabulary should be a list. Given type: {}'.format( + type(label_vocabulary))) + for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,)) return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss( - weight_feature_key=weight_feature_key, thresholds=thresholds) + weight_feature_key=weight_feature_key, + thresholds=thresholds, + label_vocabulary=label_vocabulary) class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): """See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`.""" - def __init__(self, weight_feature_key=None, thresholds=None): + def __init__(self, + weight_feature_key=None, + thresholds=None, + label_vocabulary=None): self._weight_feature_key = weight_feature_key self._thresholds = thresholds + self._label_vocabulary = label_vocabulary @property def logits_dimension(self): return 1 - def _eval_metric_ops( - self, labels, logits, logistic, scores, classes, unweighted_loss, - weights=None): - with ops.name_scope( - None, 'metrics', - (labels, logits, logistic, scores, classes, unweighted_loss, weights)): + def _eval_metric_ops(self, + labels, + logits, + logistic, + scores, + class_ids, + unweighted_loss, + weights=None): + with ops.name_scope(None, 'metrics', (labels, logits, logistic, scores, + class_ids, unweighted_loss, weights)): keys = metric_keys.MetricKeys labels_mean = _indicator_labels_mean( labels=labels, weights=weights, name=keys.LABEL_MEAN) metric_ops = { # Estimator already adds a metric for loss. - keys.LOSS_MEAN: metrics_lib.mean( - unweighted_loss, weights=weights, name=keys.LOSS_MEAN), - keys.ACCURACY: metrics_lib.accuracy( - labels=labels, predictions=classes, weights=weights, - name=keys.ACCURACY), - keys.PREDICTION_MEAN: _predictions_mean( - predictions=logistic, weights=weights, name=keys.PREDICTION_MEAN), - keys.LABEL_MEAN: labels_mean, - keys.ACCURACY_BASELINE: _accuracy_baseline(labels_mean), - keys.AUC: _auc( - labels=labels, predictions=logistic, weights=weights, - name=keys.AUC), - keys.AUC_PR: _auc( - labels=labels, predictions=logistic, weights=weights, curve='PR', - name=keys.AUC_PR) + keys.LOSS_MEAN: + metrics_lib.mean( + unweighted_loss, weights=weights, name=keys.LOSS_MEAN), + keys.ACCURACY: + metrics_lib.accuracy( + labels=labels, + predictions=class_ids, + weights=weights, + name=keys.ACCURACY), + keys.PREDICTION_MEAN: + _predictions_mean( + predictions=logistic, + weights=weights, + name=keys.PREDICTION_MEAN), + keys.LABEL_MEAN: + labels_mean, + keys.ACCURACY_BASELINE: + _accuracy_baseline(labels_mean), + keys.AUC: + _auc( + labels=labels, + predictions=logistic, + weights=weights, + name=keys.AUC), + keys.AUC_PR: + _auc( + labels=labels, + predictions=logistic, + weights=weights, + curve='PR', + name=keys.AUC_PR) } for threshold in self._thresholds: accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold @@ -559,27 +589,39 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): two_class_logits = array_ops.concat( (array_ops.zeros_like(logits), logits), 1, name='two_class_logits') scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES) - classes = array_ops.reshape( + class_ids = array_ops.reshape( math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes') + if self._label_vocabulary: + table = lookup_ops.index_to_string_table_from_tensor( + vocabulary_list=self._label_vocabulary, name='class_string_lookup') + classes = table.lookup(class_ids) + else: + classes = string_ops.as_string(class_ids, name='str_classes') predictions = { pred_keys.LOGITS: logits, pred_keys.LOGISTIC: logistic, pred_keys.PROBABILITIES: scores, - pred_keys.CLASS_IDS: classes + pred_keys.CLASS_IDS: class_ids, + pred_keys.CLASSES: classes, } if mode == model_fn.ModeKeys.PREDICT: return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, - export_outputs={'': export_output.ClassificationOutput( - scores=scores, - # `ClassificationOutput` requires string classes. - # TODO(ptucker): Support label_keys. - classes=string_ops.as_string(classes, name='str_classes'))}) + export_outputs={ + '': + export_output.ClassificationOutput( + scores=scores, classes=classes) + }) # Eval. - labels = _check_labels(_maybe_expand_dim(math_ops.to_float(labels)), - self.logits_dimension) + labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension) + if self._label_vocabulary is not None: + labels = lookup_ops.index_table_from_tensor( + vocabulary_list=tuple(self._label_vocabulary), + name='class_id_lookup').lookup(labels) + labels = math_ops.to_float(labels) + labels = _assert_range(labels, 2) unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits, name='loss') weights = ( @@ -598,7 +640,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): logits=logits, logistic=logistic, scores=scores, - classes=classes, + class_ids=class_ids, unweighted_loss=unweighted_loss, weights=weights)) @@ -721,3 +763,14 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): predictions=predictions, loss=training_loss, train_op=train_op_fn(training_loss)) + + +def _assert_range(labels, n_classes): + assert_less = check_ops.assert_less( + labels, + ops.convert_to_tensor(n_classes, dtype=labels.dtype), + message='Label IDs must < n_classes') + assert_greater = check_ops.assert_non_negative( + labels, message='Label IDs must >= 0') + with ops.control_dependencies((assert_less, assert_greater)): + return array_ops.identity(labels) diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 0efafac87a..e3d9258466 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -206,7 +206,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): }) with self.test_session(): - with self.assertRaisesOpError('Label Ids must >= 0'): + with self.assertRaisesOpError('Label IDs must >= 0'): spec.loss.eval({ labels_placeholder: labels_2x1_with_negative_id, logits_placeholder: logits_2x3 @@ -743,8 +743,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertEqual(1, head.logits_dimension) # Both logits and labels should be shape (batch_size, 1). - values_2x1 = np.array(((43.,), (44.,),)) - values_3x1 = np.array(((45.,), (46.,), (47.,),)) + values_2x1 = np.array(((0.,), (1.,),)) + values_3x1 = np.array(((0.,), (1.,), (0.,),)) # Static shape. with self.assertRaisesRegexp( @@ -788,28 +788,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertEqual(1, head.logits_dimension) # Create estimator spec. - logits = np.array(((45,), (-41,),), dtype=np.int32) + logits = [[45.], [-41.]] spec = head.create_estimator_spec( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.PREDICT, logits=logits) - expected_predictions = { - prediction_keys.PredictionKeys.LOGITS: - logits.astype(np.float32), - prediction_keys.PredictionKeys.LOGISTIC: - _sigmoid(logits).astype(np.float32), - prediction_keys.PredictionKeys.PROBABILITIES: - np.array(((0., 1.), (1., 0.),), dtype=np.float32), - prediction_keys.PredictionKeys.CLASS_IDS: - np.array(((1,), (0,)), dtype=np.int64), - } - # Assert spec contains expected tensors. - self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys()) - self.assertEqual( - {k: v.dtype for k, v in six.iteritems(expected_predictions)}, - {k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)}) self.assertIsNone(spec.loss) self.assertEqual({}, spec.eval_metric_ops) self.assertIsNone(spec.train_op) @@ -821,7 +806,37 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): with self.test_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNone(spec.scaffold.summary_op) - self.assertAllClose(expected_predictions, sess.run(spec.predictions)) + predictions = sess.run(spec.predictions) + self.assertAllClose(logits, + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose( + _sigmoid(np.array(logits)), + predictions[prediction_keys.PredictionKeys.LOGISTIC]) + self.assertAllClose( + [[0., 1.], + [1., 0.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([[1], [0]], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertAllEqual([[b'1'], [b'0']], + predictions[prediction_keys.PredictionKeys.CLASSES]) + + def test_predict_with_vocabulary_list(self): + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + label_vocabulary=['aang', 'iroh']) + + logits = [[1.], [0.]] + expected_classes = [[b'iroh'], [b'aang']] + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertAllEqual( + expected_classes, + sess.run(spec.predictions[prediction_keys.PredictionKeys.CLASSES])) def test_eval(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() @@ -834,17 +849,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): logits=logits, labels=np.array(((1,), (1,),), dtype=np.int32)) - expected_predictions = { - prediction_keys.PredictionKeys.LOGITS: - logits.astype(np.float32), - prediction_keys.PredictionKeys.LOGISTIC: - _sigmoid(logits).astype(np.float32), - prediction_keys.PredictionKeys.PROBABILITIES: - np.array(((0., 1.), (1., 0.),), dtype=np.float32), - # TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)? - prediction_keys.PredictionKeys.CLASS_IDS: - np.array(((1,), (0,)), dtype=np.int64), - } keys = metric_keys.MetricKeys expected_metrics = { # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41 @@ -859,10 +863,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): } # Assert spec contains expected tensors. - self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys()) - self.assertEqual( - {k: v.dtype for k, v in six.iteritems(expected_predictions)}, - {k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)}) self.assertIsNotNone(spec.loss) self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) self.assertIsNone(spec.train_op) @@ -875,15 +875,34 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertIsNone(spec.scaffold.summary_op) value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} - predictions, loss, metrics = sess.run(( - spec.predictions, spec.loss, update_ops)) - self.assertAllClose(expected_predictions, predictions) + loss, metrics = sess.run((spec.loss, update_ops)) self.assertAllClose(41., loss) # Check results of both update (in `metrics`) and value ops. self.assertAllClose(expected_metrics, metrics) self.assertAllClose( expected_metrics, {k: value_ops[k].eval() for k in value_ops}) + def test_eval_with_vocabulary_list(self): + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + label_vocabulary=['aang', 'iroh']) + + # Create estimator spec. + logits = np.array(((45,), (-41,),), dtype=np.float32) + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.float32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=[[b'iroh'], [b'iroh']]) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} + update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} + sess.run(update_ops) + self.assertAllClose(1. / 2, + value_ops[metric_keys.MetricKeys.ACCURACY].eval()) + def test_eval_with_thresholds(self): thresholds = [0.25, 0.5, 0.75] head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( @@ -942,23 +961,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): labels=np.array(((1,), (1,),), dtype=np.float64), train_op_fn=_train_op_fn) - expected_predictions = { - prediction_keys.PredictionKeys.LOGITS: - logits.astype(np.float32), - prediction_keys.PredictionKeys.LOGISTIC: - _sigmoid(logits).astype(np.float32), - prediction_keys.PredictionKeys.PROBABILITIES: - np.array(((0., 1.), (1., 0.),), dtype=np.float32), - # TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)? - prediction_keys.PredictionKeys.CLASS_IDS: - np.array(((1,), (0,)), dtype=np.int64), - } - # Assert spec contains expected tensors. - self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys()) - self.assertEqual( - {k: v.dtype for k, v in six.iteritems(expected_predictions)}, - {k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)}) self.assertIsNotNone(spec.loss) self.assertEqual({}, spec.eval_metric_ops) self.assertIsNotNone(spec.train_op) @@ -969,9 +972,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): with self.test_session() as sess: _initialize_variables(self, spec.scaffold) self.assertIsNotNone(spec.scaffold.summary_op) - predictions, loss, train_result, summary_str = sess.run(( - spec.predictions, spec.loss, spec.train_op, spec.scaffold.summary_op)) - self.assertAllClose(expected_predictions, predictions) + loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, + spec.scaffold.summary_op)) self.assertAllClose(expected_loss, loss) self.assertEqual(expected_train_result, train_result) _assert_simple_summaries(self, { @@ -995,28 +997,23 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): mode=model_fn.ModeKeys.PREDICT, logits=logits) - expected_predictions = { - prediction_keys.PredictionKeys.LOGITS: - logits.astype(np.float32), - prediction_keys.PredictionKeys.LOGISTIC: - _sigmoid(logits).astype(np.float32), - prediction_keys.PredictionKeys.PROBABILITIES: - np.array(((0., 1.), (1., 0.), (0., 1.)), dtype=np.float32), - # TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)? - prediction_keys.PredictionKeys.CLASS_IDS: - np.array(((1,), (0,), (1,)), dtype=np.int64), - } - - # Assert spec contains expected tensors. - self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys()) - self.assertEqual( - {k: v.dtype for k, v in six.iteritems(expected_predictions)}, - {k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)}) - # Assert predictions, loss, and metrics. with self.test_session() as sess: _initialize_variables(self, spec.scaffold) - self.assertAllClose(expected_predictions, sess.run(spec.predictions)) + predictions = sess.run(spec.predictions) + self.assertAllClose( + logits.astype(np.float32), + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose( + _sigmoid(logits).astype(np.float32), + predictions[prediction_keys.PredictionKeys.LOGISTIC]) + self.assertAllClose( + [[0., 1.], [1., 0.], + [0., 1.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([[1], [0], [1]], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertAllEqual([[b'1'], [b'0'], [b'1']], + predictions[prediction_keys.PredictionKeys.CLASSES]) def test_weighted_multi_example_eval(self): """3 examples, 1 batch.""" -- cgit v1.2.3 From 85e832201e463fea15bf7f05d4f2b37244b5fd10 Mon Sep 17 00:00:00 2001 From: RJ Ryan Date: Wed, 7 Jun 2017 13:18:59 -0700 Subject: Support unknown emit shapes in tf.nn.raw_rnn. PiperOrigin-RevId: 158308002 --- tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py | 9 ++++++--- tensorflow/python/ops/rnn.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index bf24347c43..d250af9037 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -2040,6 +2040,9 @@ class RawRNNTest(test.TestCase): inputs_ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) inputs_ta = inputs_ta.unstack(inputs) + # Verify emit shapes may be unknown by feeding a placeholder that + # determines an emit shape. + unknown_dim = array_ops.placeholder(dtype=dtypes.int32) cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) @@ -2047,12 +2050,12 @@ class RawRNNTest(test.TestCase): if cell_output is None: emit_output = (array_ops.zeros( [2, 3], dtype=dtypes.int32), array_ops.zeros( - [1], dtype=dtypes.int64)) + [unknown_dim], dtype=dtypes.int64)) next_state = cell.zero_state(batch_size, dtypes.float32) else: emit_output = (array_ops.ones( [batch_size, 2, 3], dtype=dtypes.int32), array_ops.ones( - [batch_size, 1], dtype=dtypes.int64)) + [batch_size, unknown_dim], dtype=dtypes.int64)) next_state = cell_state elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) finished = math_ops.reduce_all(elements_finished) @@ -2069,7 +2072,7 @@ class RawRNNTest(test.TestCase): self.assertEqual([dtypes.int32, dtypes.int64], [ta.dtype for ta in output_ta]) output = [ta.stack() for ta in output_ta] - output_vals = sess.run(output) + output_vals = sess.run(output, feed_dict={unknown_dim: 1}) self.assertAllEqual( np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0]) self.assertAllEqual( diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index ca72734707..3c3c18b1c9 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -984,7 +984,8 @@ def raw_rnn(cell, loop_fn, if emit_structure is not None: flat_emit_structure = nest.flatten(emit_structure) - flat_emit_size = [emit.get_shape() for emit in flat_emit_structure] + flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else + array_ops.shape(emit) for emit in flat_emit_structure] flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] else: emit_structure = cell.output_size -- cgit v1.2.3 From 0770393e95e689b6f89fb8e3a195381796766d4c Mon Sep 17 00:00:00 2001 From: Eric Liu Date: Wed, 7 Jun 2017 13:30:51 -0700 Subject: [Tensorboard] Add a trace viewer component to TensorBoard. We make the trace viewer a separate app; otherwise, there would be dependency conflicts (e.g. Polymer) between the trace viewer app and the tensorboard app. The trace viewer app would be served by a plugin, and Tensorboard dashboard will integrate trace viewer app using iframe in the future. This CL also added "mominify" support for link import HTML tags in the tensorboard home-grown java vulnizer; otherwise, the vulcanized trace viewer code would crash the java vulcanizer. For open-source build, we add a denpendency on the Catapult github repository (https://github.com/catapult-project/catapult/tree/master/tracing). We use a bazel genrule to vulcanize a trace viewer binary which is then used in the tf-trace-viewer component. PiperOrigin-RevId: 158309408 --- tensorflow/BUILD | 1 + tensorflow/tensorboard/BUILD | 1 + tensorflow/tensorboard/components/BUILD | 18 +++ .../tensorboard/components/tf_trace_viewer/BUILD | 30 +++++ .../components/tf_trace_viewer/data/BUILD | 17 +++ .../components/tf_trace_viewer/data/trace.json | 105 +++++++++++++++++ .../components/tf_trace_viewer/demo.html | 30 +++++ .../tf_trace_viewer/tf-trace-viewer.html | 127 +++++++++++++++++++++ .../tensorboard/components/trace_viewer.html | 28 +++++ tensorflow/workspace.bzl | 15 +++ 10 files changed, 372 insertions(+) create mode 100644 tensorflow/tensorboard/components/tf_trace_viewer/BUILD create mode 100644 tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD create mode 100644 tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json create mode 100644 tensorflow/tensorboard/components/tf_trace_viewer/demo.html create mode 100644 tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html create mode 100644 tensorflow/tensorboard/components/trace_viewer.html diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 8b6224e165..1353f15ec4 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -372,6 +372,7 @@ filegroup( "//tensorflow/tensorboard/components/tf_storage/test:all_files", "//tensorflow/tensorboard/components/tf_tensorboard:all_files", "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_trace_viewer:all_files", "//tensorflow/tensorboard/components/vz_distribution_chart:all_files", "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files", "//tensorflow/tensorboard/components/vz_line_chart:all_files", diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD index caaf1769c0..23581badb6 100644 --- a/tensorflow/tensorboard/BUILD +++ b/tensorflow/tensorboard/BUILD @@ -30,6 +30,7 @@ filegroup( srcs = [ "TAG", "//tensorflow/tensorboard/components:index.html", + "//tensorflow/tensorboard/components:trace_viewer_index.html", ], ) diff --git a/tensorflow/tensorboard/components/BUILD b/tensorflow/tensorboard/components/BUILD index 6a0052b793..e287b2c918 100644 --- a/tensorflow/tensorboard/components/BUILD +++ b/tensorflow/tensorboard/components/BUILD @@ -22,6 +22,24 @@ tensorboard_html_binary( deps = [":tensorboard"], ) +ts_web_library( + name = "trace_viewer", + srcs = [ + "trace_viewer.html", + ], + path = "/", + deps = [ + "//tensorflow/tensorboard/components/tf_trace_viewer", + ], +) + +tensorboard_html_binary( + name = "trace_viewer_index", + input_path = "/trace_viewer.html", + output_path = "/trace_viewer_index.html", + deps = [":trace_viewer"], +) + filegroup( name = "all_files", srcs = glob(["**"]), diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/BUILD b/tensorflow/tensorboard/components/tf_trace_viewer/BUILD new file mode 100644 index 0000000000..943229fd8b --- /dev/null +++ b/tensorflow/tensorboard/components/tf_trace_viewer/BUILD @@ -0,0 +1,30 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_trace_viewer", + srcs = [ + "tf-trace-viewer.html", + "@org_chromium_catapult_vulcanized_trace_viewer//:trace_viewer_full.html", + ], + path = "/tf-trace-viewer", +) + +ts_web_library( + name = "demo", + srcs = ["demo.html"], + path = "/tf-trace-viewer", + deps = [ + ":tf_trace_viewer", + "//tensorflow/tensorboard/components/tf_trace_viewer/data", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD b/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD new file mode 100644 index 0000000000..f72035d43a --- /dev/null +++ b/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD @@ -0,0 +1,17 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") + +licenses(["notice"]) # Apache 2.0 + +web_library( + name = "data", + srcs = glob(["*.json"]), + path = "/tf-trace-viewer/data/plugin/profile", +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json b/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json new file mode 100644 index 0000000000..e1d57394e3 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json @@ -0,0 +1,105 @@ +{ + "traceEvents": [ + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "C", + "name": "counter", "args": {"value": 10}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "B", + "name": "A long name that doesnt fit but is exceedingly informative", + "args": {"name_false": false, "value_true": true}}, + {"cat": "PERF", "pid": 22630, "ts": 835, "ph": "I", "s": "p", + "name": "ProcessWideEvent1", "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 827, "ph": "B", + "name": "Asub with a name that wont fit", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 828, "ph": "E", + "name": "Asub", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 829, "ph": "B", + "name": "Asub", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 15, "ts": 820, "ph": "X", + "name": "Long X type", "args": {}, "sf": 7, "esf": 8}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "E", + "name": "Asub", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X", + "name": "X1", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X", + "name": "X same ts and dur as X1", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "C", + "name": "counter", "args": {"value": 1}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 833, "ph": "E", + "name": "", "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 835, "ph": "I", + "name": "ThreadLevelI1", "args": {}}, + + {"cat": "PERF", "ts": 880, "ph": "I", "s": "g", "name": "GlobalEvent1", + "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 837, "ph": "I", + "name": "ThreadLevelI2", "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 839, "ph": "C", + "name": "counter", "args": {"value": 5}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 840, "ph": "B", + "name": "A not as long a name", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "E", + "name": "A not as long a name", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "C", + "name": "counter", "args": {"value": 1}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "C", + "name": "counter", "args": {"value": 10}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 850, "ph": "B", + "name": "B", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "E", + "name": "B", "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 827, "ph": "B", + "name": "A", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 835, "ph": "I", + "name": "ThreadLevelImmediate Three", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 845, "ph": "I", + "name": "ThreadLevelImmediate4", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 854, "ph": "E", + "name": "A", "args": {}}, + + {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B", + "name": "B/E over X", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 10, "ts": 860, "ph": "X", + "name": "X", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B", + "name": "B/E under X", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E", + "name": "B/E under X", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E", + "name": "B/E over X", "args": {}}, + + {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 870, "ph": "P", + "name": "SampleA", "args": {}}, + {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 875, "ph": "P", + "name": "SampleB", "args": {}}, + {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 878, "ph": "P", + "name": "SampleC", "args": {}, "sf": 8}, + + {"cat": "__metadata", "pid": 22630, "tid": 22630, "ts": 0, "ph": "M", + "name": "thread_name", "args": {"name": "threadA"}}, + {"cat": "__metadata", "pid": 22630, "tid": 22631, "ts": 0, "ph": "M", + "name": "thread_name", "args": {"name": "threadB"}}, + {"cat": "__metadata", "pid": 22630, "tid": 22632, "ts": 0, "ph": "M", + "name": "thread_name", "args": {"name": "threadC"}} + ], + "stackFrames": { + "1": { + "category": "m1", + "name": "main" + }, + "7": { + "category": "m2", + "name": "frame7", + "parent": "1" + }, + "8": { + "category": "m2", + "name": "frame8", + "parent": "1" + } + } +} diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/demo.html b/tensorflow/tensorboard/components/tf_trace_viewer/demo.html new file mode 100644 index 0000000000..dd0029e967 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_trace_viewer/demo.html @@ -0,0 +1,30 @@ + + + + +Trace Viewer Demo + +
+ + +
diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html b/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html new file mode 100644 index 0000000000..a7b0b2cd73 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html @@ -0,0 +1,127 @@ + + + + + + diff --git a/tensorflow/tensorboard/components/trace_viewer.html b/tensorflow/tensorboard/components/trace_viewer.html new file mode 100644 index 0000000000..c9bcdc9e20 --- /dev/null +++ b/tensorflow/tensorboard/components/trace_viewer.html @@ -0,0 +1,28 @@ + + + + +Trace Viewer + + + + + + + + diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index d7b80f38a8..a5e7588860 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -2693,3 +2693,18 @@ def tf_workspace(path_prefix="", tf_repo_name=""): path = "/test-fixture", exclude = ["test/**"], ) + + filegroup_external( + name = "org_chromium_catapult_vulcanized_trace_viewer", + licenses = ["notice"], # BSD-3-Clause + sha256_urls = { + "f0df289ba9d03d857ad1c2f5918861376b1510b71588ffc60eff5c7a7bfedb09": [ + "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE", + "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE", + ], + "9e99e79439ea5a1471bd4dd325bd6733e133bcb3da4df4b878ed6d2aec7c8d86": [ + "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html", + "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html" + ], + }, + ) -- cgit v1.2.3 From 599727c654aac53ee6f290b3d5e36c0e0852e951 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 7 Jun 2017 13:31:32 -0700 Subject: [XLA] Propagate debug option flags to hlo_test_base. Specific HLO tests have to replace the generic test_main target with a manual main() that invokes RUN_ALL_TESTS. To get access to a module with debug options set up, a new convenience method is created on HloTestBase. Initially algebraic_simplifier_test is modified as a canary; in a followup we'll convert all HLO tests to this approach. PiperOrigin-RevId: 158309488 --- tensorflow/compiler/xla/service/BUILD | 3 +- .../xla/service/algebraic_simplifier_test.cc | 95 +++++++++++++--------- tensorflow/compiler/xla/tests/BUILD | 1 + tensorflow/compiler/xla/tests/hlo_test_base.cc | 10 +++ tensorflow/compiler/xla/tests/hlo_test_base.h | 6 ++ 5 files changed, 76 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7cb3c95ffa..ecb0d2cb23 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -868,9 +868,10 @@ cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", - "//tensorflow/core:test_main", + "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e0915e3526..19583433db 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" namespace op = xla::testing::opcode_matchers; @@ -59,7 +61,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); @@ -82,7 +84,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); @@ -105,7 +107,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); @@ -127,7 +129,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); @@ -149,7 +151,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); @@ -171,7 +173,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); @@ -199,7 +201,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); @@ -225,7 +227,7 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -250,7 +252,7 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); @@ -279,7 +281,7 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -304,7 +306,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); @@ -329,7 +331,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); @@ -359,7 +361,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); @@ -382,7 +384,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); @@ -405,7 +407,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); @@ -434,7 +436,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -455,7 +457,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); @@ -476,7 +478,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); @@ -497,7 +499,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { builder.AddInstruction( HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); @@ -527,7 +529,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT( @@ -558,7 +560,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, empty_slice}, 0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -581,7 +583,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); // Set to different layouts. @@ -608,7 +610,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); // Set to same layouts. @@ -640,7 +642,7 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); @@ -686,7 +688,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -716,7 +718,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), HloOpcode::kMaximum, movable_reshape, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -744,7 +746,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); @@ -771,7 +773,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); @@ -797,7 +799,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -825,7 +827,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), HloOpcode::kCopy, copy1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); @@ -850,7 +852,7 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); @@ -874,7 +876,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -897,7 +899,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -919,7 +921,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -942,7 +944,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -966,7 +968,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -992,7 +994,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), @@ -1697,7 +1699,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEmbeddedComputation(std::move(dot_computation)); module->AddEntryComputation(call_builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -1707,3 +1709,20 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { } // namespace } // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 1971868a38..e60d38d0c6 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -93,6 +93,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 871fbeb0a8..fbbb101ce9 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,6 +23,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -54,6 +55,8 @@ struct HloTestBase::EigenThreadPoolWrapper { HloTestBase::HloTestBase() : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) { + // TODO(b/62411181): get rid of this flag entirely when the usual debug flags + // are piped to all HLO tests. test_hlo_dumper_ = [](const HloModule& module, const string& label) { legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags(); if (flags->xla_hlo_test_generate_hlo_graph) { @@ -73,6 +76,13 @@ HloTestBase::~HloTestBase() { } } +std::unique_ptr HloTestBase::CreateNewModule() { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} + StatusOr HloTestBase::Execute( std::unique_ptr module, tensorflow::gtl::ArraySlice diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index ca0b0b6928..83c877b393 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -44,6 +44,12 @@ class HloTestBase : public ::testing::Test { ~HloTestBase() override; + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. It's recommended to use this method to + // create all HloModules for tests. + std::unique_ptr CreateNewModule(); + // Executes the given module and returns a global data handle. StatusOr Execute( std::unique_ptr module, -- cgit v1.2.3 From 38249d6be21e77bbd0663b71af598c0bdb99d6dc Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 7 Jun 2017 13:37:02 -0700 Subject: Swap the order of NanTensorHook and custom hooks to ensure that when the training encounteres NaN's in the loss function, user-supplied hooks such as tf_debug.LocalCLIDebugHook can still be used to debug the root cause of the numeric issues. PiperOrigin-RevId: 158310249 --- .../learn/python/learn/estimators/estimator.py | 2 +- tensorflow/python/estimator/estimator.py | 2 +- tensorflow/python/estimator/estimator_test.py | 42 ++++++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 534aac644a..ac5ef565c8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -957,6 +957,7 @@ class BaseEstimator( self._check_inputs(features, labels) model_fn_ops = self._get_train_ops(features, labels) ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss) + all_hooks.extend(hooks) all_hooks.extend([ basic_session_run_hooks.NanTensorHook(model_fn_ops.loss), basic_session_run_hooks.LoggingTensorHook( @@ -966,7 +967,6 @@ class BaseEstimator( }, every_n_iter=100) ]) - all_hooks.extend(hooks) scaffold = model_fn_ops.scaffold or monitored_session.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index f424598ccb..8e6edf6da7 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -589,6 +589,7 @@ class Estimator(object): estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) + all_hooks.extend(hooks) all_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( @@ -598,7 +599,6 @@ class Estimator(object): }, every_n_iter=100) ]) - all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 4119a07bd8..b86afece43 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -55,6 +55,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import checkpoint_state_pb2 from tensorflow.python.training import saver from tensorflow.python.training import saver_test_utils @@ -1520,6 +1521,47 @@ class EstimatorExportTest(test.TestCase): est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn) +class EstimatorHookOrderingTest(test.TestCase): + + def testCustomHooksAreCalledBeforeNanTensorHook(self): + + def nan_making_model_fn(mode, features, labels): + """A graph that generates NaN's for testing.""" + del features, labels + + global_step = variables.Variable( + 0, dtype=dtypes.int64, name='global_step') + inc_global_step = state_ops.assign_add(global_step, 1) + nan_const = constant_op.constant(np.nan, dtype=dtypes.float32) + loss = control_flow_ops.cond( + inc_global_step > 1, lambda: nan_const, lambda: 1.0) + + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=global_step.read_value(), + loss=loss, + train_op=inc_global_step) + + def empty_input_fn(): + return dict(), None + + class AfterRunCountingHook(session_run_hook.SessionRunHook): + """Hooks that counts the number of times after_run() is called.""" + + def __init__(self): + self.after_run_count = 0 + + def after_run(self, run_context, run_values): + del run_context, run_values + self.after_run_count += 1 + + test_hook = AfterRunCountingHook() + est = estimator.Estimator(model_fn=nan_making_model_fn) + with self.assertRaises(basic_session_run_hooks.NanLossDuringTrainingError): + est.train(input_fn=empty_input_fn, steps=2, hooks=[test_hook]) + self.assertEqual(2, test_hook.after_run_count) + + class EstimatorIntegrationTest(test.TestCase): def test_complete_flow_with_a_simple_linear_model(self): -- cgit v1.2.3 From bff5e72da9f3df488c7d99149497ea41b6366944 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 13:40:05 -0700 Subject: Fix typo. PiperOrigin-RevId: 158310742 --- tensorflow/contrib/rnn/python/ops/gru_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py index de57e7d81e..92beae35dd 100644 --- a/tensorflow/contrib/rnn/python/ops/gru_ops.py +++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py @@ -98,7 +98,7 @@ class GRUBlockCell(rnn_cell_impl.RNNCell): r"""Block GRU cell implementation. The implementation is based on: http://arxiv.org/abs/1406.1078 - Computes the LSTM cell forward propagation for 1 time step. + Computes the GRU cell forward propagation for 1 time step. This kernel op implements the following mathematical equations: -- cgit v1.2.3 From 5d90bbaac9b046a34562fa799a2f88cbe6edc2ae Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Wed, 7 Jun 2017 14:26:23 -0700 Subject: [XLA] Disable constant_folding in test base, so that intended test code paths would not be elided by constant_folding pass. PiperOrigin-RevId: 158317641 --- tensorflow/compiler/xla/tests/client_library_test_base.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 03552d7bbf..b96bb8f846 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -48,6 +48,15 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) : client_(GetOrCreateLocalClientOrDie(platform)) { *(execution_options_.mutable_debug_options()) = legacy_flags::GetDebugOptionsFromFlags(); + + // Disabling constant_folding so that tests (usually written using Constants) + // will exercise the intended code paths, instead of being constant folded. + // + // TODO(b/38354253): Constant folding is currently disabled. Change tests to + // use Parameters instead of Constants, and re-enable constant folding by + // default. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "constant_folding"); } string ClientLibraryTestBase::TestName() const { -- cgit v1.2.3 From b5e8d308655a027e8c163c3fe3bd3445e09e9d23 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 7 Jun 2017 14:51:13 -0700 Subject: [TF:XLA] Refactor randomized tests to allow testing of larger inputs without running out of memory. PiperOrigin-RevId: 158321431 --- tensorflow/compiler/tests/randomized_tests.cc | 1327 ++++++++++++++----------- 1 file changed, 726 insertions(+), 601 deletions(-) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 2a71543f3f..50ac4a6c25 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -76,6 +76,7 @@ namespace { // Command line flags: see main() below. int64 tf_xla_random_seed = 0; int32 tf_xla_test_repetitions = 20; +int64 tf_xla_max_tensor_size = 100000LL; string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; @@ -96,6 +97,11 @@ class OpTestBuilder { // Adds an input 'tensor'. OpTestBuilder& Input(const Tensor& tensor); + // Adds a random input tensor with 'type'. If 'dims' is not provided, + // RandomDims() is used. + OpTestBuilder& RandomInput(DataType type); + OpTestBuilder& RandomInput(DataType type, std::vector dims); + // Sets an attribute. template OpTestBuilder& Attr(StringPiece attr_name, T&& value); @@ -116,11 +122,19 @@ class OpTestBuilder { std::vector* inputs, std::vector* outputs) const; - const std::vector& inputs() const { return inputs_; } + struct InputDescription { + Tensor tensor; + + DataType type = DT_INVALID; + bool has_dims = false; + std::vector dims; + }; + + const std::vector& inputs() const { return inputs_; } private: NodeDef node_def_; - std::vector inputs_; + std::vector inputs_; }; OpTestBuilder::OpTestBuilder(const string& op_name) { @@ -129,7 +143,28 @@ OpTestBuilder::OpTestBuilder(const string& op_name) { OpTestBuilder& OpTestBuilder::Input(const Tensor& tensor) { VLOG(1) << "Adding input: " << tensor.DebugString(); - inputs_.push_back(tensor); + InputDescription input; + input.tensor = tensor; + inputs_.push_back(input); + return *this; +} + +OpTestBuilder& OpTestBuilder::RandomInput(DataType type) { + VLOG(1) << "Adding random input: " << type; + InputDescription input; + input.type = type; + inputs_.push_back(input); + return *this; +} + +OpTestBuilder& OpTestBuilder::RandomInput(DataType type, + std::vector dims) { + VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString(); + InputDescription input; + input.type = type; + input.has_dims = true; + input.dims = std::move(dims); + inputs_.push_back(input); return *this; } @@ -207,16 +242,30 @@ class OpTest : public ::testing::Test { public: OpTest(); - // Runs 'fn' up to --tf_xla_test_repetitions times, or until a failure occurs; - // whichever happens first. - void Repeatedly(const std::function& fn); + enum TestResult { + // The test saw an unrecoverable error. Don't try any more runs. + kFatalError, + // The parameters of the test were invalid (e.g., the "golden" + // implementation failed, or the parameters are oversize). Reruns are ok. + kInvalid, + // The test ran successfully, and we have a verdict. Does *not* mean the + // test passed. + kOk, + }; + + // Runs 'fn' up to --tf_xla_test_repetitions times, or until a test failure + // occurs; whichever happens first. Reruns if the TestResult is kInvalid. + void Repeatedly(const std::function& fn); // Select a random element from 'candidates'. template T Choose(gtl::ArraySlice candidates); static constexpr int kDefaultMaxRank = 5; - static constexpr int64 kDefaultMaxDimensionSize = 20LL; + static constexpr int64 kDefaultMaxDimensionSize = 256LL; + + // Returns true if 'dims' have a size less than tf_xla_max_tensor_size. + bool TensorSizeIsOk(gtl::ArraySlice dims); // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); @@ -278,8 +327,9 @@ class OpTest : public ::testing::Test { // element-wise difference between x and y must no more than // atol + rtol * abs(x); or both elements may be NaN or infinity. For // non-floating-point tensors the element values must match exactly. - void ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, - double atol = 1e-2, double rtol = 1e-2); + TestResult ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, + double atol = 1e-2, + double rtol = 1e-2); protected: // Per-test state: @@ -315,10 +365,35 @@ OpTest::OpTest() { TF_CHECK_OK(session_->Create(def)); } -void OpTest::Repeatedly(const std::function& fn) { +void OpTest::Repeatedly(const std::function& fn) { int const max_repetitions = tf_xla_test_repetitions; - for (int i = 0; !HasFailure() && i < max_repetitions; ++i) { - fn(); + int valid_test_runs = 0; + // We run up to 10 * max_repetitions times; the idea is that if we roll the + // dice enough times we will find some valid parameters. We want to put an + // upper limit on the number iterations just in case the probability of + // finding feasible parameters is very low. + for (int i = 0; !HasFailure() && i < max_repetitions * 10 && + valid_test_runs < max_repetitions; + ++i) { + TestResult result = fn(); + switch (result) { + case kOk: + ++valid_test_runs; + break; + + case kFatalError: + ASSERT_TRUE(false) << "Test had fatal failure"; + return; + + case kInvalid: + break; + } + } + if (!HasFailure()) { + EXPECT_GE(valid_test_runs, max_repetitions) + << "Not enough test instances passed; this means that either the " + "golden implementation is buggy or the operator harness is not " + "producing well-formed test cases with a high probability."; } } @@ -333,6 +408,14 @@ int64 OpTest::RandomDim(int64 min, int64 max) { return size_distribution(generator()); } +bool OpTest::TensorSizeIsOk(gtl::ArraySlice dims) { + int64 size = 1LL; + for (int64 dim : dims) { + size *= dim; + } + return size < tf_xla_max_tensor_size; +} + std::vector OpTest::RandomDims(int min_rank, int max_rank, int64 min_size, int64 max_size) { CHECK_LE(0, min_rank); @@ -340,9 +423,13 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, std::uniform_int_distribution rank_distribution(min_rank, max_rank); int rank = rank_distribution(generator()); std::vector dims(rank); - std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() { - return RandomDim(min_size, max_size); - }); + // TODO(phawkins): too small a maximum tensor size could lead to an infinite + // loop here. + do { + std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() { + return RandomDim(min_size, max_size); + }); + } while (!TensorSizeIsOk(dims)); return dims; } @@ -606,53 +693,84 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, } } -void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, - double atol, double rtol) { +OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( + const OpTestBuilder& builder, double atol, double rtol) { + const std::vector& inputs = builder.inputs(); + std::vector input_tensors; + input_tensors.reserve(inputs.size()); + for (const OpTestBuilder::InputDescription& input : inputs) { + if (input.type == DT_INVALID) { + VLOG(1) << "Input: " << input.tensor.DebugString(); + input_tensors.push_back(input.tensor); + } else { + VLOG(1) << "Input: " << input.type << " " + << TensorShape(input.dims).DebugString(); + std::vector dims; + if (input.has_dims) { + dims = input.dims; + } else { + dims = RandomDims(); + } + if (!TensorSizeIsOk(dims)) { + VLOG(1) << "Ignoring oversize dims."; + return kInvalid; + } + input_tensors.push_back(RandomTensor(input.type, dims)); + } + } + string cpu_device = LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0")); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; - ASSERT_TRUE( - DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)); + if (!DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)) { + LOG(ERROR) << "Could not parse device name: " << *tf_xla_test_device_ptr; + return kFatalError; + } DeviceType test_device_type(parsed_name.type); ++num_tests_; GraphDef graph; std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; - TF_ASSERT_OK(builder.BuildGraph( + Status status = builder.BuildGraph( strings::StrCat("test", num_tests_, "_expected"), cpu_device, /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, - &expected_inputs, &expected_fetches)); + &expected_inputs, &expected_fetches); + if (!status.ok()) { + LOG(ERROR) << "Expected graph construction failed: " << status; + return kFatalError; + } NodeDef* node_def; - TF_ASSERT_OK(builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), - test_device, tf_xla_test_use_jit, &graph, - &node_def, &test_inputs, &test_fetches)); + status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), + test_device, tf_xla_test_use_jit, &graph, + &node_def, &test_inputs, &test_fetches); + if (!status.ok()) { + LOG(ERROR) << "Test graph construction failed: " << status; + return kFatalError; + } // Check that there's a kernel corresponding to 'node_def' on the device under // test. - Status status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr); + status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr); if (!status.ok()) { VLOG(1) << "Skipping test because there is no corresponding registered " << "kernel on the test device: " << status; - return; + return kInvalid; } - TF_ASSERT_OK(session_->Extend(graph)); - - const std::vector& input_tensors = builder.inputs(); - if (VLOG_IS_ON(1)) { - for (const Tensor& input : input_tensors) { - VLOG(1) << "Input: " << input.DebugString(); - } + status = session_->Extend(graph); + if (!status.ok()) { + LOG(ERROR) << "Session::Extend() failed: " << status; + return kFatalError; } std::vector> expected_feeds(expected_inputs.size()); std::vector> test_feeds(test_inputs.size()); - ASSERT_EQ(input_tensors.size(), expected_inputs.size()); - ASSERT_EQ(input_tensors.size(), test_inputs.size()); + CHECK_EQ(input_tensors.size(), expected_inputs.size()); + CHECK_EQ(input_tensors.size(), test_inputs.size()); for (int i = 0; i < input_tensors.size(); ++i) { expected_feeds[i] = {expected_inputs[i], input_tensors[i]}; @@ -664,21 +782,27 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, Status s = session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs); if (!s.ok()) { - VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test"; - return; + VLOG(1) << "Expected graph failed with status: " << s << ". Ignoring test"; + return kInvalid; } for (const Tensor& expected : expected_outputs) { VLOG(1) << "Expected: " << expected.DebugString(); } VLOG(1) << "Running test graph"; - TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs)); + status = session_->Run(test_feeds, test_fetches, {}, &test_outputs); + if (!status.ok()) { + LOG(ERROR) << "Test graph failed: " << status; + return kFatalError; + } - ASSERT_EQ(expected_outputs.size(), test_outputs.size()); + CHECK_EQ(expected_outputs.size(), test_outputs.size()); for (int j = 0; s.ok() && j < test_outputs.size(); ++j) { s = TensorsAreClose(expected_outputs[j], test_outputs[j], atol, rtol); } TF_EXPECT_OK(s); + + return kOk; } // Helper that converts 'values' to an int32 or int64 Tensor. @@ -698,8 +822,8 @@ Tensor AsIntTensor(DataType dtype, const std::vector& values) { TEST_F(OpTest, Abs) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Abs").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Abs").RandomInput(type).Attr("T", type)); }); } @@ -707,10 +831,10 @@ TEST_F(OpTest, Add) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -725,49 +849,50 @@ TEST_F(OpTest, AddN) { builder.Attr("T", type); builder.Attr("N", n); for (int i = 0; i < n; ++i) { - builder.Input(RandomTensor(type, shape)); + builder.RandomInput(type, shape); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } TEST_F(OpTest, All) { Repeatedly([this]() { - Tensor data = RandomTensor(DT_BOOL); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("All").Input(data).Input(indices).Attr("keep_dims", - keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("All") + .RandomInput(DT_BOOL, data_dims) + .Input(indices) + .Attr("keep_dims", keep_dims)); }); } TEST_F(OpTest, Any) { Repeatedly([this]() { - Tensor data = RandomTensor(DT_BOOL); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Any").Input(data).Input(indices).Attr("keep_dims", - keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Any") + .RandomInput(DT_BOOL, data_dims) + .Input(indices) + .Attr("keep_dims", keep_dims)); }); } TEST_F(OpTest, AvgPool) { Repeatedly([this]() { - WindowedSpatialDims d = ChooseWindowedSpatialDims(2); std::uniform_int_distribution random_int(1, 5); - - int kernel_rows = random_int(generator()), - kernel_cols = random_int(generator()); + std::vector dims = RandomDims(4, 4, 1); + int kernel_rows = + std::uniform_int_distribution(1, dims[1])(generator()); + int kernel_cols = + std::uniform_int_distribution(1, dims[2])(generator()); int stride_rows = random_int(generator()), stride_cols = random_int(generator()); string padding = Choose({"SAME", "VALID"}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool") - .Input( - RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows), - RandomDim(kernel_cols), RandomDim(1)})) + .RandomInput(DT_FLOAT, dims) .Attr("T", DT_FLOAT) .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) .Attr("strides", {1, stride_rows, stride_cols, 1}) @@ -781,23 +906,28 @@ TEST_F(OpTest, AvgPool) { TEST_F(OpTest, AvgPool3D) { Repeatedly([this]() { std::uniform_int_distribution random_int(1, 5); + std::vector dims = RandomDims(5, 5, 1); + std::vector input_dims, kernel_dims, stride_dims; for (int i = 0; i < 3; ++i) { - kernel_dims.push_back(random_int(generator())); - input_dims.push_back(RandomDim(kernel_dims.back())); + kernel_dims.push_back( + std::uniform_int_distribution(1, dims[i])(generator())); + input_dims.push_back(dims[i]); stride_dims.push_back(random_int(generator())); } + int64 batch = dims[3]; + int64 feature = dims[4]; string padding = Choose({"SAME", "VALID"}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool3D") - .Input(RandomTensor(DT_FLOAT, ImageDims(FORMAT_NHWC, RandomDim(1), - RandomDim(1), input_dims))) + .RandomInput(DT_FLOAT, + ImageDims(FORMAT_NHWC, batch, feature, input_dims)) .Attr("T", DT_FLOAT) .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, kernel_dims)) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, stride_dims)) .Attr("padding", padding) - .Attr("data_format", "NHWC")); + .Attr("data_format", "NDHWC")); }); // TODO(phawkins): test NCHW format (not supported by CPU) } @@ -810,15 +940,15 @@ TEST_F(OpTest, AvgPoolGrad) { AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims)); std::vector output_dims = ImageDims(FORMAT_NHWC, batch, features, d.output_dims); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPoolGrad") .Input(test::AsTensor(input_dims)) - .Input(RandomTensor(DT_FLOAT, output_dims)) + .RandomInput(DT_FLOAT, output_dims) .Attr("T", DT_FLOAT) .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims)) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") - .Attr("data_format", "NHWC")); + .Attr("data_format", "NDHWC")); }); } @@ -830,15 +960,15 @@ TEST_F(OpTest, AvgPool3DGrad) { AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims)); std::vector output_dims = ImageDims(FORMAT_NHWC, batch, features, d.output_dims); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool3DGrad") .Input(test::AsTensor(input_dims)) - .Input(RandomTensor(DT_FLOAT, output_dims)) + .RandomInput(DT_FLOAT, output_dims) .Attr("T", DT_FLOAT) .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims)) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") - .Attr("data_format", "NHWC")); + .Attr("data_format", "NDHWC")); }); } @@ -850,32 +980,23 @@ TEST_F(OpTest, BatchMatMul) { std::vector x_dims(output_dims), y_dims(output_dims); x_dims[ndims - 1] = inner_dim; y_dims[ndims - 2] = inner_dim; - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .Input(RandomTensor(DT_FLOAT, x_dims)) - .Input(RandomTensor(DT_FLOAT, y_dims)) - .Attr("T", DT_FLOAT)); - - std::swap(x_dims[ndims - 1], x_dims[ndims - 2]); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .Input(RandomTensor(DT_FLOAT, x_dims)) - .Input(RandomTensor(DT_FLOAT, y_dims)) - .Attr("T", DT_FLOAT) - .Attr("adj_x", true)); - - std::swap(y_dims[ndims - 1], y_dims[ndims - 2]); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .Input(RandomTensor(DT_FLOAT, x_dims)) - .Input(RandomTensor(DT_FLOAT, y_dims)) - .Attr("T", DT_FLOAT) - .Attr("adj_x", true) - .Attr("adj_y", true)); - - std::swap(x_dims[ndims - 1], x_dims[ndims - 2]); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .Input(RandomTensor(DT_FLOAT, x_dims)) - .Input(RandomTensor(DT_FLOAT, y_dims)) - .Attr("T", DT_FLOAT) - .Attr("adj_y", true)); + + std::bernoulli_distribution random_bool; + bool adj_x = random_bool(generator()); + bool adj_y = random_bool(generator()); + if (adj_x) { + std::swap(x_dims[ndims - 1], x_dims[ndims - 2]); + } + if (adj_y) { + std::swap(y_dims[ndims - 1], y_dims[ndims - 2]); + } + + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") + .RandomInput(DT_FLOAT, x_dims) + .RandomInput(DT_FLOAT, y_dims) + .Attr("T", DT_FLOAT) + .Attr("adj_x", adj_x) + .Attr("adj_y", adj_y)); }); } @@ -905,11 +1026,11 @@ TEST_F(OpTest, BatchToSpace) { CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), TensorShape({num_block_dims, 2}))); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace") - .Input(RandomTensor(DT_FLOAT, input_dims)) - .Input(crops) - .Attr("T", DT_FLOAT) - .Attr("block_size", block_size)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace") + .RandomInput(DT_FLOAT, input_dims) + .Input(crops) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); }); } @@ -942,9 +1063,9 @@ TEST_F(OpTest, BatchToSpaceND) { CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), TensorShape({num_block_dims, 2}))); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("BatchToSpaceND") - .Input(RandomTensor(DT_FLOAT, input_dims)) + .RandomInput(DT_FLOAT, input_dims) .Input(test::AsTensor( std::vector(block_dims.begin(), block_dims.end()))) .Input(crops) @@ -954,29 +1075,32 @@ TEST_F(OpTest, BatchToSpaceND) { TEST_F(OpTest, BiasAdd) { Repeatedly([this]() { - auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); - auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)}); + auto x_dims = RandomDims(2, kDefaultMaxRank); + auto y_dims = {x_dims[x_dims.size() - 1]}; // TODO(phawkins): test both data formats. - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BiasAdd").Input(x).Input(y).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd") + .RandomInput(DT_FLOAT, x_dims) + .RandomInput(DT_FLOAT, y_dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, BiasAddGrad) { Repeatedly([this]() { - auto x = RandomTensor(DT_FLOAT); // TODO(phawkins): test both data formats. - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BiasAddGrad").Input(x).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BiasAddGrad").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, BiasAddV1) { Repeatedly([this]() { - auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); - auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BiasAddV1").Input(x).Input(y).Attr("T", DT_FLOAT)); + auto x_dims = RandomDims(2, kDefaultMaxRank); + auto y_dims = {x_dims[x_dims.size() - 1]}; + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1") + .RandomInput(DT_FLOAT, x_dims) + .RandomInput(DT_FLOAT, y_dims) + .Attr("T", DT_FLOAT)); }); } @@ -986,10 +1110,11 @@ TEST_F(OpTest, BroadcastGradientArgs) { // DataType type = Choose({DT_INT32, DT_INT64}); DataType type = DT_INT32; auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BroadcastGradientArgs") - .Input(AsIntTensor(type, dims.first)) - .Input(AsIntTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BroadcastGradientArgs") + .Input(AsIntTensor(type, dims.first)) + .Input(AsIntTensor(type, dims.second)) + .Attr("T", type)); }); } @@ -998,18 +1123,17 @@ TEST_F(OpTest, Cast) { DataType src_type, dst_type; src_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); dst_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") - .Input(RandomTensor(src_type)) - .Attr("SrcT", src_type) - .Attr("DstT", dst_type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") + .RandomInput(src_type) + .Attr("SrcT", src_type) + .Attr("DstT", dst_type)); }); } TEST_F(OpTest, Ceil) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Ceil") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Ceil").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } @@ -1029,9 +1153,9 @@ TEST_F(OpTest, Concat) { for (int i = 0; i < n; ++i) { std::vector shape = dims; shape[concat_dim] = RandomDim(); - builder.Input(RandomTensor(type, shape)); + builder.RandomInput(type, shape); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } @@ -1051,7 +1175,7 @@ TEST_F(OpTest, ConcatOffset) { shape[concat_dim] = RandomDim(); builder.Input(test::AsTensor(shape)); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } @@ -1064,15 +1188,15 @@ TEST_F(OpTest, Conv2D) { int64 batch = RandomDim(); - Tensor data = RandomTensor( - DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)); + std::vector data_dims = + ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); - Tensor kernel = RandomTensor(DT_FLOAT, {d.kernel_dims[0], d.kernel_dims[1], - features_in, features_out}); - ExpectTfAndXlaOutputsAreClose( + std::vector kernel_dims = {d.kernel_dims[0], d.kernel_dims[1], + features_in, features_out}; + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2D") - .Input(data) - .Input(kernel) + .RandomInput(DT_FLOAT, data_dims) + .RandomInput(DT_FLOAT, kernel_dims) .Attr("T", DT_FLOAT) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") @@ -1087,17 +1211,17 @@ TEST_F(OpTest, Conv2DBackpropFilter) { int features_in = random_int(generator()); int features_out = random_int(generator()); int32 batch = RandomDim(); - Tensor activations = RandomTensor( - DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)); - Tensor backprop = RandomTensor( - DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims)); + std::vector activations = + ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); + std::vector backprop = + ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out})); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropFilter") - .Input(activations) + .RandomInput(DT_FLOAT, activations) .Input(kernel_shape) - .Input(backprop) + .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") @@ -1114,15 +1238,15 @@ TEST_F(OpTest, Conv2DBackpropInput) { int32 batch = RandomDim(); Tensor in_shape = test::AsTensor( AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); - Tensor backprop = RandomTensor( - DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims)); - Tensor kernel = RandomTensor(DT_FLOAT, {d.kernel_dims[0], d.kernel_dims[1], - features_in, features_out}); - ExpectTfAndXlaOutputsAreClose( + std::vector backprop = + ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); + std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], + features_in, features_out}; + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropInput") .Input(in_shape) - .Input(kernel) - .Input(backprop) + .RandomInput(DT_FLOAT, kernel) + .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") @@ -1136,17 +1260,15 @@ TEST_F(OpTest, Conv3D) { std::uniform_int_distribution random_int(1, 5); int features_in = random_int(generator()); int features_out = random_int(generator()); - Tensor data = - RandomTensor(DT_FLOAT, {RandomDim(), d.input_dims[0], d.input_dims[1], - d.input_dims[2], features_in}); - - Tensor kernel = - RandomTensor(DT_FLOAT, {d.kernel_dims[0], d.kernel_dims[1], - d.kernel_dims[2], features_in, features_out}); - ExpectTfAndXlaOutputsAreClose( + std::vector data = {RandomDim(), d.input_dims[0], d.input_dims[1], + d.input_dims[2], features_in}; + + std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], + d.kernel_dims[2], features_in, features_out}; + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3D") - .Input(data) - .Input(kernel) + .RandomInput(DT_FLOAT, data) + .RandomInput(DT_FLOAT, kernel) .Attr("T", DT_FLOAT) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); @@ -1160,18 +1282,18 @@ TEST_F(OpTest, Conv3DBackpropFilter) { int features_in = random_int(generator()); int features_out = random_int(generator()); int32 batch = RandomDim(1); - Tensor activations = RandomTensor( - DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)); - Tensor backprop = RandomTensor( - DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims)); + std::vector activations = + ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); + std::vector backprop = + ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); Tensor kernel_shape = test::AsTensor( AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out})); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3DBackpropFilterV2") - .Input(activations) + .RandomInput(DT_FLOAT, activations) .Input(kernel_shape) - .Input(backprop) + .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); @@ -1187,16 +1309,15 @@ TEST_F(OpTest, Conv3DBackpropInput) { int32 batch = RandomDim(1); Tensor in_shape = test::AsTensor( AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); - Tensor backprop = RandomTensor( - DT_FLOAT, ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims)); - Tensor kernel = - RandomTensor(DT_FLOAT, {d.kernel_dims[0], d.kernel_dims[1], - d.kernel_dims[2], features_in, features_out}); - ExpectTfAndXlaOutputsAreClose( + std::vector backprop = + ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); + std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], + d.kernel_dims[2], features_in, features_out}; + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3DBackpropInputV2") .Input(in_shape) - .Input(kernel) - .Input(backprop) + .RandomInput(DT_FLOAT, kernel) + .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); @@ -1206,9 +1327,15 @@ TEST_F(OpTest, Conv3DBackpropInput) { TEST_F(OpTest, Diag) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Diag") - .Input(RandomTensor(type, RandomDims(1))) - .Attr("T", type)); + std::vector dims; + // Diag causes a quadratic blowup in output size. + int64 size; + do { + dims = RandomDims(1); + size = TensorShape(dims).num_elements(); + } while (size * size < tf_xla_max_tensor_size); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type)); }); } @@ -1220,9 +1347,9 @@ TEST_F(OpTest, DiagPart) { std::vector doubled_dims(dims.size() * 2); std::copy(dims.begin(), dims.end(), doubled_dims.begin()); std::copy(dims.begin(), dims.end(), doubled_dims.begin() + dims.size()); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart") - .Input(RandomTensor(type, doubled_dims)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart") + .RandomInput(type, doubled_dims) + .Attr("T", type)); }); } @@ -1230,10 +1357,10 @@ TEST_F(OpTest, Div) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1282,27 +1409,26 @@ TEST_F(OpTest, DynamicStitch) { std::vector dims(index_dims[i].begin(), index_dims[i].end()); std::copy(constant_dims.begin(), constant_dims.end(), std::back_inserter(dims)); - Tensor t = RandomTensor(type, dims); - builder.Input(t); + builder.RandomInput(type, dims); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } TEST_F(OpTest, Elu) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Elu").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Elu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, EluGrad) { Repeatedly([this]() { auto dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } @@ -1310,50 +1436,51 @@ TEST_F(OpTest, Equal) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Exp) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Exp").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Exp").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, ExpandDims) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor in = RandomTensor(type); + std::vector in_dims = RandomDims(); Tensor dim(DT_INT32, TensorShape()); - std::uniform_int_distribution d(-1 - in.dims(), in.dims()); + std::uniform_int_distribution d(-1 - in_dims.size(), in_dims.size()); dim.scalar()() = d(generator()); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("ExpandDims").Input(in).Input(dim).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ExpandDims") + .RandomInput(type, in_dims) + .Input(dim) + .Attr("T", type)); }); } TEST_F(OpTest, Fill) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor scalar = RandomTensor(type, {}); std::vector dims = RandomDims(); std::vector shape(dims.begin(), dims.end()); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Fill") - .Input(test::AsTensor(shape)) - .Input(scalar) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Fill") + .Input(test::AsTensor(shape)) + .RandomInput(type, {}) + .Attr("T", type)); }); } TEST_F(OpTest, Floor) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Floor") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Floor").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } @@ -1361,10 +1488,10 @@ TEST_F(OpTest, FloorDiv) { Repeatedly([this]() { DataType type = DT_INT32; auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1372,10 +1499,10 @@ TEST_F(OpTest, FloorMod) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1383,10 +1510,10 @@ TEST_F(OpTest, Greater) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1394,18 +1521,10 @@ TEST_F(OpTest, GreaterEqual) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); - }); -} - -TEST_F(OpTest, Reciprocal) { - Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reciprocal") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1413,9 +1532,9 @@ TEST_F(OpTest, L2Loss) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); // TODO(b/31644876): scalars currently crash. - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss") - .Input(RandomTensor(type, RandomDims(1))) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss") + .RandomInput(type, RandomDims(1)) + .Attr("T", type)); }); } @@ -1423,10 +1542,10 @@ TEST_F(OpTest, Less) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1434,10 +1553,10 @@ TEST_F(OpTest, LessEqual) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1449,10 +1568,10 @@ TEST_F(OpTest, LinSpace) { }; std::uniform_int_distribution distribution(-50, 50); DataType type = Choose({DT_INT32, DT_INT64}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LinSpace") - .Input(RandomTensor(DT_FLOAT, {})) - .Input(RandomTensor(DT_FLOAT, {})) + .RandomInput(DT_FLOAT, {}) + .RandomInput(DT_FLOAT, {}) .Input(ToScalar(type, distribution(generator()))) .Attr("T", DT_FLOAT) .Attr("Tidx", type)); @@ -1461,62 +1580,62 @@ TEST_F(OpTest, LinSpace) { TEST_F(OpTest, Log) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Log").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Log").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, LogicalAnd) { Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LogicalAnd") - .Input(RandomTensor(DT_BOOL, dims.first)) - .Input(RandomTensor(DT_BOOL, dims.second))); + .RandomInput(DT_BOOL, dims.first) + .RandomInput(DT_BOOL, dims.second)); }); } TEST_F(OpTest, LogicalNot) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("LogicalNot").Input(RandomTensor(DT_BOOL))); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LogicalNot").RandomInput(DT_BOOL)); }); } TEST_F(OpTest, LogicalOr) { Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LogicalOr") - .Input(RandomTensor(DT_BOOL, dims.first)) - .Input(RandomTensor(DT_BOOL, dims.second))); + .RandomInput(DT_BOOL, dims.first) + .RandomInput(DT_BOOL, dims.second)); }); } TEST_F(OpTest, LogSoftmax) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LogSoftmax") - .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2))) + .RandomInput(DT_FLOAT, RandomDims(2, 2)) .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, LRN) { Repeatedly([this]() { - Tensor data; // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed. - data = RandomTensor(DT_FLOAT, RandomDims(4, 4, 1, 8)); + std::vector data_dims = RandomDims(4, 4, 1, 8); // CuDNN requires depth_radius > 0. - std::uniform_int_distribution radius(1, data.dim_size(3)); + std::uniform_int_distribution radius(1, data_dims[3]); std::uniform_real_distribution coeff(0.01, 2.0); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRN") - .Input(data) - .Attr("T", DT_FLOAT) - .Attr("depth_radius", radius(generator())) - .Attr("bias", coeff(generator())) - .Attr("alpha", coeff(generator())) - .Attr("beta", coeff(generator()))); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LRN") + .RandomInput(DT_FLOAT, data_dims) + .Attr("T", DT_FLOAT) + .Attr("depth_radius", radius(generator())) + .Attr("bias", coeff(generator())) + .Attr("alpha", coeff(generator())) + .Attr("beta", coeff(generator()))); }); } @@ -1524,21 +1643,19 @@ TEST_F(OpTest, LRNGrad) { Repeatedly([this]() { // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed. std::vector dims = RandomDims(4, 4, 1, 8); - Tensor input_grads = RandomTensor(DT_FLOAT, dims); - Tensor input_image = RandomTensor(DT_FLOAT, dims); - Tensor output_image = RandomTensor(DT_FLOAT, dims); // CuDNN requires depth_radius > 0. - std::uniform_int_distribution radius(1, input_grads.dim_size(3)); + std::uniform_int_distribution radius(1, dims[3]); std::uniform_real_distribution coeff(0.0, 2.0); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRNGrad") - .Input(input_grads) - .Input(input_image) - .Input(output_image) - .Attr("T", DT_FLOAT) - .Attr("depth_radius", radius(generator())) - .Attr("bias", coeff(generator())) - .Attr("alpha", coeff(generator())) - .Attr("beta", coeff(generator()))); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LRNGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT) + .Attr("depth_radius", radius(generator())) + .Attr("bias", coeff(generator())) + .Attr("alpha", coeff(generator())) + .Attr("beta", coeff(generator()))); }); } @@ -1548,59 +1665,57 @@ TEST_F(OpTest, MatMul) { int64 y = RandomDim(); int64 z = RandomDim(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .Input(RandomTensor(DT_FLOAT, {x, y})) - .Input(RandomTensor(DT_FLOAT, {y, z})) - .Attr("T", DT_FLOAT)); - - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .Input(RandomTensor(DT_FLOAT, {y, x})) - .Input(RandomTensor(DT_FLOAT, {y, z})) - .Attr("T", DT_FLOAT) - .Attr("transpose_a", true)); + std::vector a_dims = {x, y}; + std::vector b_dims = {y, z}; - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .Input(RandomTensor(DT_FLOAT, {x, y})) - .Input(RandomTensor(DT_FLOAT, {z, y})) - .Attr("T", DT_FLOAT) - .Attr("transpose_b", true)); + std::bernoulli_distribution random_bool; + bool transpose_a = random_bool(generator()); + bool transpose_b = random_bool(generator()); + if (transpose_a) { + std::swap(a_dims[0], a_dims[1]); + } + if (transpose_b) { + std::swap(b_dims[0], b_dims[1]); + } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .Input(RandomTensor(DT_FLOAT, {y, x})) - .Input(RandomTensor(DT_FLOAT, {z, y})) - .Attr("T", DT_FLOAT) - .Attr("transpose_a", true) - .Attr("transpose_b", true)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") + .RandomInput(DT_FLOAT, a_dims) + .RandomInput(DT_FLOAT, b_dims) + .Attr("T", DT_FLOAT) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b)); }); } TEST_F(OpTest, MatrixDiag) { Repeatedly([this]() { - DataType type = Choose({DT_BOOL, DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag") - .Input(RandomTensor(type, RandomDims(1))) - .Attr("T", type)); + DataType type = Choose({DT_INT32, DT_FLOAT}); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag") + .RandomInput(type, RandomDims(1)) + .Attr("T", type)); }); } TEST_F(OpTest, MatrixDiagPart) { Repeatedly([this]() { - DataType type = Choose({DT_BOOL, DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart") - .Input(RandomTensor(type, RandomDims(2))) - .Attr("T", type)); + DataType type = Choose({DT_INT32, DT_FLOAT}); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart") + .RandomInput(type, RandomDims(2)) + .Attr("T", type)); }); } TEST_F(OpTest, Max) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Max").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Max") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } @@ -1608,26 +1723,28 @@ TEST_F(OpTest, Maximum) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, MaxPool) { Repeatedly([this]() { std::uniform_int_distribution random_int(1, 5); - int kernel_rows = random_int(generator()), - kernel_cols = random_int(generator()); + std::vector dims = RandomDims(4, 4, 1); + int kernel_rows = + std::uniform_int_distribution(1, dims[1])(generator()); + int kernel_cols = + std::uniform_int_distribution(1, dims[2])(generator()); int stride_rows = random_int(generator()), stride_cols = random_int(generator()); + string padding = Choose({"SAME", "VALID"}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("MaxPool") - .Input( - RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows), - RandomDim(kernel_cols), RandomDim(1)})) + .RandomInput(DT_FLOAT, dims) .Attr("T", DT_FLOAT) .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) .Attr("strides", {1, stride_rows, stride_cols, 1}) @@ -1640,28 +1757,32 @@ TEST_F(OpTest, MaxPool) { TEST_F(OpTest, MaxPool3D) { Repeatedly([this]() { std::uniform_int_distribution random_int(1, 5); - std::vector input_dims; - std::vector kernel_dims, stride_dims; - input_dims.push_back(RandomDim(1)); + std::vector dims = RandomDims(5, 5, 1); + + std::vector input_dims, kernel_dims, stride_dims; kernel_dims.push_back(1); stride_dims.push_back(1); for (int i = 0; i < 3; ++i) { - kernel_dims.push_back(random_int(generator())); - input_dims.push_back(RandomDim(kernel_dims.back())); + kernel_dims.push_back( + std::uniform_int_distribution(1, dims[i])(generator())); + input_dims.push_back(dims[i]); stride_dims.push_back(random_int(generator())); } - input_dims.push_back(RandomDim(1)); kernel_dims.push_back(1); stride_dims.push_back(1); + int64 batch = dims[3]; + int64 feature = dims[4]; string padding = Choose({"SAME", "VALID"}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MaxPool3D") - .Input(RandomTensor(DT_FLOAT, input_dims)) - .Attr("T", DT_FLOAT) - .Attr("ksize", kernel_dims) - .Attr("strides", stride_dims) - .Attr("padding", padding) - .Attr("data_format", "NHWC")); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("MaxPool3D") + .RandomInput(DT_FLOAT, + ImageDims(FORMAT_NHWC, batch, feature, input_dims)) + .Attr("T", DT_FLOAT) + .Attr("ksize", kernel_dims) + .Attr("strides", stride_dims) + .Attr("padding", padding) + .Attr("data_format", "NDHWC")); }); // TODO(phawkins): test NCHW format (not supported by CPU) } @@ -1671,24 +1792,28 @@ TEST_F(OpTest, Mean) { DataType type = Choose({DT_INT32, DT_FLOAT}); // TODO(phawkins): CPU and XLA differ output for reducing across a // size-0 dimension (nan vs 0). For now, require size >= 1. - Tensor data = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 1)); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(0, kDefaultMaxRank, 1); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Mean").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mean") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } TEST_F(OpTest, Min) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Min").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Min") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } @@ -1696,21 +1821,20 @@ TEST_F(OpTest, Minimum) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Mod) { Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Mod") - .Input(RandomTensor(DT_INT32, dims.first)) - .Input(RandomTensor(DT_INT32, dims.second)) - .Attr("T", DT_INT32)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mod") + .RandomInput(DT_INT32, dims.first) + .RandomInput(DT_INT32, dims.second) + .Attr("T", DT_INT32)); }); } @@ -1718,18 +1842,18 @@ TEST_F(OpTest, Mul) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Neg) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Neg").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Neg").RandomInput(type).Attr("T", type)); }); } @@ -1737,10 +1861,10 @@ TEST_F(OpTest, NotEqual) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1768,9 +1892,17 @@ TEST_F(OpTest, OneHot) { builder.Attr("axis", axis); builder.Input(indices); builder.Input(test::AsScalar(depth)); - builder.Input(RandomTensor(type, {})); - builder.Input(RandomTensor(type, {})); - ExpectTfAndXlaOutputsAreClose(builder); + builder.RandomInput(type, {}); + builder.RandomInput(type, {}); + return ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +TEST_F(OpTest, OnesLike) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type)); }); } @@ -1789,9 +1921,9 @@ TEST_F(OpTest, Pack) { builder.Attr("N", n); builder.Attr("axis", axis); for (int i = 0; i < n; ++i) { - builder.Input(RandomTensor(type, dims)); + builder.RandomInput(type, dims); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } @@ -1799,23 +1931,26 @@ TEST_F(OpTest, Pack) { TEST_F(OpTest, Pad) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor t = RandomTensor(type); + std::vector t_dims = RandomDims(); // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. // DataType tpaddings = Choose({DT_INT32, DT_INT64}); DataType tpaddings = DT_INT32; std::vector paddings_vec; std::uniform_int_distribution distribution(0, 7); - for (int i = 0; i < t.dims(); ++i) { + for (int i = 0; i < t_dims.size(); ++i) { paddings_vec.push_back(distribution(generator())); paddings_vec.push_back(distribution(generator())); } Tensor paddings; - CHECK(paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec), - TensorShape({t.dims(), 2}))); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Pad").Input(t).Input(paddings).Attr("T", type).Attr( - "Tpaddings", tpaddings)); + CHECK( + paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec), + TensorShape({static_cast(t_dims.size()), 2}))); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pad") + .RandomInput(type, t_dims) + .Input(paddings) + .Attr("T", type) + .Attr("Tpaddings", tpaddings)); }); } @@ -1824,23 +1959,24 @@ TEST_F(OpTest, Pow) { // nontermination. Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Pow") - .Input(RandomTensor(DT_FLOAT, dims.first)) - .Input(RandomTensor(DT_FLOAT, dims.second)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow") + .RandomInput(DT_FLOAT, dims.first) + .RandomInput(DT_FLOAT, dims.second) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Prod) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Prod").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Prod") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } @@ -1855,7 +1991,7 @@ TEST_F(OpTest, Range) { }; std::uniform_int_distribution distribution(-50, 50); DataType tidx = Choose({DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Range") .Input(ToScalar(tidx, distribution(generator()))) .Input(ToScalar(tidx, distribution(generator()))) @@ -1867,8 +2003,8 @@ TEST_F(OpTest, Range) { TEST_F(OpTest, Rank) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Rank").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Rank").RandomInput(type).Attr("T", type)); }); } @@ -1876,46 +2012,51 @@ TEST_F(OpTest, RealDiv) { Repeatedly([this]() { DataType type = DT_FLOAT; auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Reciprocal) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Reciprocal").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Relu) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Relu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Relu6) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Relu6").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Relu6Grad) { Repeatedly([this]() { auto dims = RandomDims(1); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, ReluGrad) { Repeatedly([this]() { auto dims = RandomDims(1); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } @@ -1937,10 +2078,9 @@ TEST_F(OpTest, Reshape) { } } } - Tensor data = RandomTensor(type, dims_before); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Reshape") - .Input(data) + .RandomInput(type, dims_before) .Input(test::AsTensor( std::vector(dims_after.begin(), dims_after.end()))) .Attr("T", type)); @@ -1952,56 +2092,54 @@ TEST_F(OpTest, Reverse) { std::vector dims = RandomDims(1); DataType type = Choose({DT_INT32, DT_FLOAT}); int64 rank = dims.size(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse") - .Input(RandomTensor(type, dims)) - .Input(RandomTensor(DT_BOOL, {rank})) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse") + .RandomInput(type, dims) + .RandomInput(DT_BOOL, {rank}) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, ReverseV2) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2") - .Input(data) - .Input(indices) - .Attr("T", DT_FLOAT)); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Round) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Round") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Round").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Rsqrt) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Rsqrt") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Rsqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, RsqrtGrad) { Repeatedly([this]() { auto dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Shape) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Shape").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Shape").RandomInput(type).Attr("T", type)); }); } @@ -2013,72 +2151,72 @@ TEST_F(OpTest, ShapeN) { builder.Attr("T", type); builder.Attr("N", n); for (int i = 0; i < n; ++i) { - builder.Input(RandomTensor(type)); + builder.RandomInput(type); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } TEST_F(OpTest, Sigmoid) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sigmoid") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sigmoid").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, SigmoidGrad) { Repeatedly([this]() { auto dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Sign) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sign").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sign").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Size) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Size").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Size").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Slice) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor data = RandomTensor(type); + std::vector data_dims = RandomDims(); - std::vector begin(data.dims()), size(data.dims()); - for (int i = 0; i < data.dims(); ++i) { - begin[i] = std::uniform_int_distribution( - 0, data.dim_size(i))(generator()); + std::vector begin(data_dims.size()), size(data_dims.size()); + for (int i = 0; i < data_dims.size(); ++i) { + begin[i] = + std::uniform_int_distribution(0, data_dims[i])(generator()); size[i] = std::uniform_int_distribution( - -1, data.dim_size(i) - begin[i])(generator()); + -1, data_dims[i] - begin[i])(generator()); } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice") - .Input(data) - .Input(test::AsTensor(begin)) - .Input(test::AsTensor(size)) - .Attr("T", type) - .Attr("Index", DT_INT32)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Slice") + .RandomInput(type, data_dims) + .Input(test::AsTensor(begin)) + .Input(test::AsTensor(size)) + .Attr("T", type) + .Attr("Index", DT_INT32)); }); } TEST_F(OpTest, Softmax) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Softmax") - .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2))) + .RandomInput(DT_FLOAT, RandomDims(2, 2)) .Attr("T", DT_FLOAT)); }); } @@ -2086,28 +2224,28 @@ TEST_F(OpTest, Softmax) { TEST_F(OpTest, SoftmaxCrossEntropyWithLogits) { Repeatedly([this]() { std::vector dims = RandomDims(2, 2, 1); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftmaxCrossEntropyWithLogits") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("SoftmaxCrossEntropyWithLogits") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Softplus) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Softplus") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Softplus").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, SoftplusGrad) { Repeatedly([this]() { std::vector dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } @@ -2141,11 +2279,11 @@ TEST_F(OpTest, SpaceToBatch) { CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), TensorShape({num_block_dims, 2}))); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch") - .Input(RandomTensor(DT_FLOAT, input_dims)) - .Input(paddings) - .Attr("T", DT_FLOAT) - .Attr("block_size", block_size)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch") + .RandomInput(DT_FLOAT, input_dims) + .Input(paddings) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); }); } @@ -2182,9 +2320,9 @@ TEST_F(OpTest, SpaceToBatchND) { CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), TensorShape({num_block_dims, 2}))); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SpaceToBatchND") - .Input(RandomTensor(DT_FLOAT, input_dims)) + .RandomInput(DT_FLOAT, input_dims) .Input(test::AsTensor( std::vector(block_dims.begin(), block_dims.end()))) .Input(paddings) @@ -2198,33 +2336,26 @@ TEST_F(OpTest, SparseMatMul) { int64 y = RandomDim(); int64 z = RandomDim(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") - .Input(RandomTensor(DT_FLOAT, {x, y})) - .Input(RandomTensor(DT_FLOAT, {y, z})) - .Attr("Ta", DT_FLOAT) - .Attr("Tb", DT_FLOAT)); - - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") - .Input(RandomTensor(DT_FLOAT, {y, x})) - .Input(RandomTensor(DT_FLOAT, {y, z})) - .Attr("Ta", DT_FLOAT) - .Attr("Tb", DT_FLOAT) - .Attr("transpose_a", true)); + std::vector a_dims = {x, y}; + std::vector b_dims = {y, z}; - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") - .Input(RandomTensor(DT_FLOAT, {x, y})) - .Input(RandomTensor(DT_FLOAT, {z, y})) - .Attr("Ta", DT_FLOAT) - .Attr("Tb", DT_FLOAT) - .Attr("transpose_b", true)); + std::bernoulli_distribution random_bool; + bool transpose_a = random_bool(generator()); + bool transpose_b = random_bool(generator()); + if (transpose_a) { + std::swap(a_dims[0], a_dims[1]); + } + if (transpose_b) { + std::swap(b_dims[0], b_dims[1]); + } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") - .Input(RandomTensor(DT_FLOAT, {y, x})) - .Input(RandomTensor(DT_FLOAT, {z, y})) - .Attr("Ta", DT_FLOAT) - .Attr("Tb", DT_FLOAT) - .Attr("transpose_a", true) - .Attr("transpose_b", true)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") + .RandomInput(DT_FLOAT, a_dims) + .RandomInput(DT_FLOAT, b_dims) + .Attr("Ta", DT_FLOAT) + .Attr("Tb", DT_FLOAT) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b)); }); } @@ -2240,9 +2371,9 @@ TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) { std::uniform_int_distribution(0, num_classes - 1)(generator()); } - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits") - .Input(RandomTensor(DT_FLOAT, dims)) + .RandomInput(DT_FLOAT, dims) .Input(test::AsTensor(indices)) .Attr("T", DT_FLOAT) .Attr("Tlabels", DT_INT32)); @@ -2260,56 +2391,54 @@ TEST_F(OpTest, Split) { // Ensure 'dim' is evenly divisible by 'n'. dims[dim] /= n; dims[dim] *= n; - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split") - .Input(test::AsScalar(dim)) - .Input(RandomTensor(type, dims)) - .Attr("T", type) - .Attr("num_split", n)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split") + .Input(test::AsScalar(dim)) + .RandomInput(type, dims) + .Attr("T", type) + .Attr("num_split", n)); }); } TEST_F(OpTest, Sqrt) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sqrt") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, SquaredDifference) { Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("SquaredDifference") - .Input(RandomTensor(DT_FLOAT, dims.first)) - .Input(RandomTensor(DT_FLOAT, dims.second)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SquaredDifference") + .RandomInput(DT_FLOAT, dims.first) + .RandomInput(DT_FLOAT, dims.second) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Square) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Square").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Square").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Squeeze) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor t = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 0, 5)); + std::vector t_dims = RandomDims(0, kDefaultMaxRank, 0, 5); std::bernoulli_distribution random_bool; std::vector squeeze_dims; - for (int i = 0; i < t.dims(); ++i) { - if (t.dim_size(i) == 1 && random_bool(generator())) { + for (int i = 0; i < t_dims.size(); ++i) { + if (t_dims[i] == 1 && random_bool(generator())) { squeeze_dims.push_back(i); } } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze") - .Input(t) - .Attr("squeeze_dims", squeeze_dims) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze") + .RandomInput(type, t_dims) + .Attr("squeeze_dims", squeeze_dims) + .Attr("T", type)); }); } @@ -2317,58 +2446,59 @@ TEST_F(OpTest, Sub) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Sum) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sum").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sum") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } TEST_F(OpTest, StridedSlice) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor data = RandomTensor(type); - - std::vector begin(data.dims()), end(data.dims()); - std::vector strides(data.dims()); - for (int i = 0; i < data.dims(); ++i) { + std::vector data_dims = RandomDims(); + std::vector begin(data_dims.size()), end(data_dims.size()); + std::vector strides(data_dims.size()); + for (int i = 0; i < data_dims.size(); ++i) { begin[i] = std::uniform_int_distribution( - -2 * data.dim_size(i), 2 * data.dim_size(i))(generator()); + -2 * data_dims[i], 2 * data_dims[i])(generator()); end[i] = std::uniform_int_distribution( - -2 * data.dim_size(i), 2 * data.dim_size(i))(generator()); + -2 * data_dims[i], 2 * data_dims[i])(generator()); // TODO(b/31360685): support strides other than 1 or -1 strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1; } - int64 max_bitmask = (1LL << data.dims()) - 1; + int64 max_bitmask = (1LL << data_dims.size()) - 1; std::uniform_int_distribution bitmask_distribution(0, max_bitmask); int64 begin_mask = bitmask_distribution(generator()); int64 end_mask = bitmask_distribution(generator()); // Create a ellipsis bitmask with at most one 1 bit set. int64 ellipsis_mask = 0; - if (data.dims() > 0 && std::bernoulli_distribution()(generator())) { - int ellipsis_pos = - std::uniform_int_distribution(0, data.dims() - 1)(generator()); + if (!data_dims.empty() && std::bernoulli_distribution()(generator())) { + int ellipsis_pos = std::uniform_int_distribution( + 0, data_dims.size() - 1)(generator()); ellipsis_mask = 1LL << ellipsis_pos; } int64 new_axis_mask = bitmask_distribution(generator()); int64 shrink_axis_mask = bitmask_distribution(generator()); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("StridedSlice") - .Input(data) + .RandomInput(type, data_dims) .Input(test::AsTensor(begin)) .Input(test::AsTensor(end)) .Input(test::AsTensor(strides)) @@ -2418,13 +2548,13 @@ TEST_F(OpTest, StridedSliceGrad) { // TODO(phawkins): use shape inference for the forward op to compute the // gradient shape for the backward op. At present, there is a low // probability of the golden op succeeding. - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("StridedSliceGrad") .Input(test::AsTensor(dims)) .Input(test::AsTensor(begin)) .Input(test::AsTensor(end)) .Input(test::AsTensor(strides)) - .Input(RandomTensor(type, RandomDims(1))) + .RandomInput(type, RandomDims(1)) .Attr("T", type) .Attr("Index", DT_INT64) .Attr("begin_mask", begin_mask) @@ -2437,48 +2567,48 @@ TEST_F(OpTest, StridedSliceGrad) { TEST_F(OpTest, Tanh) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tanh") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Tanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, TanhGrad) { Repeatedly([this]() { auto dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Tile) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor t = RandomTensor(type, RandomDims(1)); - std::vector multiples(t.dims()); - for (int i = 0; i < t.dims(); ++i) { + std::vector t_dims = RandomDims(1); + std::vector multiples(t_dims.size()); + for (int i = 0; i < t_dims.size(); ++i) { multiples[i] = std::uniform_int_distribution(1, 3)(generator()); } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tile") - .Input(t) - .Input(test::AsTensor(multiples)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Tile") + .RandomInput(type, t_dims) + .Input(test::AsTensor(multiples)) + .Attr("T", type)); }); } TEST_F(OpTest, Transpose) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor data = RandomTensor(type); - std::vector perm(data.dims()); + std::vector data_dims = RandomDims(); + std::vector perm(data_dims.size()); std::iota(perm.begin(), perm.end(), 0); std::shuffle(perm.begin(), perm.end(), generator()); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose") - .Input(data) - .Input(test::AsTensor(perm)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose") + .RandomInput(type, data_dims) + .Input(test::AsTensor(perm)) + .Attr("T", type)); }); } @@ -2486,10 +2616,10 @@ TEST_F(OpTest, TruncateDiv) { Repeatedly([this]() { DataType type = DT_INT32; auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -2497,26 +2627,18 @@ TEST_F(OpTest, TruncateMod) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, ZerosLike) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("ZerosLike").Input(RandomTensor(type)).Attr("T", type)); - }); -} - -TEST_F(OpTest, OnesLike) { - Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type)); }); } @@ -2535,6 +2657,9 @@ int main(int argc, char** argv) { tensorflow::Flag("tf_xla_test_repetitions", &tensorflow::tf_xla_test_repetitions, "Number of repetitions for each test."), + tensorflow::Flag("tf_xla_max_tensor_size", + &tensorflow::tf_xla_max_tensor_size, + "Maximum number of elements for random input tensors."), tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr, "Tensorflow device type to use for test"), tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit, -- cgit v1.2.3 From c19e6cac0413b0b93d5a15f9d4dc7c861aa1c734 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 7 Jun 2017 14:55:31 -0700 Subject: [TF:XLA] Initial implementation of TensorArray ops. The XLA implementation of TensorArrays is more restrictive than regular TensorArrays: * XLA TensorArrays must have dynamic_size=False. * all elements in an XLA TensorArray must have the same shape. * writes always add their values to any existing values; neither reads nor writes ever issue errors. Out-of-bounds writes currently wrap. Refactor Variable handling in the TF/XLA bridge. Use a XlaVariable* to refer to variables inside compilation rather than a numerical ID. Allow for variables that don't correspond to variables known to the user. Also use XlaVariable to handle TensorArrays. PiperOrigin-RevId: 158322041 --- tensorflow/compiler/tests/BUILD | 19 + tensorflow/compiler/tests/tensor_array_ops_test.py | 1018 ++++++++++++++++++++ tensorflow/compiler/tests/xla_test.py | 16 +- tensorflow/compiler/tf2xla/const_analysis.cc | 2 + tensorflow/compiler/tf2xla/kernels/BUILD | 1 + tensorflow/compiler/tf2xla/kernels/arg_op.cc | 13 +- .../compiler/tf2xla/kernels/tensor_array_ops.cc | 538 +++++++++++ .../compiler/tf2xla/xla_compilation_device.cc | 2 - .../compiler/tf2xla/xla_compilation_device.h | 39 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 39 +- tensorflow/compiler/tf2xla/xla_compiler.h | 4 + tensorflow/compiler/tf2xla/xla_context.cc | 29 +- tensorflow/compiler/tf2xla/xla_context.h | 44 +- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 41 +- tensorflow/compiler/tf2xla/xla_op_kernel.h | 11 +- tensorflow/core/ops/data_flow_ops.cc | 2 +- 16 files changed, 1710 insertions(+), 108 deletions(-) create mode 100644 tensorflow/compiler/tests/tensor_array_ops_test.py create mode 100644 tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 19f7ff8354..d18e51e32c 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -346,6 +346,25 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "tensor_array_ops_test", + size = "small", + srcs = ["tensor_array_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:math_ops_gen", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + "//tensorflow/python:tensor_array_grad", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "ternary_ops_test", size = "small", diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py new file mode 100644 index 0000000000..27a2977305 --- /dev/null +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -0,0 +1,1018 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for XLA TensorArray Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def _make_converter(dtype): + def _converter(x): + return np.asarray(x).astype(dtype.as_numpy_dtype) + return _converter + + +class TensorArrayTest(xla_test.XLATestCase): + + def testTensorArrayWriteRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3) + + w0 = ta.write(0, [[4.0, 5.0]]) + w1 = w0.write(1, [[1.0, 3.0]]) + w2 = w1.write(2, [[7.0, -8.5]]) + + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual([[4.0, 5.0]], d0) + self.assertAllEqual([[1.0, 3.0]], d1) + self.assertAllEqual([[7.0, -8.5]], d2) + + def _testTensorArrayWritePack(self, tf_dtype): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + convert = _make_converter(tf_dtype) + + w0 = ta.write(0, convert([[4.0, 5.0]])) + w1 = w0.write(1, convert([[6.0, 7.0]])) + w2 = w1.write(2, convert([[8.0, 9.0]])) + + c0 = w2.stack() + + self.assertAllEqual( + convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval()) + + def testTensorArrayWritePack(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayWritePack(dtype) + + def testEmptyTensorArrayPack(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + + empty_element = np.zeros((0, 1), dtype=np.float32) + w0 = ta.write(0, empty_element) + w1 = w0.write(1, empty_element) + w2 = w1.write(2, empty_element) + + c0 = w2.stack() + + self.assertAllEqual([3, 0, 1], c0.eval().shape) + + def _testTensorArrayWriteConcat(self, tf_dtype): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + convert = _make_converter(tf_dtype) + + w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0]])) + w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]])) + w2 = w1.write(2, convert([[8.0, 9.0], [204.0, 205.0]])) + + c0 = w2.concat() + + self.assertAllEqual( + convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], + [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval()) + + def testTensorArrayWriteConcat(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayWriteConcat(dtype) + + def _testTensorArrayUnpackRead(self, tf_dtype): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + convert = _make_converter(tf_dtype) + + # Unpack a vector into scalars + w0 = ta.unstack(convert([1.0, 2.0, 3.0])) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert(1.0), d0) + self.assertAllEqual(convert(2.0), d1) + self.assertAllEqual(convert(3.0), d2) + + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + # Unpack a matrix into vectors + w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])) + r0 = w1.read(0) + r1 = w1.read(1) + r2 = w1.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([1.0, 1.1]), d0) + self.assertAllEqual(convert([2.0, 2.1]), d1) + self.assertAllEqual(convert([3.0, 3.1]), d2) + + # Reset ta because we're going to change the shape, else shape + # inference will throw an error. + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + # Try unpacking an empty matrix, which should not cause an error. + w2 = ta.unstack(convert([[], [], []])) + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([]), d0) + self.assertAllEqual(convert([]), d1) + self.assertAllEqual(convert([]), d2) + + def _testTensorArrayUnpackReadMaybeLegacy(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayUnpackRead(dtype) + + def testTensorArrayUnpackRead(self): + self._testTensorArrayUnpackReadMaybeLegacy() + + def _testTensorArraySplitRead(self, tf_dtype): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + convert = _make_converter(tf_dtype) + + # Split an empty vector + lengths = constant_op.constant([0, 0, 0]) + w0 = ta.split(convert([]), lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([]), d0) + self.assertAllEqual(convert([]), d1) + self.assertAllEqual(convert([]), d2) + + # Split a vector + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + lengths = constant_op.constant([1, 1, 1]) + w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([1.0]), d0) + self.assertAllEqual(convert([2.0]), d1) + self.assertAllEqual(convert([3.0]), d2) + + # Split a matrix + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + lengths = constant_op.constant([1, 1, 1]) + w0 = ta.split( + convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([[1.0, 101.0]]), d0) + self.assertAllEqual(convert([[2.0, 201.0]]), d1) + self.assertAllEqual(convert([[3.0, 301.0]]), d2) + + def testTensorArraySplitRead(self): + for dtype in self.numeric_tf_types: + self._testTensorArraySplitRead(dtype) + + def testTensorGradArrayWriteRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3) + + w0 = ta.write(0, [[4.0]]) + w1 = w0.write(1, [[1.0]]) + w2 = w1.write(2, [[-3.0]]) + + g_ta = w2.grad("grad") + + g_w0 = g_ta.write(0, [[5.0]]) + g_w1 = g_w0.write(1, [[2.0]]) + g_w2 = g_w1.write(2, [[-2.0]]) + + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + + g_r0 = g_w2.read(0) + g_r1 = g_w2.read(1) + g_r2 = g_w2.read(2) + + d0, d1, d2, g_d0, g_d1, g_d2 = session.run([r0, r1, r2, g_r0, g_r1, g_r2]) + self.assertAllEqual([[4.0]], d0) + self.assertAllEqual([[1.0]], d1) + self.assertAllEqual([[-3.0]], d2) + self.assertAllEqual([[5.0]], g_d0) + self.assertAllEqual([[2.0]], g_d1) + self.assertAllEqual([[-2.0]], g_d2) + + def testTensorGradArrayDynamicWriteRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3) + + w0 = ta.write(0, [[4.0]]) + w1 = w0.write(1, [[1.0]]) + w2 = w1.write(2, [[-3.0]]) + + g_ta = w2.grad("grad") # Get gradient array here so we know the shape + + s = w2.size() + g_s = g_ta.size() + + g_w0 = g_ta.write(0, [[5.0]]) + g_w1 = g_w0.write(1, [[2.0]]) + g_w2 = g_w1.write(2, [[-2.0]]) + + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + + g_r0 = g_w2.read(0) + g_r1 = g_w2.read(1) + g_r2 = g_w2.read(2) + + d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = session.run( + [r0, r1, r2, g_r0, g_r1, g_r2, s, g_s]) + self.assertAllEqual([[4.0]], d0) + self.assertAllEqual([[1.0]], d1) + self.assertAllEqual([[-3.0]], d2) + self.assertAllEqual([[5.0]], g_d0) + self.assertAllEqual([[2.0]], g_d1) + self.assertAllEqual([[-2.0]], g_d2) + self.assertAllEqual(3, vs) + self.assertAllEqual(3, g_vs) + + def testTensorGradAccessTwiceReceiveSameObject(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3, + element_shape=[1, 2]) + g_ta_0 = ta.grad("grad") + g_ta_1 = ta.grad("grad") + + with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]): + # Write with one gradient handle, read with another copy of it + r1_0 = g_ta_1.read(0) + + t_g_ta_0, t_g_ta_1, d_r1_0 = session.run( + [g_ta_0.handle.op, g_ta_1.handle.op, r1_0]) + self.assertAllEqual(t_g_ta_0, t_g_ta_1) + self.assertAllEqual([[4.0, 5.0]], d_r1_0) + + def testTensorArrayWriteWrongIndexOrDataTypeFails(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + + # Test writing the wrong datatype + with self.assertRaisesOpError( + "TensorArray dtype is float but op has dtype int32"): + ta.write(-1, np.int32(7)).flow.eval() + + def testTensorArrayReadWrongIndexOrDataTypeFails(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + + w0 = ta.write(0, [[4.0, 5.0]]) + + # Test reading wrong datatype + r0_bad = gen_data_flow_ops._tensor_array_read_v3( + handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) + with self.assertRaisesOpError( + "TensorArray dtype is float but Op requested dtype double."): + r0_bad.eval() + + # Test reading from a different index than the one we wrote to + w0.read(1) + + def testTensorArraySplitIncompatibleShapesFails(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + + with self.assertRaisesOpError( + r"value is not 1D"): + lengths = array_ops.placeholder(dtypes.int64) + ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1}) + + with self.assertRaisesOpError( + r"lengths must be equal: 1 vs. 2"): + ta.split([1.0, 2.0, 3.0], [1, 2, 3]).flow.eval() + + with self.assertRaisesOpError( + r"value must have rank >= 1"): + ta.split(1.0, [1]).flow.eval() + + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + infer_shape=False) + + with self.assertRaisesOpError( + r"TensorArray's size is not equal to the size of lengths " + r"\(1 vs. 2\)"): + ta.split([1.0], [1]).flow.eval() + + def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) + + c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype) + + w0 = ta.write(2, c(3.0)) + w1 = w0.write(2, c(4.0)) + + ta_grad = w1.grad("grad") + + w0_grad = ta_grad.write(2, c(3.0)) + w1_grad = w0_grad.write(2, c(4.0)) + w2_grad = w1_grad.write(2, c(5.0)) + + # Assert that aggregation works correctly + self.assertAllEqual(c(12.00), w2_grad.read(2).eval()) + + # Using differing shapes causes an exception + wb0_grad = ta_grad.write(1, c(1.0)) + wb1_grad = wb0_grad.write(1, c([1.0])) + + with self.assertRaisesOpError( + r"Mismatched TensorArray sizes"): + wb1_grad.flow.eval() + + def testTensorArrayWriteGradientAddMultipleAdds(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayWriteGradientAddMultipleAdds(dtype) + + def testMultiTensorArray(self): + with self.test_session(), self.test_scope(): + h1 = tensor_array_ops.TensorArray( + size=1, dtype=dtypes.float32, tensor_array_name="foo") + w1 = h1.write(0, 4.0) + r1 = w1.read(0) + + h2 = tensor_array_ops.TensorArray( + size=1, dtype=dtypes.float32, tensor_array_name="bar") + + w2 = h2.write(0, 5.0) + r2 = w2.read(0) + r = r1 + r2 + self.assertAllClose(9.0, r.eval()) + + def _testTensorArrayGradientWriteReadType(self, dtype): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.as_dtype(dtype), + tensor_array_name="foo", + size=3, + infer_shape=False) + + c = lambda x: np.array(x, dtype=dtype) + + value_0 = constant_op.constant(c([[4.0, 5.0]])) + value_1 = constant_op.constant(c([[3.0, 3.5]])) + + w0 = ta.write(0, value_0) + w1 = w0.write(1, value_1) + r0 = w1.read(0) + r1 = w1.read(1) + r0_2 = w1.read(0) + + # Test individual components' gradients + grad_just_r0 = gradients_impl.gradients( + ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])]) + grad_just_r0_vals = session.run(grad_just_r0) + self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0]) + + grad_r0_r0_2 = gradients_impl.gradients( + ys=[r0, r0_2], + xs=[value_0], + grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])]) + grad_r0_r0_2_vals = session.run(grad_r0_r0_2) + self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0]) + + grad_just_r1 = gradients_impl.gradients( + ys=[r1], xs=[value_1], grad_ys=[c([[-2.0, -4.0]])]) + grad_just_r1_vals = session.run(grad_just_r1) + self.assertAllEqual(c([[-2.0, -4.0]]), grad_just_r1_vals[0]) + + # Test combined gradients + grad = gradients_impl.gradients( + ys=[r0, r0_2, r1], + xs=[value_0, value_1], + grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]]), c([[-2.0, -10.0]])]) + grad_vals = session.run(grad) + self.assertEqual(len(grad_vals), 2) + self.assertAllEqual(c([[3.0, 2.0]]), grad_vals[0]) + self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1]) + + def testTensorArrayGradientWriteRead(self): + for dtype in self.numeric_types: + self._testTensorArrayGradientWriteReadType(dtype) + + def _testTensorArrayGradientWritePackConcatAndRead(self): + with self.test_session() as sess, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) + + value_0 = constant_op.constant([-1.0, 1.0]) + value_1 = constant_op.constant([-10.0, 10.0]) + + w0 = ta.write(0, value_0) + w1 = w0.write(1, value_1) + p0 = w1.stack() + r0 = w1.read(0) + s0 = w1.concat() + + # Test gradient accumulation between read(0), pack(), and concat() + with ops.control_dependencies([p0, r0, s0]): + grad_r = gradients_impl.gradients( + ys=[p0, r0, s0], + xs=[value_0, value_1], + grad_ys=[ + [[2.0, 3.0], [4.0, 5.0]], # stack gradient + [-0.5, 1.5], # read(0) gradient + [20.0, 30.0, 40.0, 50.0], # concat gradient + ]) + grad_vals = sess.run(grad_r) # 2 + 2 entries + + self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0]) + self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1]) + + def testTensorArrayGradientWritePackConcatAndRead(self): + self._testTensorArrayGradientWritePackConcatAndRead() + + def testTensorArrayReadTwice(self): + with self.test_session(), self.test_scope(): + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + + ta_readtwice = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) + w_readtwice = ta_readtwice.unstack(value) + r0_readtwice = w_readtwice.read(0) + with ops.control_dependencies([r0_readtwice]): + r1_readtwice = w_readtwice.read(0) + + self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) + + def _testTensorArrayGradientUnpackRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) + + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + + w = ta.unstack(value) + r0 = w.read(0) + r0_1 = w.read(0) + r1 = w.read(1) + + # Test combined gradients + aggregation of read(0) + grad = gradients_impl.gradients( + ys=[r0, r0_1, r1], + xs=[value], + grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]]) + grad_vals = session.run(grad) + + self.assertEqual(len(grad_vals), 1) + self.assertAllEqual([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0]) + + def testTensorArrayGradientUnpackRead(self): + self._testTensorArrayGradientUnpackRead() + + def testTensorArrayGradientSplitConcat(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=2) + + value = constant_op.constant( + [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0], [1000.0, -1000.0]]) + + w = ta.split(value, [2, 2]) + r = w.concat() + + # Test combined gradients + grad = gradients_impl.gradients( + ys=[r], + xs=[value], + grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0], + [2000.0, -2000.0]]]) + grad_vals = session.run(grad) + + self.assertEqual(len(grad_vals), 1) + self.assertAllEqual([[2.0, -2.0], [20.0, -20.0], [200.0, -200.0], + [2000.0, -2000.0]], + grad_vals[0]) + + # TODO(phawkins): implement TensorArrayClose + # def testCloseTensorArray(self): + # with self.test_session() as session, self.test_scope(): + # ta = tensor_array_ops.TensorArray( + # dtype=dtypes.float32, tensor_array_name="foo", size=3) + # c1 = ta.close() + # session.run(c1) + + def testSizeTensorArray(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + s = ta.size() + self.assertAllEqual(3, s.eval()) + + # TODO(phawkins): implement TensorArrayClose + # def testWriteCloseTensorArray(self): + # with self.test_session(), self.test_scope(): + # ta = tensor_array_ops.TensorArray( + # dtype=dtypes.float32, + # tensor_array_name="foo", + # size=3, + # infer_shape=False) + # w0 = ta.write(0, [[4.0, 5.0]]) + # w1 = w0.write(1, [3.0]) + # w1.close().run() # Expected to run without problems + + # TODO(phawkins): implement while loops. + # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): + # np_dtype = dtype.as_numpy_dtype + # with self.test_session() as session, self.test_scope(): + # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)) + # var = variables.Variable(np.arange(100, 105, dtype=np_dtype)) + # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype)) + # ta = tensor_array_ops.TensorArray( + # dtype=dtype, + # tensor_array_name="foo", + # size=0 if dynamic_size else 3, + # dynamic_size=dynamic_size) + # time_0 = array_ops.identity(0) + + # def body(time, ta_t, state): + # sliced = array_ops.slice( + # v0, begin=array_ops.stack([time, 0]), size=[1, -1]) + # sliced = array_ops.squeeze(sliced) + # out = sliced + var + state + # state += sliced + # ta_t = ta_t.write(time, out) + # return (time + 1, ta_t, state) + + # (unused_0, h_final, unused_2) = control_flow_ops.while_loop( + # cond=lambda time, unused_1, unused_2: time < 3, + # body=body, + # loop_vars=(time_0, ta, state0), + # shape_invariants=(time_0.get_shape(), tensor_shape.unknown_shape(), + # tensor_shape.unknown_shape()), + # parallel_iterations=3) + # vout = h_final.stack() + + # grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5) + # v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0] + # state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0] + # var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0] + + # variables.global_variables_initializer().run() + # state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = ( + # session.run([state0, var, v0, vout, v0_grad, var_grad, state0_grad]) + # ) + # just_v0_grad_t, = session.run([v0_grad]) + + # # state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ] + # # vout = [ v0[0] + var + state[0] | + # # v0[1] + var + state[1] | + # # v0[2] + var + state[2] ] + # # = [ v0[0] + var + state0 | + # # v0[1] + var + state0 + v0[0] | + # # v0[2] + var + state0 + v0[0] + v0[1] ] + # # + # # d(vout[0])/d(v0) = [1 | 0 | 0 ] + # # d(vout[1])/d(v0) = [1 | 1 | 0 ] + # # d(vout[2])/d(v0) = [1 | 1 | 1 ] + # # d(vout)/d(var) = [1 | 1 | 1] + # # d(vout)/d(state0) = [ 1 | 1 | 1 ] + + # state_per_time = np.array( + # [state0_t, state0_t + v0_t[0, :], + # state0_t + v0_t[0, :] + v0_t[1, :]]) + + # # Compare forward prop + # self.assertAllClose(v0_t + var_t + state_per_time, vout_t) + + # # Compare backward prop + # expected_v0_grad_t = np.array([ + # grad_val[0, :] + grad_val[1, :] + grad_val[2, :], + # grad_val[1, :] + grad_val[2, :], grad_val[2, :] + # ]) + + # self.assertAllEqual(expected_v0_grad_t, v0_grad_t) + # self.assertAllEqual(expected_v0_grad_t, just_v0_grad_t) + # self.assertAllClose(grad_val.sum(axis=0), var_grad_t) + # self.assertAllClose(grad_val.sum(axis=0), state0_grad_t) + + # def testWhileLoopWritePackGradients(self): + # self._testWhileLoopWritePackGradients( + # dynamic_size=False, dtype=dtypes.float32) + # # TODO(ebrevdo): re-enable when While supports non-float32 gradients. + # # self._testWhileLoopWritePackGradients( + # # dynamic_size=False, dtype=tf.int64) + + # def testWhileLoopDynamicWritePackGradients(self): + # self._testWhileLoopWritePackGradients( + # dynamic_size=True, dtype=dtypes.float32) + + # def testGradSerialTwoLoops(self): + # with self.test_session(), self.test_scope(): + # num_steps = 100 + # acc = tensor_array_ops.TensorArray( + # dtype=dtypes.float32, + # size=num_steps, + # clear_after_read=False, + # element_shape=tensor_shape.scalar()) + # i = constant_op.constant(0, name="i") + # x = constant_op.constant(2.0, name="x") + + # c = lambda i, acc: i < 5 + + # def b(i, acc): + # x1 = control_flow_ops.cond( + # math_ops.equal(i, 0), lambda: x, + # lambda: math_ops.multiply(acc.read(i - 1), 2.0)) + # return i + 1, acc.write(i, x1) + + # i1, acc1 = control_flow_ops.while_loop(c, b, [i, acc]) + + # z = constant_op.constant(0.0) + + # def fn(i, acc): + # return i + 1, acc.write(i, z) + + # _, acc2 = control_flow_ops.while_loop(lambda i, acc: i < num_steps, fn, + # [i1, acc1]) + + # r = acc2.stack() + # grad = gradients_impl.gradients(r, [x])[0] + # self.assertAllClose(31.0, grad.eval()) + + def testSumOfTwoReadVariablesWithoutRepeatGrad(self): + with self.test_session() as session, self.test_scope(): + a = array_ops.identity( + np.arange( + 3 * 5, dtype=np.float32).reshape(3, 5) + 1) + b = array_ops.identity( + np.arange( + 3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5) + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) + ta = ta.write(0, a, name="write_a") + ta = ta.write(1, b, name="write_b") + c = ( + ta.read( + 0, name="read_a_0") + # a + b + ta.read( + 1, name="read_b_0")) + g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1) + grad_a = gradients_impl.gradients([c], [a], [g0])[0] # d(a+b)/da = 1 + grad_b = gradients_impl.gradients([c], [b], [g0])[0] # d(a+b)/db = 1 + + # Test gradients calculated individually + grad_a_t, = session.run([grad_a]) + self.assertAllEqual(grad_a_t, g0) + + grad_b_t, = session.run([grad_b]) + self.assertAllEqual(grad_b_t, g0) + + # Test gradients calculated jointly + joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b]) + self.assertAllEqual(joint_grad_a_t, g0) + self.assertAllEqual(joint_grad_b_t, g0) + + def testWriteShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c0 = constant_op.constant([4.0, 5.0]) + w0 = ta.write(0, c0) + r0 = w0.read(0) + self.assertAllEqual(c0.get_shape(), r0.get_shape()) + + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c1 = constant_op.constant([6.0, 7.0]) + w1 = w0.write(1, c1) + r0 = w1.read(0) + r1 = w1.read(1) + self.assertAllEqual(c0.get_shape(), r0.get_shape()) + self.assertAllEqual(c1.get_shape(), r1.get_shape()) + + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c2 = constant_op.constant([4.0, 5.0, 6.0]) + with self.assertRaises(ValueError): + w0.write(0, c2) + + def testPartlyUnknownShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=6) + + c0 = array_ops.placeholder(dtypes.float32, [None, None, None, 3]) + w0 = ta.write(0, c0) + r0 = w0.read(0) + self.assertAllEqual([None, None, None, 3], r0.get_shape().as_list()) + + c1 = array_ops.placeholder(dtypes.float32, [None, None, None, 3]) + w1 = w0.write(1, c1) + r1 = w1.read(0) + self.assertAllEqual([None, None, None, 3], r1.get_shape().as_list()) + + # Writing less specific shape (doesn't change type.) + c2 = array_ops.placeholder(dtypes.float32, [None, None, None, None]) + w2 = w1.write(2, c2) + r2 = w2.read(0) + self.assertAllEqual([None, None, None, 3], r2.get_shape().as_list()) + + # Writing more specific shape in one dimension and less specific in + # another. + c3 = array_ops.placeholder(dtypes.float32, [None, None, 2, None]) + w3 = w2.write(3, c3) + r3 = w3.read(0) + self.assertAllEqual([None, None, 2, 3], r3.get_shape().as_list()) + + # Writing partly defined shape using TensorArray.scatter. + c4 = array_ops.placeholder(dtypes.float32, [2, None, 4, 2, 3]) + w4 = w3.scatter([4, 5], c4) + r4 = w4.read(0) + self.assertAllEqual([None, 4, 2, 3], r4.get_shape().as_list()) + + # Writing fully defined shape using TensorArray.split. + c5 = array_ops.placeholder(dtypes.float32, [10, 4, 2, 3]) + w5 = w4.split(c5, constant_op.constant([5, 5])) + r5 = w5.read(0) + self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) + + def _testUnpackShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=0, + infer_shape=True) + value = constant_op.constant( + [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]]) + w0 = ta.unstack(value) + r0 = w0.read(0) + self.assertAllEqual((2,), r0.get_shape()) + + c1 = constant_op.constant([4.0, 5.0]) + w1 = w0.write(3, c1) + r1 = w1.read(0) + self.assertAllEqual(c1.get_shape(), r1.get_shape()) + + c2 = constant_op.constant([4.0, 5.0, 6.0]) + with self.assertRaises(ValueError): + w1.write(4, c2) + + def testUnpackShape(self): + self._testUnpackShape() + + def testSplitShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=0, + infer_shape=True) + value = constant_op.constant([[1.0, -1.0], [2.0, -2.0], [3.0, -3.0]]) + w0 = ta.split(value, [1, 1, 1]) + r0 = w0.read(0) + self.assertAllEqual((1, 2), r0.get_shape()) + + ta1 = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo1", + size=0, + infer_shape=True) + w0 = ta1.split(value, [1, 2]) + r0 = w0.read(0) + self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) + + def testWriteUnknownShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=True) + c0 = array_ops.placeholder(dtypes.float32) + w0 = ta.write(0, c0) + r0 = w0.read(0) + self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) + + def _testGradientWhenNotAllComponentsRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) + x = constant_op.constant([2.0, 3.0]) + w = ta.unstack(x) + r0 = w.read(0) + # calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0). + grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0]) + grad_r0_vals = session.run(grad_r0)[0] + self.assertAllEqual(grad_r0_vals, [1.0, 0.0]) + + def testGradientWhenNotAllComponentsRead(self): + self._testGradientWhenNotAllComponentsRead() + + def _testTensorArrayEvalEmpty(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=0, infer_shape=False) + with self.assertRaisesOpError( + "TensorArray has size zero, but element shape is not fully " + "defined. Currently only static shapes are supported when packing " + "zero-size TensorArrays."): + ta.stack().eval() + + def testTensorArrayEvalEmpty(self): + self._testTensorArrayEvalEmpty() + + def _testTensorArrayEvalEmptyWithDefault(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=0, infer_shape=True) + self.assertEqual(0, ta.size().eval()) + ta = ta.unstack(array_ops.zeros([0, 3, 5])) + packed = ta.stack() + self.assertAllEqual([0, 3, 5], packed.eval().shape) + # Concatenating zero tensors along their first dimension gives a + # first dimension of zero + self.assertAllEqual([0, 5], ta.concat().eval().shape) + + def testTensorArrayEvalEmptyWithDefault(self): + self._testTensorArrayEvalEmptyWithDefault() + + def testTensorArrayScatterReadAndGradients(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=10) + + indices = constant_op.constant([1, 8]) + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + + w = ta.scatter(indices, value) + r0 = w.read(1) + r1 = w.read(8) + + # Test combined gradients + aggregation of read(0) + grad = gradients_impl.gradients( + ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) + read_vals, grad_vals = session.run([[r0, r1], grad]) + + self.assertEqual(len(read_vals), 2) + self.assertEqual(len(grad_vals), 1) + self.assertAllEqual([1.0, -1.0], read_vals[0]) + self.assertAllEqual([10.0, -10.0], read_vals[1]) + self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) + + def testTensorArrayWriteGatherAndGradients(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=10) + + values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)]) + indices = constant_op.constant([1, 8]) + + w = ta.unstack(values) + g = w.gather(indices) + + # Test combined gradients + aggregation of read(0) + grad = gradients_impl.gradients( + ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]]) + g_vals, grad_vals = session.run([[g], grad]) + + # Gradients for 8 of the 10 unread components are zero. + expected_grad = np.zeros((10, 2)) + expected_grad[1] = [2.0, 3.0] + expected_grad[8] = [4.0, 5.0] + + self.assertEqual(len(g_vals), 1) + self.assertEqual(len(grad_vals), 1) + self.assertAllEqual([[1.0, -1.0], [8.0, -8.0]], g_vals[0]) + self.assertAllEqual(expected_grad, grad_vals[0]) + + def testTensorArrayIdentity(self): + with self.test_session() as session, self.test_scope(): + ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, + infer_shape=False) + ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4, + infer_shape=True) + + ta0 = ta0.write(0, 0.) + ta1 = ta1.write(0, 1) + + v0 = resource_variable_ops.ResourceVariable(0) + v1 = resource_variable_ops.ResourceVariable(0) + + with ops.control_dependencies([v0.assign_add(1)]): + ta0 = ta0.identity() + + with ops.control_dependencies([v1.assign_add(1)]): + ta1 = ta1.identity() + + read0 = ta0.read(0) + read1 = ta1.read(0) + + size0 = ta0.size() + size1 = ta1.size() + + # Tests correct properties on new TensorArrays. + self.assertEqual(dtypes.float32, ta0.dtype) + self.assertEqual(dtypes.int32, ta1.dtype) + self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape()) + self.assertEqual(tensor_shape.scalar(), read1.get_shape()) + + variables.global_variables_initializer().run() + + read0_v, read1_v, size0_v, size1_v = session.run( + (read0, read1, size0, size1)) + + # Tests that the control dependencies was added and executed. + self.assertEqual(1, v0.eval()) + self.assertEqual(1, v1.eval()) + + # Tests correct TensorArray. + self.assertEqual(read0_v, 0) + self.assertEqual(read1_v, 1) + self.assertEqual(size0_v, 2) + self.assertEqual(size1_v, 4) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index f7fe186cf8..79549644ea 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -54,16 +54,20 @@ class XLATestCase(test.TestCase): self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU') self.all_tf_types = [ - dtypes.DType(types_pb2.DataType.Value(name)) + dtypes.as_dtype(types_pb2.DataType.Value(name)) for name in FLAGS.types.split(',') ] - self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] - self.int_types = [ - dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer + self.int_tf_types = [ + dtype for dtype in self.all_tf_types if dtype.is_integer ] - self.float_types = [ - dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating + self.float_tf_types = [ + dtype for dtype in self.all_tf_types if dtype.is_floating ] + self.numeric_tf_types = self.int_tf_types + self.float_tf_types + + self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] + self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types] + self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types] self.numeric_types = self.int_types + self.float_types # Parse the manifest file, if any, into a regex identifying tests to diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index c4cbaebb25..36a6c90af4 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -89,6 +89,8 @@ Status BackwardsConstAnalysis(const Graph& g, {"StridedSliceGrad", "end"}, {"StridedSliceGrad", "strides"}, {"Sum", "reduction_indices"}, + {"TensorArrayV3", "size"}, + {"TensorArraySplitV3", "lengths"}, {"Tile", "multiples"}, {"Transpose", "perm"}}; diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 81b065689d..a434c74680 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -55,6 +55,7 @@ tf_kernel_library( "spacetobatch_op.cc", "split_op.cc", "strided_slice_op.cc", + "tensor_array_ops.cc", "tile_ops.cc", "training_ops.cc", "transpose_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index d6897d6e33..620fc84437 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -49,14 +49,15 @@ class ArgOp : public XlaOpKernel { return; } - XlaContext& tc = XlaContext::Get(ctx); - const XlaContext::Argument& arg = tc.args()[index_]; + XlaContext& xc = XlaContext::Get(ctx); + const XlaContext::Argument& arg = xc.args()[index_]; if (arg.is_variable) { - // We use the argument position of the variable input as a unique ID. // TODO(phawkins): this code assumes that variables do not alias. - OP_REQUIRES_OK(ctx, tc.CreateVariable(index_, arg.name, arg.value.type, - arg.value.handle)); - ctx->SetVariableOutput(0, index_); + XlaVariable* var; + OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type, + arg.value.handle, &var)); + var->tensor_array_size = arg.tensor_array_size; + ctx->SetVariableOutput(0, var); } else if (arg.value.is_constant) { ctx->SetConstantOutput(0, arg.value.constant_value); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc new file mode 100644 index 0000000000..de542d55e8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -0,0 +1,538 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA TensorArray operators. + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// Since the element shape is not always provided to the TensorArrayV3 operator, +// we must support lazily initialization of the TensorArray at the time of the +// first write. +// If a TensorArray `var` has not been initialized, constructs storage for the +// TensorArray with elements of `elem_shape`. For both initialized and +// uninitialized TensorArrays, checks that the tensor has a type compatible with +// 'dtype' and shape compatible with 'elem_shape'. +Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, + XlaVariable* var, DataType dtype, + const TensorShape& elem_shape) { + if (var->type != dtype) { + return errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(var->type), + " but op has dtype ", DataTypeString(dtype), "."); + } + + TF_RET_CHECK(var->tensor_array_size >= 0) + << var->name << " size " << var->tensor_array_size; + TensorShape ta_shape; + ta_shape.AddDim(var->tensor_array_size); + ta_shape.AppendShape(elem_shape); + + if (var->value.handle() == 0) { + // TensorArray has not been initialized. + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type); + var->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + } else { + // Checks the elem_shape matches the TensorArray shape. + auto shape_or_status = builder->GetShape(var->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + TensorShape shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + if (ta_shape != shape) { + return errors::InvalidArgument( + "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", + shape.DebugString()); + } + } + return Status::OK(); +} + +// Pads 'x' with 'count' zero indices. 'x' must have 1 element. +xla::ComputationDataHandle PadIndexWithZeros( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + int count) { + xla::ComputationDataHandle zero = builder->ConstantR1({0}); + std::vector xs(count + 1, zero); + xs[0] = builder->Reshape(x, {1}); + return builder->ConcatInDim(xs, 0); +} + +// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the +// relevant slice of 'operand'. +xla::ComputationDataHandle DynamicAddSlice( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand, + const xla::ComputationDataHandle& update, + const gtl::ArraySlice& update_dims, + const xla::ComputationDataHandle& start_indices) { + xla::ComputationDataHandle current = + builder->DynamicSlice(operand, start_indices, update_dims); + xla::ComputationDataHandle sum = builder->Add(current, update); + return builder->DynamicUpdateSlice(operand, sum, start_indices); +} + +class TensorArrayOp : public XlaOpKernel { + public: + explicit TensorArrayOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_shape", &element_shape_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + bool dynamic_size; + OP_REQUIRES_OK(ctx, ctx->GetAttr("dynamic_size", &dynamic_size)); + OP_REQUIRES( + ctx, !dynamic_size, + errors::Unimplemented( + "TensorArrays with dynamic size are not supported by XLA.")); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_array_name", &tensor_array_name_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + int64 size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size)); + OP_REQUIRES(ctx, size >= 0, + errors::InvalidArgument("TensorArray size must be >= 0")); + + xla::ComputationBuilder* b = ctx->builder(); + b->set_die_immediately_on_error(true); + + // Initializes the TensorArray value if we know the element shape. + // Otherwise, defer initialization to the first write. + xla::ComputationDataHandle value; + if (element_shape_.IsFullyDefined()) { + TensorShape shape; + CHECK(element_shape_.AsTensorShape(&shape)); + TensorShape ta_shape; + ta_shape.AddDim(size); + ta_shape.AppendShape(shape); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_); + value = b->Broadcast(zero, ta_shape.dim_sizes()); + } + + XlaContext& xc = XlaContext::Get(ctx); + XlaVariable* var; + string name = strings::StrCat("TensorArray: ", tensor_array_name_); + OP_REQUIRES_OK(ctx, + xc.CreateVariable(-1, std::move(name), dtype_, value, &var)); + var->tensor_array_size = size; + ctx->SetVariableOutput(0, var); + ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); + } + + private: + PartialTensorShape element_shape_; + DataType dtype_; + string tensor_array_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp); + +class TensorArrayWriteOp : public XlaOpKernel { + public: + explicit TensorArrayWriteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape elem_shape = ctx->InputShape(2); + + // Initializes the TensorArray, if the element shape was not known at + // construction time. + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + + xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle index = ctx->Input(1); + xla::ComputationDataHandle value = ctx->Input(2); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = b->Reshape(value, slice_shape.dim_sizes()); + + xla::ComputationDataHandle written = + DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written)); + ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayWriteOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayWriteV3"), TensorArrayWriteOp); + +class TensorArrayReadOp : public XlaOpKernel { + public: + explicit TensorArrayReadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType ta_type; + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); + OP_REQUIRES(ctx, ta_type == dtype_, + errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(ta_type), + " but Op requested dtype ", DataTypeString(dtype_), ".")); + OP_REQUIRES(ctx, ta_shape.dims() >= 1, + errors::InvalidArgument("TensorArray rank must be >= 1")); + + xla::ComputationBuilder* b = ctx->builder(); + + xla::ComputationDataHandle ta; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + xla::ComputationDataHandle index = ctx->Input(1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); + + auto slice_shape = ta_shape.dim_sizes(); + slice_shape[0] = 1LL; + + xla::ComputationDataHandle read = + b->DynamicSlice(ta, start_indices, slice_shape); + + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + ctx->SetOutput(0, b->Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayReadOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayReadV3"), TensorArrayReadOp); + +class TensorArrayGatherOp : public XlaOpKernel { + public: + explicit TensorArrayGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType ta_type; + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); + OP_REQUIRES(ctx, ta_type == dtype_, + errors::InvalidArgument("TensorArray type mismatch")); + OP_REQUIRES(ctx, ta_shape.dims() >= 1, + errors::InvalidArgument("TensorArray rank must be >= 1")); + + const TensorShape indices_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, indices_shape.dims() >= 1, + errors::InvalidArgument("indices must be rank 1")); + const int num_indices = indices_shape.dim_size(0); + auto indices = ctx->Input(1); + + xla::ComputationBuilder* b = ctx->builder(); + + xla::ComputationDataHandle ta; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + + // For each index in `indices`, add the corresponding slice to `slices`. + std::vector slices(num_indices); + for (int i = 0; i < num_indices; ++i) { + // Slices the i-th index out of `indices`, and pads it with zeros in the + // minor dimensions to form an index into the TensorArray storage. + auto index = b->Slice(indices, {i}, {i + 1}); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); + + auto slice_shape = ta_shape.dim_sizes(); + slice_shape[0] = 1LL; + + slices[i] = b->DynamicSlice(ta, start_indices, slice_shape); + } + + xla::ComputationDataHandle gather; + if (slices.empty()) { + auto shape = ta_shape.dim_sizes(); + shape[0] = 0; + gather = b->Broadcast(XlaHelpers::Zero(b, dtype_), shape); + } else { + gather = b->ConcatInDim(slices, 0); + } + ctx->SetOutput(0, gather); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGatherOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayGatherV3"), TensorArrayGatherOp); + +class TensorArrayScatterOp : public XlaOpKernel { + public: + explicit TensorArrayScatterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + const TensorShape value_shape = ctx->InputShape(2); + + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + TensorShape elem_shape = value_shape; + elem_shape.RemoveDim(0); + OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + + const TensorShape indices_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, indices_shape.dims() >= 1, + errors::InvalidArgument("indices must be rank 1")); + const int num_indices = indices_shape.dim_size(0); + const xla::ComputationDataHandle indices = ctx->Input(1); + + xla::ComputationDataHandle ta = var->value; + const xla::ComputationDataHandle value = ctx->Input(2); + + auto slice_dims = value_shape.dim_sizes(); + slice_dims[0] = 1LL; + + std::vector value_starts(value_shape.dims(), 0); + auto value_ends = value_shape.dim_sizes(); + + // For every (index, value) pair, update the corresponding TensorArray + // storage. + for (int i = 0; i < num_indices; ++i) { + // Slice out part of the value. + value_starts[0] = i; + value_ends[0] = i + 1; + auto slice = b->Slice(value, value_starts, value_ends); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto index = b->Slice(indices, {i}, {i + 1}); + auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + } + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayScatterOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayScatterV3"), TensorArrayScatterOp); + +class TensorArrayConcatOp : public XlaOpKernel { + public: + explicit TensorArrayConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType ta_type; + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); + OP_REQUIRES(ctx, ta_type == dtype_, + errors::InvalidArgument("TensorArray type mismatch")); + OP_REQUIRES(ctx, ta_shape.dims() >= 1, + errors::InvalidArgument("TensorArray rank must be >= 1")); + + xla::ComputationBuilder* b = ctx->builder(); + + xla::ComputationDataHandle ta; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + + auto ta_dims = ta_shape.dim_sizes(); + std::vector shape(ta_dims.begin() + 1, ta_dims.end()); + shape[0] *= ta_shape.dim_size(0); + ctx->SetOutput(0, b->Reshape(ta, shape)); + + Tensor lengths(DT_INT64, {ta_dims[0]}); + auto lengths_vec = lengths.vec(); + for (int i = 0; i < ta_dims[0]; ++i) { + lengths_vec(i) = ta_dims[1]; + } + ctx->SetConstantOutput(1, lengths); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayConcatOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayConcatV3"), TensorArrayConcatOp); + +class TensorArraySplitOp : public XlaOpKernel { + public: + explicit TensorArraySplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + std::vector lengths; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths)); + + int64 length = 0; + if (!lengths.empty()) { + length = lengths[0]; + for (int i = 1; i < lengths.size(); ++i) { + OP_REQUIRES(ctx, lengths[i] == length, + errors::InvalidArgument("lengths must be equal: ", length, + " vs. ", lengths[i])); + } + } + + TensorShape value_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, value_shape.dims() >= 1, + errors::InvalidArgument("value must have rank >= 1, got ", + value_shape.DebugString())); + TensorShape elem_shape = value_shape; + elem_shape.set_dim(0, length); + + xla::ComputationBuilder* b = ctx->builder(); + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + xla::ComputationDataHandle ta = var->value; + + TensorShape ta_shape; + ta_shape.AddDim(var->tensor_array_size); + ta_shape.AppendShape(elem_shape); + + OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size, + errors::InvalidArgument( + "TensorArray's size is not equal to the size of lengths (", + lengths.size(), " vs. ", var->tensor_array_size, ")")); + + const xla::ComputationDataHandle value = ctx->Input(1); + + OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(), + errors::InvalidArgument("mismatched element count ", + value_shape.DebugString(), " vs. ", + ta_shape.DebugString())); + + ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + + ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); +}; + +REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp); + +class TensorArraySizeOp : public XlaOpKernel { + public: + explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + Tensor size_tensor(DT_INT32, {}); + size_tensor.scalar()() = static_cast(var->tensor_array_size); + ctx->SetConstantOutput(0, size_tensor); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySizeOp); +}; + +REGISTER_XLA_OP(Name("TensorArraySizeV3"), TensorArraySizeOp); + +class TensorArrayGradOp : public XlaOpKernel { + public: + explicit TensorArrayGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("source", &source_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + + DataType ta_type; + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); + OP_REQUIRES(ctx, ta_shape.dims() >= 1, + errors::InvalidArgument("TensorArray rank must be >= 1")); + + // Finds or looks up the corresponding gradient TensorArray, which stores + // gradients computed during backpropagation. + XlaVariable*& gradient = var->tensor_array_gradient[source_]; + if (!gradient) { + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type); + xla::ComputationDataHandle value = + b->Broadcast(zero, ta_shape.dim_sizes()); + + XlaContext& xc = XlaContext::Get(ctx); + string name = strings::StrCat("TensorArrayGrad: ", var->name); + OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type, + value, &gradient)); + gradient->tensor_array_size = var->tensor_array_size; + } + + ctx->SetVariableOutput(0, gradient); + ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); + } + + private: + string source_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 362a101895..1d0098591e 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -119,6 +119,4 @@ void XlaExpression::set_constant_value(Tensor value) { constant_value_ = std::move(value); } -void XlaExpression::set_variable_id(int id) { variable_id_ = id; } - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 1ee96e5e6c..75630bee39 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -64,6 +64,39 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; +struct XlaVariable { + // If this variable is visible externally, what was its argument number? + int arg_num = -1; + + // A descriptive name for the variable, used in error messages. + string name; + + // Current type and value of the variable. Uninitialized variables are + // represented by a default (zero) handle and type DT_INVALID. + // While the type of a variable is notionally fixed during execution, when + // a variable is first initialized we do not yet know its type, so we keep + // track of its type dynamically. + DataType type = DT_INVALID; + xla::ComputationDataHandle value; + + // Value of the variable at computation entry. Used to detect which + // variables have new values that need to be written back. + xla::ComputationDataHandle initial_value; + + // We treat TensorArrays as a Variable with some extra metadata. + + // 'tensor_array_size' stores the expected size of the TensorArray. We need + // to store this since sometimes TensorArrays must be initialized lazily since + // we do not know the element shape at construction time. + int64 tensor_array_size = -1; + + // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes + // to an XlaVariable containing the gradient TensorArrays. We store a pointer + // here since there should only be one gradient TensorArray per 'source' + // string, irrespective of the number of calls to TensorArrayGrad. + std::unordered_map tensor_array_gradient; +}; + // A XlaExpression wraps an XLA computation. Each Tensor on an // XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor // matches the shape of the subcomputation in the ComputationDataHandle. Each @@ -82,8 +115,8 @@ class XlaExpression { bool has_constant_value() const { return has_constant_value_; } const Tensor& constant_value() const { return constant_value_; } - void set_variable_id(int id); - int variable_id() const { return variable_id_; } + void set_variable(XlaVariable* variable) { variable_ = variable; } + XlaVariable* variable() const { return variable_; } private: // The XLA handle of the expression's computation. @@ -95,7 +128,7 @@ class XlaExpression { bool has_constant_value_ = false; Tensor constant_value_; - int variable_id_ = -1; + XlaVariable* variable_ = nullptr; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 820e8dd56f..580ce3d802 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -59,8 +59,9 @@ Status CheckSignature(const DataTypeVector& types, bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, type, shape, name) != - std::tie(other.kind, other.type, other.shape, other.name)) { + if (std::tie(kind, type, shape, name, tensor_array_size) != + std::tie(other.kind, other.type, other.shape, other.name, + other.tensor_array_size)) { return false; } if (constant_value.shape() != other.constant_value.shape()) { @@ -264,8 +265,9 @@ Status BuildArguments(const std::vector& args, switch (args[i].kind) { case XlaCompiler::Argument::kVariable: variables.push_back(i); - context_arg.value.is_constant = false; context_arg.is_variable = true; + context_arg.value.is_constant = false; + context_arg.tensor_array_size = args[i].tensor_array_size; break; case XlaCompiler::Argument::kParameter: parameters.push_back(i); @@ -274,6 +276,7 @@ Status BuildArguments(const std::vector& args, case XlaCompiler::Argument::kUninitializedVariable: context_arg.is_variable = true; context_arg.value.is_constant = true; + context_arg.tensor_array_size = args[i].tensor_array_size; break; case XlaCompiler::Argument::kConstant: context_arg.value.is_constant = true; @@ -337,7 +340,7 @@ Status BuildArguments(const std::vector& args, // type of the final output. Status BuildComputation( const std::vector& retvals, - const std::unordered_map& variable_map, + const std::vector>& variables, bool has_side_effects, bool return_updated_values_for_all_variables, xla::ComputationBuilder* builder, xla::Computation* computation, int* num_nonconst_outputs, @@ -352,27 +355,27 @@ Status BuildComputation( *num_nonconst_outputs = elems.size(); // Add return values for variables whose values have changed. - std::vector> variables; - variables.reserve(variable_map.size()); - for (const auto& entry : variable_map) { - variables.emplace_back(entry.first, &entry.second); + std::vector arg_vars; + arg_vars.reserve(variables.size()); + for (const auto& var : variables) { + if (var->arg_num >= 0) { + arg_vars.push_back(var.get()); + } } - std::sort(variables.begin(), variables.end(), - [](const std::pair& a, - const std::pair& b) { - return a.first < b.first; + std::sort(arg_vars.begin(), arg_vars.end(), + [](const XlaVariable* a, const XlaVariable* b) { + return a->arg_num < b->arg_num; }); - for (const auto& entry : variables) { - bool modified = - entry.second->value.handle() != entry.second->initial_value.handle(); + for (const XlaVariable* var : arg_vars) { + bool modified = var->value.handle() != var->initial_value.handle(); if (return_updated_values_for_all_variables || modified) { variable_updates->emplace_back(); XlaCompiler::VariableUpdate& update = variable_updates->back(); - update.input_index = entry.first; - update.type = entry.second->type; + update.input_index = var->arg_num; + update.type = var->type; update.modified = modified; - elems.push_back(entry.second->value); + elems.push_back(var->value); } } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 15f723ad78..1314305532 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -114,6 +114,10 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; + // For a kVariable or kUninitializedVariable corresponding to a TensorArray, + // what is the tensor array's declared size? + int64 tensor_array_size = -1; + bool operator==(const Argument& other) const; }; diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 3592680303..4440b53069 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -53,6 +54,10 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context"; return *context; } +/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) { + return Get(ctx->op_kernel_context()); +} + void XlaContext::set_args(std::vector args) { args_ = std::move(args); } @@ -124,29 +129,19 @@ void XlaContext::AddSideEffects() { xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateVariable(int variable_id, string name, DataType type, - const xla::ComputationDataHandle& handle) { - auto result = variables_.emplace(variable_id, Variable()); - if (!result.second) { - return errors::InvalidArgument("Duplicate ID ", variable_id, - " for variable ", name); - } - Variable& var = result.first->second; +Status XlaContext::CreateVariable(int arg_num, string name, DataType type, + const xla::ComputationDataHandle& handle, + XlaVariable** variable) { + variables_.emplace_back(new XlaVariable); + *variable = variables_.back().get(); + XlaVariable& var = **variable; + var.arg_num = arg_num; var.name = std::move(name); var.type = type; var.initial_value = var.value = handle; return Status::OK(); } -Status XlaContext::GetVariable(int variable_id, Variable** variable) { - auto it = variables_.find(variable_id); - if (it == variables_.end()) { - return errors::InvalidArgument("Unknown variable ID ", variable_id); - } - *variable = &it->second; - return Status::OK(); -} - const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [this, type] { const string type_string = DataTypeString(type); diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 657ead5391..3978baaf63 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -31,6 +30,8 @@ limitations under the License. namespace tensorflow { +class XlaOpKernelContext; + // The XlaContext is the data structure that holds the state of an XLA // compilation, that is accessible from OpKernelContexts when compiling a // subgraph of Ops using XLA. @@ -55,16 +56,16 @@ class XlaContext : public ResourceBase { string name; // Is this a variable? - bool is_variable; + bool is_variable = false; HandleOrConstant value; + + int64 tensor_array_size = -1; }; // Retrieves the XlaContext of the current compilation. static XlaContext& Get(const OpKernelContext* ctx); - static XlaContext& Get(const XlaOpKernelContext* ctx) { - return Get(ctx->op_kernel_context()); - } + static XlaContext& Get(const XlaOpKernelContext* ctx); // Creates a new XlaContext. XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, @@ -105,33 +106,16 @@ class XlaContext : public ResourceBase { bool has_side_effects() const { return has_side_effects_; } - struct Variable { - // A descriptive name for the variable, used in error messages. - string name; - - // Current type and value of the variable. Uninitialized variables are - // represented by a default (zero) handle and type DT_INVALID. - // While the type of a variable is notionally fixed during execution, when - // a variable is first initialized we do not yet know its type, so we keep - // track of its type dynamically. - DataType type = DT_INVALID; - xla::ComputationDataHandle value; - - // Value of the variable at computation entry. Used to detect which - // variables have new values that need to be written back. - xla::ComputationDataHandle initial_value; - }; - // Creates a variable with variable `variable_id` and initial type `type` and // value `handle`. `name` is a descriptive name for use in error messages. // Fails if the variable already exists. - Status CreateVariable(int variable_id, string name, DataType type, - const xla::ComputationDataHandle& handle); + Status CreateVariable(int arg_num, string name, DataType type, + const xla::ComputationDataHandle& handle, + XlaVariable** variable); - // Retrieves variable `variable_id`. Fails if the variable does not exist. - Status GetVariable(int variable_id, Variable** variable); - - const std::unordered_map& variables() { return variables_; } + const std::vector>& variables() { + return variables_; + } // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -182,8 +166,8 @@ class XlaContext : public ResourceBase { // Does the computation have side effects, i.e., Send() calls? bool has_side_effects_ = false; - // Map from variable ID to the current value of each variable. - std::unordered_map variables_; + // Holds ownership of variables. The variables are not ordered. + std::vector> variables_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 4de69ee43c..3272b1efa1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -38,7 +38,8 @@ xla::ComputationBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().handle() != 0 || expression->variable_id() >= 0); + CHECK(expression->handle().handle() != 0 || + expression->variable() != nullptr); VLOG(1) << "Fetched T" << expression->handle().handle(); return expression; } @@ -251,11 +252,8 @@ Status XlaOpKernelContext::ReadVariableInput( int index, xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - int variable_id = expression->variable_id(); - - XlaContext::Variable* variable; - XlaContext& context = XlaContext::Get(this); - TF_RETURN_IF_ERROR(context.GetVariable(variable_id, &variable)); + XlaVariable* variable = expression->variable(); + TF_RET_CHECK(variable != nullptr); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -267,11 +265,8 @@ Status XlaOpKernelContext::ReadVariableInput( string XlaOpKernelContext::VariableDebugString(int index) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - int variable_id = expression->variable_id(); - - XlaContext::Variable* variable; - XlaContext& context = XlaContext::Get(this); - if (!context.GetVariable(variable_id, &variable).ok()) { + XlaVariable* variable = expression->variable(); + if (!variable) { return ""; } return variable->name; @@ -281,11 +276,8 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - int variable_id = expression->variable_id(); - - XlaContext::Variable* variable; - XlaContext& context = XlaContext::Get(this); - TF_RETURN_IF_ERROR(context.GetVariable(variable_id, &variable)); + XlaVariable* variable = expression->variable(); + TF_RET_CHECK(variable != nullptr); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -345,14 +337,22 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } -void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) { +void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) { Tensor* output = nullptr; // The shape of the output tensor is the shape of the variable resource // (i.e., a scalar), not the shape of the variable's value. OP_REQUIRES_OK(context_, context_->allocate_output(index, TensorShape(), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_variable_id(variable_id); + expression->set_variable(variable); +} + +Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) { + const XlaExpression* expression = + CastExpressionFromTensor(context_->input(index)); + TF_RET_CHECK(expression->variable() != nullptr); + *variable = expression->variable(); + return Status::OK(); } Status XlaOpKernelContext::AssignVariable( @@ -362,9 +362,8 @@ Status XlaOpKernelContext::AssignVariable( const XlaExpression* expression = CastExpressionFromTensor(context_->input(index)); - XlaContext& context = XlaContext::Get(this); - XlaContext::Variable* variable; - TF_RETURN_IF_ERROR(context.GetVariable(expression->variable_id(), &variable)); + XlaVariable* variable = expression->variable(); + TF_RET_CHECK(variable != nullptr); if (!((variable->type == DT_INVALID && type != DT_INVALID) || (variable->type == type))) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 0a8a928418..a25774c3a6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -157,15 +157,18 @@ class XlaOpKernelContext { // 'index'. Status ReadVariableInput(int index, xla::ComputationDataHandle* value); - // Sets output 'index' to be a reference to variable 'variable_id'. Used - // to propagate resource variables through the compilation. - void SetVariableOutput(int index, int variable_id); - // Assigns the value `handle` to the variable referenced by input // `variable_index`. Marks the operator as having side effects. Status AssignVariable(int variable_index, DataType type, const xla::ComputationDataHandle& handle); + // Sets '*variable' to the variable associated with input `index`. + Status GetVariableInput(int index, XlaVariable** variable); + + // Sets output 'index' to be a reference to variable 'variable'. Used + // to propagate resource variables through the compilation. + void SetVariableOutput(int index, XlaVariable* variable); + // Returns a human-readable debug string describing 'variable_index'. string VariableDebugString(int variable_index); diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index f007581e8d..97d0800d12 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1221,7 +1221,7 @@ of the forward TensorArray is known when this operation is called. TensorArray gradient calls use an accumulator TensorArray object. If multiple gradients are calculated and run in the same session, the multiple -gradient nodes may accidentally flow throuth the same accumulator TensorArray. +gradient nodes may accidentally flow through the same accumulator TensorArray. This double counts and generally breaks the TensorArray gradient flow. The solution is to identify which gradient call this particular -- cgit v1.2.3 From 0a9d2dac0844d1bfe11c8a21d3b2598793564b95 Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Wed, 7 Jun 2017 15:00:19 -0700 Subject: Add a util function in virtual placer to return canonicalized device string, which can be used to fix the node's device field before passing them to the maxcut algorithm. PiperOrigin-RevId: 158322753 --- tensorflow/core/grappler/costs/virtual_placer.cc | 24 ++++++++++++++-------- tensorflow/core/grappler/costs/virtual_placer.h | 5 +++++ .../core/grappler/costs/virtual_placer_test.cc | 14 +++++++++++++ 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/grappler/costs/virtual_placer.cc b/tensorflow/core/grappler/costs/virtual_placer.cc index ff6eff0249..e06774fc41 100644 --- a/tensorflow/core/grappler/costs/virtual_placer.cc +++ b/tensorflow/core/grappler/costs/virtual_placer.cc @@ -36,11 +36,20 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) : has_gpu_(false) { } const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const { + string device = get_canonical_device_name(node); + if (device.empty()) { + return unknown_device_; + } + auto it = devices_.find(device); + DCHECK(it != devices_.end()); + return it->second; +} + +string VirtualPlacer::get_canonical_device_name(const NodeDef& node) const { string device; if (!node.device().empty()) { - auto it = devices_.find(node.device()); - if (it != devices_.end()) { - return it->second; + if (devices_.find(node.device()) != devices_.end()) { + return node.device(); } DeviceNameUtils::ParsedName parsed_name; bool parsed = DeviceNameUtils::ParseFullName(node.device(), &parsed_name); @@ -57,7 +66,7 @@ const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const { } } if (!parsed) { - return unknown_device_; + return ""; } else { device = strings::StrCat( "/job:", parsed_name.job, "/replica:", parsed_name.replica, @@ -71,11 +80,10 @@ const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const { device = "/job:localhost/replica:0/task:0/cpu:0"; } } - auto it = devices_.find(device); - if (it == devices_.end()) { - return unknown_device_; + if (devices_.find(device) == devices_.end()) { + return ""; } - return it->second; + return device; } } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/virtual_placer.h b/tensorflow/core/grappler/costs/virtual_placer.h index 40cd64e37c..85bd502c67 100644 --- a/tensorflow/core/grappler/costs/virtual_placer.h +++ b/tensorflow/core/grappler/costs/virtual_placer.h @@ -33,6 +33,11 @@ class VirtualPlacer { const DeviceProperties& get_device(const NodeDef& node) const; + // Returns canonical device name that has a corresponding device in the + // cluster; returns empty string if no device found or the node.device() can + // not be parsed. + string get_canonical_device_name(const NodeDef& node) const; + private: std::unordered_map devices_; bool has_gpu_; diff --git a/tensorflow/core/grappler/costs/virtual_placer_test.cc b/tensorflow/core/grappler/costs/virtual_placer_test.cc index 037c52713d..bc8d0e38ba 100644 --- a/tensorflow/core/grappler/costs/virtual_placer_test.cc +++ b/tensorflow/core/grappler/costs/virtual_placer_test.cc @@ -37,12 +37,18 @@ TEST(VirtualPlacerTest, LocalDevices) { NodeDef node; node.set_op("Conv2D"); EXPECT_EQ("GPU", placer.get_device(node).type()); + EXPECT_EQ("/job:localhost/replica:0/task:0/gpu:0", + placer.get_canonical_device_name(node)); node.set_device("CPU"); EXPECT_EQ("CPU", placer.get_device(node).type()); + EXPECT_EQ("/job:localhost/replica:0/task:0/cpu:0", + placer.get_canonical_device_name(node)); node.set_device("GPU:0"); EXPECT_EQ("GPU", placer.get_device(node).type()); + EXPECT_EQ("/job:localhost/replica:0/task:0/gpu:0", + placer.get_canonical_device_name(node)); } TEST(VirtualPlacerTest, RemoteDevices) { @@ -60,24 +66,32 @@ TEST(VirtualPlacerTest, RemoteDevices) { node.set_op("Conv2D"); // There is no local device available EXPECT_EQ("UNKNOWN", placer.get_device(node).type()); + EXPECT_EQ("", placer.get_canonical_device_name(node)); node.set_device("/job:my_job/replica:0/task:0/cpu:0"); EXPECT_EQ("CPU", placer.get_device(node).type()); + EXPECT_EQ("/job:my_job/replica:0/task:0/cpu:0", + placer.get_canonical_device_name(node)); node.set_device("/job:my_job/replica:0/task:0/gpu:0"); EXPECT_EQ("GPU", placer.get_device(node).type()); + EXPECT_EQ("/job:my_job/replica:0/task:0/gpu:0", + placer.get_canonical_device_name(node)); // There is no local CPU available node.set_device("CPU"); EXPECT_EQ("UNKNOWN", placer.get_device(node).type()); + EXPECT_EQ("", placer.get_canonical_device_name(node)); node.set_device("GPU:0"); // There is no local GPU available EXPECT_EQ("UNKNOWN", placer.get_device(node).type()); + EXPECT_EQ("", placer.get_canonical_device_name(node)); // This isn't a valid name node.set_device("/job:my_job/replica:0/task:0"); EXPECT_EQ("UNKNOWN", placer.get_device(node).type()); + EXPECT_EQ("", placer.get_canonical_device_name(node)); } } // end namespace grappler -- cgit v1.2.3 From 4e529f0f163e2a7b42dbcac78b41de64ad7170b0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 15:17:52 -0700 Subject: Update ops-related pbtxt files. PiperOrigin-RevId: 158325293 --- tensorflow/core/ops/ops.pbtxt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 31d47c1ab3..01f79c4260 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -26621,7 +26621,7 @@ op { description: "The gradient source string, used to decide which gradient TensorArray\nto return." } summary: "Creates a TensorArray for storing the gradients of values in the given handle." - description: "If the given TensorArray gradient already exists, returns a reference to it.\n\nLocks the size of the original TensorArray by disabling its dynamic size flag.\n\n**A note about the input flow_in:**\n\nThe handle flow_in forces the execution of the gradient lookup to occur\nonly after certain other operations have occurred. For example, when\nthe forward TensorArray is dynamically sized, writes to this TensorArray\nmay resize the object. The gradient TensorArray is statically sized based\non the size of the forward TensorArray when this operation executes.\nFurthermore, the size of the forward TensorArray is frozen by this call.\nAs a result, the flow is used to ensure that the call to generate the gradient\nTensorArray only happens after all writes are executed.\n\nIn the case of dynamically sized TensorArrays, gradient computation should\nonly be performed on read operations that have themselves been chained via\nflow to occur only after all writes have executed. That way the final size\nof the forward TensorArray is known when this operation is called.\n\n**A note about the source attribute:**\n\nTensorArray gradient calls use an accumulator TensorArray object. If\nmultiple gradients are calculated and run in the same session, the multiple\ngradient nodes may accidentally flow throuth the same accumulator TensorArray.\nThis double counts and generally breaks the TensorArray gradient flow.\n\nThe solution is to identify which gradient call this particular\nTensorArray gradient is being called in. This is performed by identifying\na unique string (e.g. \"gradients\", \"gradients_1\", ...) from the input\ngradient Tensor\'s name. This string is used as a suffix when creating\nthe TensorArray gradient object here (the attribute `source`).\n\nThe attribute `source` is added as a suffix to the forward TensorArray\'s\nname when performing the creation / lookup, so that each separate gradient\ncalculation gets its own TensorArray accumulator." + description: "If the given TensorArray gradient already exists, returns a reference to it.\n\nLocks the size of the original TensorArray by disabling its dynamic size flag.\n\n**A note about the input flow_in:**\n\nThe handle flow_in forces the execution of the gradient lookup to occur\nonly after certain other operations have occurred. For example, when\nthe forward TensorArray is dynamically sized, writes to this TensorArray\nmay resize the object. The gradient TensorArray is statically sized based\non the size of the forward TensorArray when this operation executes.\nFurthermore, the size of the forward TensorArray is frozen by this call.\nAs a result, the flow is used to ensure that the call to generate the gradient\nTensorArray only happens after all writes are executed.\n\nIn the case of dynamically sized TensorArrays, gradient computation should\nonly be performed on read operations that have themselves been chained via\nflow to occur only after all writes have executed. That way the final size\nof the forward TensorArray is known when this operation is called.\n\n**A note about the source attribute:**\n\nTensorArray gradient calls use an accumulator TensorArray object. If\nmultiple gradients are calculated and run in the same session, the multiple\ngradient nodes may accidentally flow through the same accumulator TensorArray.\nThis double counts and generally breaks the TensorArray gradient flow.\n\nThe solution is to identify which gradient call this particular\nTensorArray gradient is being called in. This is performed by identifying\na unique string (e.g. \"gradients\", \"gradients_1\", ...) from the input\ngradient Tensor\'s name. This string is used as a suffix when creating\nthe TensorArray gradient object here (the attribute `source`).\n\nThe attribute `source` is added as a suffix to the forward TensorArray\'s\nname when performing the creation / lookup, so that each separate gradient\ncalculation gets its own TensorArray accumulator." is_stateful: true } op { -- cgit v1.2.3 From 379aa9911f97b4bb2b493e31ad7227e63907508b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Jun 2017 15:22:24 -0700 Subject: Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 158325855 --- tensorflow/go/op/wrappers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index c4af3a60a8..885aff8bca 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -3797,7 +3797,7 @@ func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value t // // TensorArray gradient calls use an accumulator TensorArray object. If // multiple gradients are calculated and run in the same session, the multiple -// gradient nodes may accidentally flow throuth the same accumulator TensorArray. +// gradient nodes may accidentally flow through the same accumulator TensorArray. // This double counts and generally breaks the TensorArray gradient flow. // // The solution is to identify which gradient call this particular -- cgit v1.2.3 From b94540e6f7ea130674b8122ec192c3d9a07a6752 Mon Sep 17 00:00:00 2001 From: Toby Boyd Date: Wed, 7 Jun 2017 15:28:20 -0700 Subject: tf.layers.conv2d use_bias=True to use nn.bias_add PiperOrigin-RevId: 158326493 --- tensorflow/python/layers/convolutional.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index f026e5ac45..fdf1b134b9 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -159,14 +159,12 @@ class _Conv(base.Layer): if self.bias is not None: if self.data_format == 'channels_first': - # bias_add only supports NHWC. - # TODO(fchollet): remove this when `bias_add` is feature-complete. if self.rank == 1: + # nn.bias_add does not accept a 1D input tensor. bias = array_ops.reshape(self.bias, (1, self.filters, 1)) outputs += bias if self.rank == 2: - bias = array_ops.reshape(self.bias, (1, self.filters, 1, 1)) - outputs += bias + outputs = nn.bias_add(outputs, self.bias, data_format='NCHW') if self.rank == 3: # As of Mar 2017, direct addition is significantly slower than # bias_add when computing gradients. To use bias_add, we collapse Z -- cgit v1.2.3 From beeaade460a125975b6fe34d23ff0465183f8b4a Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Wed, 7 Jun 2017 15:56:12 -0700 Subject: Resubmit a reverted change. Original description: [XLA] Enable HloEvaluator for constant folding, also merged a few operations from hlo_constant_folding to hlo_evaluator. Additionally: - In ShapeUtil::ForEachIndex: * fix a bug where visitor is called when the shape has zero elements (e.g., F32{1,0}) * added test case for ForEachIndex. - In HloEvaluator: * Instead of copying and caching a Constant instruction, return the literal directly if the instruction is constant. * Fix an issue where TUPLE and OPAQUE primitives are not keyed in the templated typed_visitor. * Use (fixed) LiteralUtil::Populate to populate resulting literal, fixes the preexisting bug in the evaluator where R0 and shape with zero size dimensions are not handled. * Refactor ElementWiseUnaryOp and HandleCompare to be templatized on the operand's type. * Refactor IsFinite to be top level since it is only applicable to floats and the return type is always boolean. * Change from std::remainder to std::fmod for kRemainder to be compliant with existing XLA behavior. * Change from std::max and std::min to std::fmax and std::fmin to handle NaNs. * Minor comments fix. PiperOrigin-RevId: 158330052 --- tensorflow/compiler/xla/literal_util.h | 1 + tensorflow/compiler/xla/literal_util_test.cc | 2 + tensorflow/compiler/xla/service/BUILD | 4 + .../compiler/xla/service/hlo_constant_folding.cc | 237 ++-------- tensorflow/compiler/xla/service/hlo_evaluator.cc | 498 +++++++++++++++------ tensorflow/compiler/xla/service/hlo_evaluator.h | 46 +- .../compiler/xla/service/hlo_evaluator_test.cc | 58 ++- tensorflow/compiler/xla/shape_util.cc | 46 +- tensorflow/compiler/xla/shape_util_test.cc | 28 ++ tensorflow/compiler/xla/tests/literal_test_util.h | 1 + 10 files changed, 555 insertions(+), 366 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 5c180a497d..9a426ad195 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -1740,6 +1740,7 @@ Status Literal::Populate( stride_config.dimensions, stride_config.step, init_function); } else { + // For scalars. data.at(0) = generator({}); } return Status::OK(); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index aaab36dc8c..50ea286b53 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -806,7 +806,9 @@ TEST_F(LiteralUtilTest, Populate) { std::vector layout; } populate_data[] = { {{}, {}}, + {{0}, {0}}, {{16}, {0}}, + {{2, 0}, {1, 0}}, {{4, 16}, {1, 0}}, {{21, 12}, {0, 1}}, {{6, 11, 17}, {2, 0, 1}}, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ecb0d2cb23..71629763da 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -106,10 +106,12 @@ cc_test( ":hlo_evaluator", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test_main", ], @@ -1492,7 +1494,9 @@ cc_library( hdrs = ["hlo_constant_folding.h"], deps = [ ":hlo", + ":hlo_evaluator", ":hlo_pass", + ":hlo_query", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index cb0a99d773..762ceebf39 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -24,230 +24,57 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { -namespace { - -template -static std::unique_ptr ConvertIfTypesMatch( - const Literal& src_literal) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< - typename primitive_util::PrimitiveTypeToNative::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); -} - -template -static std::unique_ptr ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -static std::unique_ptr ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -} // namespace - -// ConstantFolderVisitor traverses the HLO computation and reduces certain -// constant graph sections, to literals. -class ConstantFolderVisitor : public DfsHloVisitorWithDefault { - public: - // Default visitor action is to do nothing and return OK. - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) override; - - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; - - Status HandleReshape(HloInstruction* reshape) override; - - Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; - - Status HandleTranspose(HloInstruction* transpose) override; - - // Returns whether a constant folding operation has occurred. - const bool changed() const { return changed_; } - - // Runs the visitor on a computation and returns whether any changes were - // performed. - static StatusOr Run(HloComputation* computation); - - private: - ConstantFolderVisitor() = default; - - // Replaces the existing HLO instruction old_instruction, with a literal, - // and marks the optimizer status as changed. - // Returns the Status representing the result of the replace operation. - Status ReplaceWithConstant(HloInstruction* old_instruction, - std::unique_ptr literal) { - TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( - old_instruction, HloInstruction::CreateConstant(std::move(literal)))); - changed_ = true; - return Status::OK(); - } - - // Whether any constant folding operations have occurred. - bool changed_ = false; -}; - -StatusOr ConstantFolderVisitor::Run(HloComputation* computation) { - ConstantFolderVisitor visitor; - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); - return visitor.changed(); -} StatusOr HloConstantFolding::Run(HloModule* module) { + auto evaluator = MakeUnique(); + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); bool changed = false; - for (auto& comp : module->computations()) { - TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get())); - changed = changed || result; - } - XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); - return changed; -} - -Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) { - if (reshape->operand(0)->opcode() == HloOpcode::kConstant) { - TF_ASSIGN_OR_RETURN( - auto reshaped_literal, - LiteralUtil::Reshape(reshape->operand(0)->literal(), - AsInt64Slice(reshape->shape().dimensions()))); - return ReplaceWithConstant(reshape, std::move(reshaped_literal)); - } - return Status::OK(); -} -Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) { - if (transpose->operand(0)->opcode() == HloOpcode::kConstant) { - auto transposed_literal = LiteralUtil::Transpose( - transpose->operand(0)->literal(), transpose->dimensions()); - return ReplaceWithConstant(transpose, std::move(transposed_literal)); - } - return Status::OK(); -} + for (auto& computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + // Skip dead code. + if (instruction->user_count() == 0 && + computation->root_instruction() != instruction) { + continue; + } + // Skip Constant and Parameter operation. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant) { + continue; + } + // Skip instructions with non-constant operands. + if (!hlo_query::AllOperandsAreConstants(*instruction)) { + continue; + } -Status ConstantFolderVisitor::HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) { - if (operands[0]->opcode() == HloOpcode::kConstant) { - // If all the operands of a concatenate are constant, fold them into a - // single constant tensor. - // The result concatenate dimension is going to be the sum of all the - // concatenate dimensions of the arrays taking part of the operation. - int64 concat_dim = concatenate->dimensions()[0]; - const Shape& reference_shape = operands[0]->shape(); - CHECK(!ShapeUtil::IsTuple(reference_shape)); - int64 rank = ShapeUtil::Rank(reference_shape); - std::vector concat_dimensions(reference_shape.dimensions().begin(), - reference_shape.dimensions().end()); - if (concat_dim < 0) { - concat_dim += rank; - } - for (int64 i = 1; i < operands.size(); ++i) { - const Shape& operand_shape = operands[i]->shape(); - CHECK(!ShapeUtil::IsTuple(operand_shape)); - if (operands[i]->opcode() != HloOpcode::kConstant) { - return Status::OK(); + std::unique_ptr result = evaluator->TryEvaluate(instruction); + // Currently we skip unimplemented operations. + // TODO(b/35975797): Fold constant computations for more operations. + if (result == nullptr) { + VLOG(2) << "Constant folding failed for instruction: " + << instruction->ToString(); + continue; } - // Accumulate the concat dimension from all tensors taking part to the - // operation. - concat_dimensions[concat_dim] += - ShapeUtil::GetDimension(operand_shape, concat_dim); - } - auto literal = LiteralUtil::CreateFromDimensions( - reference_shape.element_type(), concat_dimensions); - std::vector source_indices(rank, 0); - std::vector dest_indices(concat_dimensions.size(), 0); - for (auto operand : operands) { - const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - operand->literal(), source_indices, literal.get(), dest_indices, - AsInt64Slice(operand_shape.dimensions()))); - dest_indices[concat_dim] += - ShapeUtil::GetDimension(operand_shape, concat_dim); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, HloInstruction::CreateConstant(std::move(result)))); + changed = true; } - return ReplaceWithConstant(concatenate, std::move(literal)); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice, - HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kConstant) { - const Shape& shape = slice->shape(); - auto literal = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - std::vector dest_indices(slice->slice_starts().size(), 0); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - operand->literal(), slice->slice_starts(), literal.get(), dest_indices, - AsInt64Slice(shape.dimensions()))); - TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal))); } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kConstant) { - const Literal& src_literal = operand->literal(); - std::unique_ptr new_constant = - ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type()); - return ReplaceWithConstant(convert, std::move(new_constant)); - } - return Status::OK(); + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e0447d69aa..3e7f5b1f3d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -46,6 +46,89 @@ limitations under the License. namespace xla { +namespace { + +template +StatusOr> Compare(const Shape& shape, HloOpcode opcode, + const Literal& lhs_literal, + const Literal& rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el != rhs_el; + }; + break; + case HloOpcode::kGe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el >= rhs_el; + }; + break; + case HloOpcode::kGt: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el > rhs_el; + }; + break; + case HloOpcode::kLe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el <= rhs_el; + }; + break; + case HloOpcode::kLt: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el < rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + auto result = LiteralUtil::CreateFromShape(shape); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index)); + })); + + return std::move(result); +} + +template +StatusOr> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = LiteralUtil::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op( + LiteralUtil::Get(operand_literal, multi_index)); + })); + return std::move(result); +} + +} // namespace + template class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { public: @@ -68,7 +151,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return elem_operand; })); return Status::OK(); - }; + } template < typename NativeT, @@ -79,7 +162,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::abs(elem_operand); })); return Status::OK(); - }; + } Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override { return HandleAbs(abs, operand); @@ -101,6 +184,45 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + template + std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { + DCHECK_EQ(src_type, src_literal.shape().element_type()); + return LiteralUtil::Convert< + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative::type>( + src_literal); + } + + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override { + auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + + switch (operand->shape().element_type()) { +#define CONVERT_IF_TYPES_MATCH(src_type) \ + case (src_type): \ + parent_->evaluated_[convert] = LiteralUtil::Convert< \ + typename primitive_util::PrimitiveTypeToNative::type, \ + ReturnT>(operand_literal); \ + break; + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + LOG(FATAL) << "unimplemented operand type for HandleCovert: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); + } + Status HandleExp(HloInstruction* exp, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { @@ -117,15 +239,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleIsFinite(HloInstruction* is_finite, - HloInstruction* operand) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite], - ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) { - return std::isfinite(elem_operand); - })); - return Status::OK(); - }; - Status HandleLog(HloInstruction* log, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ReturnT elem_operand) { @@ -209,77 +322,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleCompare(HloInstruction* compare, HloOpcode opcode, - HloInstruction* lhs, HloInstruction* rhs) override { - std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el == rhs_el; - }; - break; - case HloOpcode::kNe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el != rhs_el; - }; - break; - case HloOpcode::kGe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el >= rhs_el; - }; - break; - case HloOpcode::kGt: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el > rhs_el; - }; - break; - case HloOpcode::kLe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el <= rhs_el; - }; - break; - case HloOpcode::kLt: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el < rhs_el; - }; - break; - default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); - } - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Compare operation with mismatched dimensions, likely due to " - "broadcasting is unsupported."); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = LiteralUtil::CreateFromShape(compare->shape()); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - compare_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); - - parent_->evaluated_[compare] = std::move(result); - - return Status::OK(); - }; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, HloInstruction* rhs) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { - return std::max(lhs, rhs); + return std::fmax(lhs, rhs); })); return Status::OK(); }; @@ -289,7 +337,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::min(lhs_el, rhs_el); + return std::fmin(lhs_el, rhs_el); })); return Status::OK(); }; @@ -309,7 +357,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[remainder], ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::remainder(lhs_el, rhs_el); + return std::fmod(lhs_el, rhs_el); })); return Status::OK(); }; @@ -338,7 +386,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { HloInstruction* arg, HloInstruction* max) override { std::function clamp_op = [](ReturnT low, ReturnT high, ReturnT value) { - return std::max(low, std::min(value, high)); + return std::fmax(low, std::fmin(value, high)); }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], ElementWiseTernaryOp(clamp, std::move(clamp_op))); @@ -370,32 +418,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { - const auto shape = instruction->shape(); - const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); - } - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - - auto result = LiteralUtil::CreateFromShape(shape); - - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - unary_op(LiteralUtil::Get(operand_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); - - return std::move(result); - }; + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + return ElementWiseUnaryOpImpl(instruction, unary_op, + operand_literal); + } StatusOr> ElementWiseBinaryOp( HloInstruction* instruction, @@ -420,16 +447,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); auto result = LiteralUtil::CreateFromShape(shape); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - binary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return binary_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index)); + })); return std::move(result); - }; + } template StatusOr> ElementWiseTernaryOp( @@ -459,17 +484,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); auto result = LiteralUtil::CreateFromShape(shape); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - ternary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index), - LiteralUtil::Get(ehs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op( + LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index), + LiteralUtil::Get(ehs_literal, multi_index)); + })); return std::move(result); - }; + } HloEvaluator* parent_; }; @@ -493,6 +518,12 @@ HloEvaluator::HloEvaluator() { }); typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: OPAQUE."); + }); } StatusOr> HloEvaluator::Evaluate( @@ -502,15 +533,15 @@ StatusOr> HloEvaluator::Evaluate( evaluated_.clear(); TF_RETURN_IF_ERROR(computation->Accept(this)); - return std::move(FindOrDie(evaluated_, computation->root_instruction())); + return MakeUnique( + GetEvaluatedLiteralFor(computation->root_instruction())); } StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction, tensorflow::gtl::ArraySlice operands) { - DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); - Shape shape = instruction->shape(); - TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_ = operands; evaluated_.clear(); @@ -525,13 +556,34 @@ StatusOr> HloEvaluator::Evaluate( TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); evaluated_[operand] = MakeUnique(*input_literal); - } else if (operand->opcode() == HloOpcode::kConstant) { - evaluated_[operand] = MakeUnique(operand->literal()); } } TF_RETURN_IF_ERROR(instruction->Visit(this)); - return std::move(FindOrDie(evaluated_, instruction)); + return MakeUnique(GetEvaluatedLiteralFor(instruction)); +} + +StatusOr> HloEvaluator::Evaluate( + HloInstruction* instruction) { + TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction)); + TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); + + arg_literals_.clear(); + evaluated_.clear(); + TF_RETURN_IF_ERROR(instruction->Visit(this)); + return MakeUnique(GetEvaluatedLiteralFor(instruction)); +} + +std::unique_ptr HloEvaluator::TryEvaluate( + HloInstruction* instruction) { + auto result_or = Evaluate(instruction); + if (!result_or.ok()) { + VLOG(1) << "TryEvaluate failed:" << result_or.status(); + return nullptr; + } + + return result_or.ConsumeValueOrDie(); } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { @@ -548,9 +600,191 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { Status HloEvaluator::HandleConstant(HloInstruction* constant, const Literal& literal) { VLOG(2) << "HandleConstant: " << constant->ToString(); - DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape())); + return Status::OK(); +} + +Status HloEvaluator::HandleReshape(HloInstruction* reshape) { + TF_ASSIGN_OR_RETURN( + evaluated_[reshape], + LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)), + AsInt64Slice(reshape->shape().dimensions()))); + return Status::OK(); +} + +Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { + evaluated_[transpose] = LiteralUtil::Transpose( + GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions()); + return Status::OK(); +} + +Status HloEvaluator::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + // The result concatenate dimension is going to be the sum of all concatenate + // dimensions of the operands taking part of the operation. + const Shape& reference_shape = operands[0]->shape(); + CHECK(!ShapeUtil::IsTuple(reference_shape)); + const int64 rank = ShapeUtil::Rank(reference_shape); + const int64 concat_dim = concatenate->dimensions()[0]; + CHECK_GE(concat_dim, 0); + CHECK_LT(concat_dim, rank); + + DimensionVector concat_dimensions(reference_shape.dimensions().begin(), + reference_shape.dimensions().end()); + + for (int64 i = 1; i < operands.size(); ++i) { + const Shape& operand_shape = operands[i]->shape(); + CHECK(!ShapeUtil::IsTuple(operand_shape)); + // Accumulate the concat dimension from all tensors taking part to the + // operation. + concat_dimensions[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + auto result_literal = LiteralUtil::CreateFromDimensions( + reference_shape.element_type(), concat_dimensions); + DimensionVector source_indices(rank, 0); + DimensionVector dest_indices(concat_dimensions.size(), 0); + + for (auto operand : operands) { + const Shape& operand_shape = operand->shape(); + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(), + dest_indices, AsInt64Slice(operand_shape.dimensions()))); + dest_indices[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + evaluated_[concatenate] = std::move(result_literal); + return Status::OK(); +} + +Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) { + if (!ShapeUtil::ElementIsFloating(operand->shape())) { + return InvalidArgument( + "expected element type in shape to be float for IsFinite op, got: %s", + PrimitiveType_Name(operand->shape().element_type()).c_str()); + } + + switch (operand->shape().element_type()) { + case F16: + return Unimplemented("unhandled primitive type: F16."); + case F32: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](float elem_operand) { return std::isfinite(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](double elem_operand) { return std::isfinite(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "unknown/unhandled primitive type."; + } + + return Status::OK(); +} + +Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s", + ShapeUtil::HumanString(compare->shape()).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); + + const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs); + + // Note here we switch on the operand's type. + switch (lhs->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U8: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U16: + return Unimplemented("unhandled primitive type: U16."); + case U32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S8: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S16: + return Unimplemented("unhandled primitive type: S16."); + case S32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case F16: + return Unimplemented("unhandled primitive type: F16."); + case F32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case F64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + default: + LOG(FATAL) << "unknown primitive type."; + } + + return Status::OK(); +} + +Status HloEvaluator::HandleSlice(HloInstruction* slice, + HloInstruction* operand) { + const Shape& shape = slice->shape(); + auto literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + + DimensionVector dest_indices(slice->slice_starts().size(), 0); + + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(), + dest_indices, AsInt64Slice(shape.dimensions()))); - evaluated_[constant] = MakeUnique(literal); + evaluated_[slice] = std::move(literal); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 50cb32eb85..e6798a35a0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -57,21 +57,32 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. // Precondition: - // 1. argument literals are corresponds to the input instruction's - // parameters in their post-orderring. + // 1. argument literals correspond to the input instruction's parameters in + // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. // TODO(b/35950897): implement more ops other than element-wise ops. StatusOr> Evaluate( HloInstruction* instruction, tensorflow::gtl::ArraySlice arg_literals); + // Evaluates a single HLO instruction with constant operands. + // Returns the evaluated result as literal if successful. + // Precondition: + // 1. all operands of the input instruction are constants. + // 2. the instruction is not a Parameter operation. + StatusOr> Evaluate(HloInstruction* instruction); + + // Same as Evaluate, except returning nullptr on error. + std::unique_ptr TryEvaluate(HloInstruction* instruction); + protected: // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. One - // exception to this is HandleCompare, where the resulting literal type is + // literal type of each evaluated Handle* method of a TypedVisitor. + // There are however a few notable exceptions to this is rule, notably: + // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. - // Note the forward declaration here is necessary to enable TypedVisitor to - // access parent members. + // These operations are handled outside of the parent HloEvaluator handlers + // instead of from within TypedVisitor. template class TypedVisitor; @@ -81,15 +92,38 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); } + // Operations that are type-agnostic. + // Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) override; + + Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) override; + private: // Returns the already-evaluated literal result for the instruction. + // A Constant instruction is considered evaluated and its literal will be + // returned directly without looking up the cache. // Crash with log if the given instruction has not been evaluated previously. const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + if (hlo->IsConstant()) { + return hlo->literal(); + } auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 443e5ad4f4..b26ece28b7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -23,8 +23,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -143,7 +145,7 @@ TEST_F(HloEvaluatorTest, DoesDivide) { // element-wise abs op with 1 operand. TEST_F(HloEvaluatorTest, DoesAbs) { auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(operand)); auto instruction = HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get()); @@ -154,7 +156,29 @@ TEST_F(HloEvaluatorTest, DoesAbs) { auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); -} + + // For R0 literal. + const Shape& r0 = ShapeUtil::MakeShape(F32, {}); + operand = LiteralUtil::CreateR0(-1.0f); + c1 = HloInstruction::CreateConstant(std::move(operand)); + instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get()); + result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); + expected = LiteralUtil::CreateR0(1.0f); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + + // For R1 literal with dimension of size 0. + Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); + operand = LiteralUtil::CreateR1({}); + c1 = HloInstruction::CreateConstant(std::move(operand)); + instruction = + HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get()); + + result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); + expected = LiteralUtil::CreateR1({}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} // namespace // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. @@ -187,5 +211,35 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); } +// Verifies Reshape operation is correctly evaluated. +TEST_F(HloEvaluatorTest, DoesReshape) { + HloComputation::Builder builder( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; + builder.AddInstruction( + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + LiteralUtil::EachCell( + *result, [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_TRUE(value == + LiteralUtil::Get(*literal_clone, rindexes)); + }); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index da2c075c8c..ee49a9ae5f 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -787,27 +787,28 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, // and unmodified_dim_pair have size >1. Otherwise, returns true and appends // the degerenate input/output dimensions in the gap to // deleted_indices/inserted_indices respectively. - auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices, - &inserted_indices]( - std::pair prior_unmodified_dim_pair, - std::pair unmodified_dim_pair) { - for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; - modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) { - if (shape_pre.dimensions(modified_input_dim) > 1) { - return false; - } - deleted_indices.push_back(modified_input_dim); - } - for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; - modified_output_dim < unmodified_dim_pair.second; - ++modified_output_dim) { - if (shape_post.dimensions(modified_output_dim) > 1) { - return false; - } - inserted_indices.push_back(modified_output_dim); - } - return true; - }; + auto check_modified_dims = + [&shape_pre, &shape_post, &deleted_indices, &inserted_indices]( + std::pair prior_unmodified_dim_pair, + std::pair unmodified_dim_pair) { + for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; + modified_input_dim < unmodified_dim_pair.first; + ++modified_input_dim) { + if (shape_pre.dimensions(modified_input_dim) > 1) { + return false; + } + deleted_indices.push_back(modified_input_dim); + } + for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; + modified_output_dim < unmodified_dim_pair.second; + ++modified_output_dim) { + if (shape_post.dimensions(modified_output_dim) > 1) { + return false; + } + inserted_indices.push_back(modified_output_dim); + } + return true; + }; std::vector> unmodified_dims = DimensionsUnmodifiedByReshape(shape_pre, shape_post); @@ -1220,6 +1221,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, tensorflow::gtl::ArraySlice count, tensorflow::gtl::ArraySlice incr, const IndexVisitorFunction& visitor_function) { + if (ShapeUtil::HasZeroElements(shape)) { + return; + } DCHECK_EQ(Rank(shape), base.size()); DCHECK_EQ(incr.size(), base.size()); DCHECK_EQ(count.size(), base.size()); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 8ac2e8345b..69ef6175cc 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -467,6 +467,34 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } +TEST(ShapeUtilTest, ForEachIndex) { + struct ShapeDimensionAndNumberInvocations { + std::vector dimensions; + int invocations; + } test_data[] = { + {{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0}, + {{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610}, + }; + + for (const auto& data : test_data) { + Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); + // Increments at every invocation. + int invocations = 0; + auto increment_func = [&invocations](const std::vector& indexes) { + invocations++; + return true; + }; + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + + ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step, + increment_func); + + EXPECT_EQ(invocations, data.invocations); + } +} + TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 4f98083033..a8b07a2c5d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" -- cgit v1.2.3 From 2741561c8eea7748ba04f6b47076bdfd22e3a915 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Wed, 7 Jun 2017 16:31:35 -0700 Subject: Fix up vz_projector script structure We now make sure scripts and HTML imports are declared in the correct places. In the future, pedantically listing script tags should not be necessary. PiperOrigin-RevId: 158334306 --- .../tensorboard/components/tf_tensorboard/tf-tensorboard.html | 1 - tensorflow/tensorboard/components/vz_projector/bundle.html | 8 -------- .../tensorboard/components/vz_projector/vz-projector-app.html | 4 ++-- .../components/vz_projector/vz-projector-bookmark-panel.html | 2 ++ .../components/vz_projector/vz-projector-data-panel.html | 3 +++ .../tensorboard/components/vz_projector/vz-projector-input.html | 4 +++- .../components/vz_projector/vz-projector-inspector-panel.html | 5 +++-- .../tensorboard/components/vz_projector/vz-projector-legend.html | 4 +++- .../components/vz_projector/vz-projector-metadata-card.html | 2 ++ .../components/vz_projector/vz-projector-projections-panel.html | 2 ++ tensorflow/tensorboard/components/vz_projector/vz-projector.html | 3 +++ 11 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html index 00a30686f6..926c476731 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html +++ b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html @@ -36,7 +36,6 @@ limitations under the License. - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html index 607d446789..d8dfd6e978 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html @@ -27,8 +27,10 @@ limitations under the License. + + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html index e77694426e..0d7bf7cdda 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html @@ -20,6 +20,7 @@ limitations under the License. + - \ No newline at end of file + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html index 7554c322ce..412bcbb480 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html @@ -17,9 +17,9 @@ limitations under the License. - - + + @@ -237,4 +237,5 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html index 3fc5f4db15..4b98d8bded 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html @@ -17,6 +17,7 @@ limitations under the License. + - \ No newline at end of file + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html index ebdcd72c77..4231a61ff3 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html @@ -18,6 +18,7 @@ limitations under the License. + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html index cddcb2b7d0..b82f3f520b 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html @@ -30,6 +30,7 @@ limitations under the License. + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.html b/tensorflow/tensorboard/components/vz_projector/vz-projector.html index d4be2f26a5..438ea9f4e9 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.html @@ -32,6 +32,7 @@ limitations under the License. + @@ -40,6 +41,7 @@ limitations under the License. + + -- cgit v1.2.3 From 187404eac001d67e8f0b3a7aeac2fc2b2ba65e0a Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 7 Jun 2017 16:51:01 -0700 Subject: Setup the env to since ops such as MatchFileOp rely on it. PiperOrigin-RevId: 158336344 --- tensorflow/core/grappler/optimizers/constant_folding.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 8bf6a081e3..33b0e16093 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -53,7 +54,7 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { class DeviceSimple : public DeviceBase { public: - DeviceSimple() : DeviceBase(nullptr) { + DeviceSimple() : DeviceBase(Env::Default()) { eigen_worker_threads_.num_threads = 1; eigen_worker_threads_.workers = new thread::ThreadPool( Env::Default(), "constant_folding", eigen_worker_threads_.num_threads); -- cgit v1.2.3 From 892293d98764c8623f7b7388517794dab07d8e62 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Wed, 7 Jun 2017 17:04:33 -0700 Subject: Set a default for datasets end_of_sequence. While all datasets carefully set the end_of_sequence to true at the appropriate time, some datasets might forget to set it to false in the normal case. In order to avoid potential undefined behavior, we set the end_of_sequence variable to be false by default. PiperOrigin-RevId: 158337799 --- tensorflow/core/kernels/iterator_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index c6e6634b1e..0a82ff227e 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -307,7 +307,7 @@ class IteratorGetNextOp : public AsyncOpKernel { core::ScopedUnref unref_iterator(iterator); std::vector components; - bool end_of_sequence; + bool end_of_sequence = false; IteratorContext::Params params; params.env = ctx->env(); -- cgit v1.2.3 From 845539f98302a2a42ba595ee1f197959064c3d1c Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Wed, 7 Jun 2017 17:27:03 -0700 Subject: Add evaluation test for linear classifier (n==2 or n >2). PiperOrigin-RevId: 158340296 --- tensorflow/python/estimator/BUILD | 1 + tensorflow/python/estimator/canned/linear_test.py | 380 ++------------ .../estimator/canned/linear_testing_utils.py | 552 +++++++++++++++++++++ 3 files changed, 587 insertions(+), 346 deletions(-) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 1d3f8de20a..5ab92d3352 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -459,6 +459,7 @@ py_library( deps = [ ":estimator", ":export_export", + ":linear", ":metric_keys", ":numpy_io", ":pandas_io", diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py index e1daaf51b3..1db3dfbf3a 100644 --- a/tensorflow/python/estimator/canned/linear_test.py +++ b/tensorflow/python/estimator/canned/linear_test.py @@ -18,30 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math -import shutil -import tempfile - -import numpy as np - from tensorflow.python.estimator.canned import linear from tensorflow.python.estimator.canned import linear_testing_utils -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column as feature_column_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import optimizer def _linear_regressor_fn(*args, **kwargs): return linear.LinearRegressor(*args, **kwargs) +def _linear_classifier_fn(*args, **kwargs): + return linear.LinearClassifier(*args, **kwargs) + + +# Tests for Linear Regressor. + + class LinearRegressorPartitionerTest( linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase): @@ -87,347 +79,43 @@ class LinearRegressorTrainingTest( self, _linear_regressor_fn) -class _BaseLinearClassiferTrainingTest(object): - - def __init__(self, n_classes): - self._n_classes = n_classes - self._logits_dimensions = ( - self._n_classes if self._n_classes > 2 else 1) - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - shutil.rmtree(self._model_dir) - - def _mock_optimizer(self, expected_loss=None): - expected_var_names = [ - '%s/part_0:0' % linear_testing_utils.AGE_WEIGHT_NAME, - '%s/part_0:0' % linear_testing_utils.BIAS_NAME - ] - - def _minimize(loss, global_step): - trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertItemsEqual( - expected_var_names, - [var.name for var in trainable_vars]) - - # Verify loss. We can't check the value directly, so we add an assert op. - self.assertEquals(0, loss.shape.ndims) - if expected_loss is None: - return state_ops.assign_add(global_step, 1).op - assert_loss = linear_testing_utils.assert_close( - math_ops.to_float(expected_loss, name='expected'), - loss, - name='assert_loss') - with ops.control_dependencies((assert_loss,)): - return state_ops.assign_add(global_step, 1).op - - mock_optimizer = test.mock.NonCallableMock( - spec=optimizer.Optimizer, - wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) - mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) - - # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. - # So, return mock_optimizer itself for deepcopy. - mock_optimizer.__deepcopy__ = lambda _: mock_optimizer - return mock_optimizer - - def _assert_checkpoint( - self, expected_global_step, expected_age_weight=None, expected_bias=None): - logits_dimension = self._logits_dimensions - - shapes = { - name: shape for (name, shape) in - checkpoint_utils.list_variables(self._model_dir) - } - - self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) - self.assertEqual( - expected_global_step, - checkpoint_utils.load_variable( - self._model_dir, ops.GraphKeys.GLOBAL_STEP)) - - self.assertEqual([1, logits_dimension], - shapes[linear_testing_utils.AGE_WEIGHT_NAME]) - if expected_age_weight is not None: - self.assertAllEqual(expected_age_weight, - checkpoint_utils.load_variable( - self._model_dir, - linear_testing_utils.AGE_WEIGHT_NAME)) - - self.assertEqual([logits_dimension], shapes[linear_testing_utils.BIAS_NAME]) - if expected_bias is not None: - self.assertAllEqual(expected_bias, - checkpoint_utils.load_variable( - self._model_dir, linear_testing_utils.BIAS_NAME)) - - def testFromScratchWithDefaultOptimizer(self): - n_classes = self._n_classes - label = 0 - age = 17 - est = linear.LinearClassifier( - feature_columns=(feature_column_lib.numeric_column('age'),), - n_classes=n_classes, - model_dir=self._model_dir) - - # Train for a few steps, and validate final checkpoint. - num_steps = 10 - est.train( - input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) - self._assert_checkpoint(num_steps) - - def testTrainWithTwoDimsLabel(self): - n_classes = self._n_classes - batch_size = 20 - - est = linear.LinearClassifier( - feature_columns=(feature_column_lib.numeric_column('age'),), - n_classes=n_classes, - model_dir=self._model_dir) - data_rank_1 = np.array([0, 1]) - data_rank_2 = np.array([[0], [1]]) - self.assertEqual((2,), data_rank_1.shape) - self.assertEqual((2, 1), data_rank_2.shape) - - train_input_fn = numpy_io.numpy_input_fn( - x={'age': data_rank_1}, - y=data_rank_2, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - est.train(train_input_fn, steps=200) - self._assert_checkpoint(200) - - def testTrainWithOneDimLabel(self): - n_classes = self._n_classes - batch_size = 20 - - est = linear.LinearClassifier( - feature_columns=(feature_column_lib.numeric_column('age'),), - n_classes=n_classes, - model_dir=self._model_dir) - data_rank_1 = np.array([0, 1]) - self.assertEqual((2,), data_rank_1.shape) - - train_input_fn = numpy_io.numpy_input_fn( - x={'age': data_rank_1}, - y=data_rank_1, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - est.train(train_input_fn, steps=200) - self._assert_checkpoint(200) - - def testTrainWithTwoDimsWeight(self): - n_classes = self._n_classes - batch_size = 20 - - est = linear.LinearClassifier( - feature_columns=(feature_column_lib.numeric_column('age'),), - weight_feature_key='w', - n_classes=n_classes, - model_dir=self._model_dir) - data_rank_1 = np.array([0, 1]) - data_rank_2 = np.array([[0], [1]]) - self.assertEqual((2,), data_rank_1.shape) - self.assertEqual((2, 1), data_rank_2.shape) - - train_input_fn = numpy_io.numpy_input_fn( - x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1, - batch_size=batch_size, num_epochs=None, - shuffle=True) - est.train(train_input_fn, steps=200) - self._assert_checkpoint(200) - - def testTrainWithOneDimWeight(self): - n_classes = self._n_classes - batch_size = 20 - - est = linear.LinearClassifier( - feature_columns=(feature_column_lib.numeric_column('age'),), - weight_feature_key='w', - n_classes=n_classes, - model_dir=self._model_dir) - data_rank_1 = np.array([0, 1]) - self.assertEqual((2,), data_rank_1.shape) - - train_input_fn = numpy_io.numpy_input_fn( - x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1, - batch_size=batch_size, num_epochs=None, - shuffle=True) - est.train(train_input_fn, steps=200) - self._assert_checkpoint(200) - - def testFromScratch(self): - n_classes = self._n_classes - label = 1 - age = 17 - # For binary classifer: - # loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are - # all zero initially) and label = 1 so, - # loss = 1 * -log ( sigmoid(logits) ) = 0.69315 - # For multi class classifer: - # loss = cross_entropy(logits, label) where logits are all 0s (weights are - # all zero initially) and label = 1 so, - # loss = 1 * -log ( 1.0 / n_classes ) - # For this particular test case, as logits are same, the formular - # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases. - mock_optimizer = self._mock_optimizer( - expected_loss=-1 * math.log(1.0/n_classes)) - - est = linear.LinearClassifier( - feature_columns=(feature_column_lib.numeric_column('age'),), - n_classes=n_classes, - optimizer=mock_optimizer, - model_dir=self._model_dir) - self.assertEqual(0, mock_optimizer.minimize.call_count) - - # Train for a few steps, and validate optimizer and final checkpoint. - num_steps = 10 - est.train( - input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) - self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assert_checkpoint( - expected_global_step=num_steps, - expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes], - expected_bias=[0.] if n_classes == 2 else [.0] * n_classes) - - def testFromCheckpoint(self): - # Create initial checkpoint. - n_classes = self._n_classes - label = 1 - age = 17 - # For binary case, the expected weight has shape (1,1). For multi class - # case, the shape is (1, n_classes). In order to test the weights, set - # weights as 2.0 * range(n_classes). - age_weight = [[2.0]] if n_classes == 2 else ( - np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), - (1, n_classes))) - bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes - initial_global_step = 100 - with ops.Graph().as_default(): - variables.Variable(age_weight, name=linear_testing_utils.AGE_WEIGHT_NAME) - variables.Variable(bias, name=linear_testing_utils.BIAS_NAME) - variables.Variable( - initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, - dtype=dtypes.int64) - linear_testing_utils.save_variables_to_ckpt(self._model_dir) - - # For binary classifer: - # logits = age * age_weight + bias = 17 * 2. - 35. = -1. - # loss = sigmoid_cross_entropy(logits, label) - # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133 - # For multi class classifer: - # loss = cross_entropy(logits, label) - # where logits = 17 * age_weight + bias and label = 1 - # so, loss = 1 * -log ( soft_max(logits)[1] ) - if n_classes == 2: - expected_loss = 1.3133 - else: - logits = age_weight * age + bias - logits_exp = np.exp(logits) - softmax = logits_exp / logits_exp.sum() - expected_loss = -1 * math.log(softmax[0, label]) - - mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) - - est = linear.LinearClassifier( - feature_columns=(feature_column_lib.numeric_column('age'),), - n_classes=n_classes, - optimizer=mock_optimizer, - model_dir=self._model_dir) - self.assertEqual(0, mock_optimizer.minimize.call_count) - - # Train for a few steps, and validate optimizer and final checkpoint. - num_steps = 10 - est.train( - input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) - self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assert_checkpoint( - expected_global_step=initial_global_step + num_steps, - expected_age_weight=age_weight, - expected_bias=bias) - - def testFromCheckpointMultiBatch(self): - # Create initial checkpoint. - n_classes = self._n_classes - label = [1, 0] - age = [17, 18.5] - # For binary case, the expected weight has shape (1,1). For multi class - # case, the shape is (1, n_classes). In order to test the weights, set - # weights as 2.0 * range(n_classes). - age_weight = [[2.0]] if n_classes == 2 else ( - np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), - (1, n_classes))) - bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes - initial_global_step = 100 - with ops.Graph().as_default(): - variables.Variable(age_weight, name=linear_testing_utils.AGE_WEIGHT_NAME) - variables.Variable(bias, name=linear_testing_utils.BIAS_NAME) - variables.Variable( - initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, - dtype=dtypes.int64) - linear_testing_utils.save_variables_to_ckpt(self._model_dir) - - # For binary classifer: - # logits = age * age_weight + bias - # logits[0] = 17 * 2. - 35. = -1. - # logits[1] = 18.5 * 2. - 35. = 2. - # loss = sigmoid_cross_entropy(logits, label) - # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133 - # loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269 - # For multi class classifer: - # loss = cross_entropy(logits, label) - # where logits = [17, 18.5] * age_weight + bias and label = [1, 0] - # so, loss = 1 * -log ( soft_max(logits)[label] ) - if n_classes == 2: - expected_loss = (1.3133 + 2.1269) - else: - logits = age_weight * np.reshape(age, (2, 1)) + bias - logits_exp = np.exp(logits) - softmax_row_0 = logits_exp[0] / logits_exp[0].sum() - softmax_row_1 = logits_exp[1] / logits_exp[1].sum() - expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) - expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) - expected_loss = expected_loss_0 + expected_loss_1 - - mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) - - est = linear.LinearClassifier( - feature_columns=(feature_column_lib.numeric_column('age'),), - n_classes=n_classes, - optimizer=mock_optimizer, - model_dir=self._model_dir) - self.assertEqual(0, mock_optimizer.minimize.call_count) - - # Train for a few steps, and validate optimizer and final checkpoint. - num_steps = 10 - est.train( - input_fn=lambda: ({'age': (age)}, (label)), - steps=num_steps) - self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assert_checkpoint( - expected_global_step=initial_global_step + num_steps, - expected_age_weight=age_weight, - expected_bias=bias) +# Tests for Linear Classifer. class LinearClassiferWithBinaryClassesTrainingTest( - _BaseLinearClassiferTrainingTest, test.TestCase): + linear_testing_utils.BaseLinearClassiferTrainingTest, test.TestCase): - def __init__(self, methodName='runTest'): + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name test.TestCase.__init__(self, methodName) - _BaseLinearClassiferTrainingTest.__init__(self, n_classes=2) + linear_testing_utils.BaseLinearClassiferTrainingTest.__init__( + self, n_classes=2) class LinearClassiferWithMultiClassesTrainingTest( - _BaseLinearClassiferTrainingTest, test.TestCase): + linear_testing_utils.BaseLinearClassiferTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearClassiferTrainingTest.__init__( + self, n_classes=4) + + +class LinearClassiferWithBinaryClassesEvaluationTest( + linear_testing_utils.BaseLinearClassiferEvaluationTest, test.TestCase): - def __init__(self, methodName='runTest'): + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearClassiferEvaluationTest.__init__( + self, n_classes=2, linear_classifer_fn=_linear_classifier_fn) + + +class LinearClassiferWithMultiClassesEvaluationTest( + linear_testing_utils.BaseLinearClassiferEvaluationTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name test.TestCase.__init__(self, methodName) - _BaseLinearClassiferTrainingTest.__init__(self, n_classes=4) + linear_testing_utils.BaseLinearClassiferEvaluationTest.__init__( + self, n_classes=4, linear_classifer_fn=_linear_classifier_fn) if __name__ == '__main__': diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index c13902d389..569630addd 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import os import shutil import tempfile @@ -30,6 +31,7 @@ from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.estimator import estimator from tensorflow.python.estimator import run_config +from tensorflow.python.estimator.canned import linear from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io @@ -112,6 +114,14 @@ def queue_parsed_features(feature_map): return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))} +def sorted_key_dict(unsorted_dict): + return {k: unsorted_dict[k] for k in sorted(unsorted_dict)} + + +def sigmoid(x): + return 1 / (1 + np.exp(-1.0 * x)) + + class CheckPartitionerVarHook(session_run_hook.SessionRunHook): """A `SessionRunHook` to check a paritioned variable.""" @@ -856,3 +866,545 @@ class BaseLinearRegressorTrainingTest(object): expected_global_step=initial_global_step + num_steps, expected_age_weight=age_weight, expected_bias=bias) + + +class BaseLinearClassiferTrainingTest(object): + + def __init__(self, n_classes): + self._n_classes = n_classes + self._logits_dimensions = ( + self._n_classes if self._n_classes > 2 else 1) + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s/part_0:0' % AGE_WEIGHT_NAME, + '%s/part_0:0' % BIAS_NAME + ] + + def _minimize(loss, global_step): + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + return state_ops.assign_add(global_step, 1).op + assert_loss = assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + return state_ops.assign_add(global_step, 1).op + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint( + self, expected_global_step, expected_age_weight=None, expected_bias=None): + logits_dimension = self._logits_dimensions + + shapes = { + name: shape for (name, shape) in + checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual( + expected_global_step, + checkpoint_utils.load_variable( + self._model_dir, ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([1, logits_dimension], + shapes[AGE_WEIGHT_NAME]) + if expected_age_weight is not None: + self.assertAllEqual(expected_age_weight, + checkpoint_utils.load_variable( + self._model_dir, + AGE_WEIGHT_NAME)) + + self.assertEqual([logits_dimension], shapes[BIAS_NAME]) + if expected_bias is not None: + self.assertAllEqual(expected_bias, + checkpoint_utils.load_variable( + self._model_dir, BIAS_NAME)) + + def testFromScratchWithDefaultOptimizer(self): + n_classes = self._n_classes + label = 0 + age = 17 + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self._assert_checkpoint(num_steps) + + def testTrainWithTwoDimsLabel(self): + n_classes = self._n_classes + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + data_rank_2 = np.array([[0], [1]]) + self.assertEqual((2,), data_rank_1.shape) + self.assertEqual((2, 1), data_rank_2.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_2, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(200) + + def testTrainWithOneDimLabel(self): + n_classes = self._n_classes + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + self.assertEqual((2,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_1, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(200) + + def testTrainWithTwoDimsWeight(self): + n_classes = self._n_classes + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + weight_feature_key='w', + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + data_rank_2 = np.array([[0], [1]]) + self.assertEqual((2,), data_rank_1.shape) + self.assertEqual((2, 1), data_rank_2.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(200) + + def testTrainWithOneDimWeight(self): + n_classes = self._n_classes + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + weight_feature_key='w', + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + self.assertEqual((2,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(200) + + def testFromScratch(self): + n_classes = self._n_classes + label = 1 + age = 17 + # For binary classifer: + # loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( sigmoid(logits) ) = 0.69315 + # For multi class classifer: + # loss = cross_entropy(logits, label) where logits are all 0s (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( 1.0 / n_classes ) + # For this particular test case, as logits are same, the formular + # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases. + mock_optimizer = self._mock_optimizer( + expected_loss=-1 * math.log(1.0/n_classes)) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + expected_global_step=num_steps, + expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes], + expected_bias=[0.] if n_classes == 2 else [.0] * n_classes) + + def testFromCheckpoint(self): + # Create initial checkpoint. + n_classes = self._n_classes + label = 1 + age = 17 + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[2.0]] if n_classes == 2 else ( + np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(age_weight, name=AGE_WEIGHT_NAME) + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # For binary classifer: + # logits = age * age_weight + bias = 17 * 2. - 35. = -1. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133 + # For multi class classifer: + # loss = cross_entropy(logits, label) + # where logits = 17 * age_weight + bias and label = 1 + # so, loss = 1 * -log ( soft_max(logits)[1] ) + if n_classes == 2: + expected_loss = 1.3133 + else: + logits = age_weight * age + bias + logits_exp = np.exp(logits) + softmax = logits_exp / logits_exp.sum() + expected_loss = -1 * math.log(softmax[0, label]) + + mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + expected_global_step=initial_global_step + num_steps, + expected_age_weight=age_weight, + expected_bias=bias) + + def testFromCheckpointMultiBatch(self): + # Create initial checkpoint. + n_classes = self._n_classes + label = [1, 0] + age = [17, 18.5] + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[2.0]] if n_classes == 2 else ( + np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(age_weight, name=AGE_WEIGHT_NAME) + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # For binary classifer: + # logits = age * age_weight + bias + # logits[0] = 17 * 2. - 35. = -1. + # logits[1] = 18.5 * 2. - 35. = 2. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133 + # loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269 + # For multi class classifer: + # loss = cross_entropy(logits, label) + # where logits = [17, 18.5] * age_weight + bias and label = [1, 0] + # so, loss = 1 * -log ( soft_max(logits)[label] ) + if n_classes == 2: + expected_loss = (1.3133 + 2.1269) + else: + logits = age_weight * np.reshape(age, (2, 1)) + bias + logits_exp = np.exp(logits) + softmax_row_0 = logits_exp[0] / logits_exp[0].sum() + softmax_row_1 = logits_exp[1] / logits_exp[1].sum() + expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) + expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) + expected_loss = expected_loss_0 + expected_loss_1 + + mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': (age)}, (label)), + steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + expected_global_step=initial_global_step + num_steps, + expected_age_weight=age_weight, + expected_bias=bias) + + +class BaseLinearClassiferEvaluationTest(object): + + def __init__(self, n_classes, linear_classifer_fn): + self._linear_classifer_fn = linear_classifer_fn + self._n_classes = n_classes + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def test_evaluation_for_simple_data(self): + n_classes = self._n_classes + label = 1 + age = 1. + + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[-11.0]] if n_classes == 2 else ( + np.reshape(-11.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-30.0] if n_classes == 2 else [-30.0] * n_classes + + with ops.Graph().as_default(): + variables.Variable(age_weight, name=AGE_WEIGHT_NAME) + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + est = self._linear_classifer_fn( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + eval_metrics = est.evaluate( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=1) + + if n_classes == 2: + # Binary classes: loss = sum(corss_entropy(41)) = 41. + expected_metrics = { + metric_keys.MetricKeys.LOSS: 41., + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: 41., + metric_keys.MetricKeys.ACCURACY: 0., + metric_keys.MetricKeys.PREDICTION_MEAN: 0., + metric_keys.MetricKeys.LABEL_MEAN: 1., + metric_keys.MetricKeys.ACCURACY_BASELINE: 1, + metric_keys.MetricKeys.AUC: 0., + metric_keys.MetricKeys.AUC_PR: 1., + } + else: + # Multi classes: loss = 1 * -log ( soft_max(logits)[label] ) + logits = age_weight * age + bias + logits_exp = np.exp(logits) + softmax = logits_exp / logits_exp.sum() + expected_loss = -1 * math.log(softmax[0, label]) + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss, + metric_keys.MetricKeys.ACCURACY: 0., + } + + self.assertAllClose(sorted_key_dict(expected_metrics), + sorted_key_dict(eval_metrics), rtol=1e-3) + + def test_evaluation_batch(self): + """Tests evaluation for batch_size==2.""" + n_classes = self._n_classes + label = [1, 0] + age = [17., 18.] + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[2.0]] if n_classes == 2 else ( + np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(age_weight, name=AGE_WEIGHT_NAME) + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + est = self._linear_classifer_fn( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + eval_metrics = est.evaluate( + input_fn=lambda: ({'age': (age)}, (label)), steps=1) + + if n_classes == 2: + # Logits are (-1., 1.) labels are (1, 0). + # Loss is + # loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133 + # loss for row 2: (1 - 0) * -log(1 - sigmoid(1)) = 1.3133 + expected_loss = 1.3133 * 2 + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, + metric_keys.MetricKeys.ACCURACY: 0., + metric_keys.MetricKeys.PREDICTION_MEAN: 0.5, + metric_keys.MetricKeys.LABEL_MEAN: 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + metric_keys.MetricKeys.AUC: 0., + metric_keys.MetricKeys.AUC_PR: 0.25, + } + else: + # Multi classes: loss = 1 * -log ( soft_max(logits)[label] ) + logits = age_weight * np.reshape(age, (2, 1)) + bias + logits_exp = np.exp(logits) + softmax_row_0 = logits_exp[0] / logits_exp[0].sum() + softmax_row_1 = logits_exp[1] / logits_exp[1].sum() + expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) + expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) + expected_loss = expected_loss_0 + expected_loss_1 + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, + metric_keys.MetricKeys.ACCURACY: 0., + } + + self.assertAllClose(sorted_key_dict(expected_metrics), + sorted_key_dict(eval_metrics), rtol=1e-3) + + def test_evaluation_weights(self): + """Tests evaluation with weights.""" + + n_classes = self._n_classes + label = [1, 0] + age = [17., 18.] + weights = [1., 2.] + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[2.0]] if n_classes == 2 else ( + np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(age_weight, name=AGE_WEIGHT_NAME) + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + est = self._linear_classifer_fn( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + weight_feature_key='w', + model_dir=self._model_dir) + eval_metrics = est.evaluate( + input_fn=lambda: ({'age': (age), 'w': (weights)}, (label)), steps=1) + + if n_classes == 2: + # Logits are (-1., 1.) labels are (1, 0). + # Loss is + # loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133 + # loss for row 2: (1 - 0) * -log(1 - sigmoid(1)) = 1.3133 + # weights = [1., 2.] + expected_loss = 1.3133 * (1. + 2.) + loss_mean = expected_loss / (1.0 + 2.0) + label_mean = np.average(label, weights=weights) + logits = [-1, 1] + logistics = sigmoid(np.array(logits)) + predictions_mean = np.average(logistics, weights=weights) + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: loss_mean, + metric_keys.MetricKeys.ACCURACY: 0., + metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean, + metric_keys.MetricKeys.LABEL_MEAN: label_mean, + metric_keys.MetricKeys.ACCURACY_BASELINE: ( + max(label_mean, 1-label_mean)), + metric_keys.MetricKeys.AUC: 0., + metric_keys.MetricKeys.AUC_PR: 0.1668, + } + else: + # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] ) + logits = age_weight * np.reshape(age, (2, 1)) + bias + logits_exp = np.exp(logits) + softmax_row_0 = logits_exp[0] / logits_exp[0].sum() + softmax_row_1 = logits_exp[1] / logits_exp[1].sum() + expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) + expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) + loss_mean = np.average([expected_loss_0, expected_loss_1], + weights=weights) + expected_loss = loss_mean * np.sum(weights) + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: loss_mean, + metric_keys.MetricKeys.ACCURACY: 0., + } + + self.assertAllClose(sorted_key_dict(expected_metrics), + sorted_key_dict(eval_metrics), rtol=1e-3) -- cgit v1.2.3 From cd5ac40b31afaec237aaee35007f2dc846caf811 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 7 Jun 2017 20:13:37 -0700 Subject: [XLA] Update LLVM to upstream revision r304927. Add LLVM build rules for the LLVM AMDGPU backend, commented out by default. Fixes issue #10437. PiperOrigin-RevId: 158351480 --- tensorflow/workspace.bzl | 8 +- third_party/llvm/llvm.BUILD | 239 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 233 insertions(+), 14 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index a5e7588860..57a096d993 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -496,11 +496,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "llvm", urls = [ - "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/c978c0ff91f7c4ea58cfbd8f378e51c6af2c2b4b.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/c978c0ff91f7c4ea58cfbd8f378e51c6af2c2b4b.tar.gz", + "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/e156d99231a7735d06a97b5b83de70bf4ce4f034.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/e156d99231a7735d06a97b5b83de70bf4ce4f034.tar.gz", ], - sha256 = "42c57d798a037d9dea692ce1da8ff4d24966ab5a40494015b374341e43411a37", - strip_prefix = "llvm-c978c0ff91f7c4ea58cfbd8f378e51c6af2c2b4b", + sha256 = "72e34e2411a06d4200a2688ee83832805fbef23a12ea481f31c2b8866fde007a", + strip_prefix = "llvm-e156d99231a7735d06a97b5b83de70bf4ce4f034", build_file = str(Label("//third_party/llvm:llvm.BUILD")), repository = tf_repo_name, ) diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index 2b52a991c4..32266997a7 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -24,23 +24,20 @@ llvm_host_triple = "x86_64-unknown-linux_gnu" llvm_targets = [ "AArch64", + # Uncomment to enable the AMDGPU backend. + # TODO(phawkins): use a configure-time test. + # "AMDGPU", "ARM", "NVPTX", "PowerPC", "X86", ] -llvm_target_asm_parsers = [ - "AArch64", - "ARM", - "NVPTX", - "PowerPC", - "X86", -] +llvm_target_asm_parsers = llvm_targets -llvm_target_asm_printers = llvm_target_asm_parsers +llvm_target_asm_printers = llvm_targets -llvm_target_disassemblers = llvm_target_asm_parsers +llvm_target_disassemblers = llvm_targets # TODO(phawkins): the set of CMake variables was hardcoded for expediency. # However, we should really detect many of these via configure-time tests. @@ -352,6 +349,26 @@ llvm_target_list = [ ("-gen-searchable-tables", "lib/Target/AArch64/AArch64GenSystemOperands.inc"), ], }, + { + "name": "AMDGPU", + "lower_name": "amdgpu", + "short_name": "AMDGPU", + "tbl_outs": [ + ("-gen-register-bank", "lib/Target/AMDGPU/AMDGPUGenRegisterBank.inc"), + ("-gen-register-info", "lib/Target/AMDGPU/AMDGPUGenRegisterInfo.inc"), + ("-gen-instr-info", "lib/Target/AMDGPU/AMDGPUGenInstrInfo.inc"), + ("-gen-dag-isel", "lib/Target/AMDGPU/AMDGPUGenDAGISel.inc"), + ("-gen-callingconv", "lib/Target/AMDGPU/AMDGPUGenCallingConv.inc"), + ("-gen-subtarget", "lib/Target/AMDGPU/AMDGPUGenSubtargetInfo.inc"), + ("-gen-tgt-intrinsic", "lib/Target/AMDGPU/AMDGPUGenIntrinsics.inc"), + ("-gen-emitter", "lib/Target/AMDGPU/AMDGPUGenMCCodeEmitter.inc"), + ("-gen-dfa-packetizer", "lib/Target/AMDGPU/AMDGPUGenDFAPacketizer.inc"), + ("-gen-asm-writer", "lib/Target/AMDGPU/AMDGPUGenAsmWriter.inc"), + ("-gen-asm-matcher", "lib/Target/AMDGPU/AMDGPUGenAsmMatcher.inc"), + ("-gen-disassembler", "lib/Target/AMDGPU/AMDGPUGenDisassemblerTables.inc"), + ("-gen-pseudo-lowering", "lib/Target/AMDGPU/AMDGPUGenMCPseudoLowering.inc"), + ], + }, { "name": "ARM", "lower_name": "arm", @@ -436,7 +453,6 @@ llvm_target_list = [ "include/llvm/IR/Intrinsics*.td", "include/llvm/TableGen/*.td", "include/llvm/Target/*.td", - "include/llvm/Target/GlobalISel/*.td", ]), ) for target in llvm_target_list @@ -648,6 +664,7 @@ cc_library( "include/llvm/Analysis/*.inc", ]), deps = [ + ":binary_format", ":config", ":core", ":object", @@ -656,6 +673,184 @@ cc_library( ], ) +cc_library( + name = "amdgpu_desc", + srcs = glob([ + "lib/Target/AMDGPU/MCTargetDesc/*.c", + "lib/Target/AMDGPU/MCTargetDesc/*.cpp", + "lib/Target/AMDGPU/MCTargetDesc/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/AMDGPU/MCTargetDesc/*.h", + "include/llvm/Target/AMDGPU/MCTargetDesc/*.def", + "include/llvm/Target/AMDGPU/MCTargetDesc/*.inc", + "lib/Target/AMDGPU/MCTargetDesc/*.h", + ]), + copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + deps = [ + ":amdgpu_asm_printer", + ":amdgpu_info", + ":amdgpu_utils", + ":config", + ":core", + ":mc", + ":support", + ], +) + +cc_library( + name = "amdgpu_disassembler", + srcs = glob([ + "lib/Target/AMDGPU/Disassembler/*.c", + "lib/Target/AMDGPU/Disassembler/*.cpp", + "lib/Target/AMDGPU/Disassembler/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/AMDGPU/Disassembler/*.h", + "include/llvm/Target/AMDGPU/Disassembler/*.def", + "include/llvm/Target/AMDGPU/Disassembler/*.inc", + "lib/Target/AMDGPU/Disassembler/*.h", + ]), + copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + deps = [ + ":amdgpu_desc", + ":amdgpu_info", + ":amdgpu_utils", + ":config", + ":mc", + ":mc_disassembler", + ":support", + ], +) + +cc_library( + name = "amdgpu_info", + srcs = glob([ + "lib/Target/AMDGPU/TargetInfo/*.c", + "lib/Target/AMDGPU/TargetInfo/*.cpp", + "lib/Target/AMDGPU/TargetInfo/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/AMDGPU/TargetInfo/*.h", + "include/llvm/Target/AMDGPU/TargetInfo/*.def", + "include/llvm/Target/AMDGPU/TargetInfo/*.inc", + "lib/Target/AMDGPU/TargetInfo/*.h", + ]), + copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + deps = [ + ":amdgpu_target_gen", + ":config", + ":core", + ":support", + ], +) + +cc_library( + name = "amdgpu_utils", + srcs = glob([ + "lib/Target/AMDGPU/Utils/*.c", + "lib/Target/AMDGPU/Utils/*.cpp", + "lib/Target/AMDGPU/Utils/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/AMDGPU/Utils/*.h", + "include/llvm/Target/AMDGPU/Utils/*.def", + "include/llvm/Target/AMDGPU/Utils/*.inc", + "lib/Target/AMDGPU/Utils/*.h", + ]), + copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + deps = [ + ":amdgpu_target_gen", + ":config", + ":core", + ":mc", + ":support", + ], +) + +cc_library( + name = "amdgpu_asm_parser", + srcs = glob([ + "lib/Target/AMDGPU/AsmParser/*.c", + "lib/Target/AMDGPU/AsmParser/*.cpp", + "lib/Target/AMDGPU/AsmParser/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/AMDGPU/AsmParser/*.h", + "include/llvm/Target/AMDGPU/AsmParser/*.def", + "include/llvm/Target/AMDGPU/AsmParser/*.inc", + "lib/Target/AMDGPU/AsmParser/*.h", + ]), + copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + deps = [ + ":amdgpu_desc", + ":amdgpu_info", + ":amdgpu_utils", + ":config", + ":mc", + ":mc_parser", + ":support", + ], +) + +cc_library( + name = "amdgpu_asm_printer", + srcs = glob([ + "lib/Target/AMDGPU/InstPrinter/*.c", + "lib/Target/AMDGPU/InstPrinter/*.cpp", + "lib/Target/AMDGPU/InstPrinter/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/AMDGPU/InstPrinter/*.h", + "include/llvm/Target/AMDGPU/InstPrinter/*.def", + "include/llvm/Target/AMDGPU/InstPrinter/*.inc", + "lib/Target/AMDGPU/InstPrinter/*.h", + ]), + copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + deps = [ + ":amdgpu_utils", + ":config", + ":mc", + ":support", + ], +) + +cc_library( + name = "amdgpu_code_gen", + srcs = glob([ + "lib/Target/AMDGPU/*.c", + "lib/Target/AMDGPU/*.cpp", + "lib/Target/AMDGPU/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/AMDGPU/*.h", + "include/llvm/Target/AMDGPU/*.def", + "include/llvm/Target/AMDGPU/*.inc", + "lib/Target/AMDGPU/*.h", + ]), + copts = ["-Iexternal/llvm/lib/Target/AMDGPU"], + deps = [ + ":amdgpu_asm_printer", + ":amdgpu_desc", + ":amdgpu_info", + ":amdgpu_utils", + ":analysis", + ":asm_printer", + ":code_gen", + ":config", + ":core", + ":global_i_sel", + ":ipo", + ":mc", + ":scalar", + ":selection_dag", + ":support", + ":target", + ":transform_utils", + ":vectorize", + ], +) + cc_library( name = "arm_asm_parser", srcs = glob([ @@ -824,6 +1019,7 @@ cc_library( "include/llvm/AsmParser/*.inc", ]), deps = [ + ":binary_format", ":config", ":core", ":support", @@ -842,9 +1038,11 @@ cc_library( "include/llvm/CodeGen/AsmPrinter/*.h", "include/llvm/CodeGen/AsmPrinter/*.def", "include/llvm/CodeGen/AsmPrinter/*.inc", + "lib/CodeGen/AsmPrinter/*.def", ]), deps = [ ":analysis", + ":binary_format", ":code_gen", ":config", ":core", @@ -857,6 +1055,25 @@ cc_library( ], ) +cc_library( + name = "binary_format", + srcs = glob([ + "lib/BinaryFormat/*.c", + "lib/BinaryFormat/*.cpp", + "lib/BinaryFormat/*.inc", + "lib/BinaryFormat/*.h", + ]), + hdrs = glob([ + "include/llvm/BinaryFormat/*.h", + "include/llvm/BinaryFormat/*.def", + "include/llvm/BinaryFormat/*.inc", + ]), + deps = [ + ":config", + ":support", + ], +) + cc_library( name = "bit_reader", srcs = glob([ @@ -956,6 +1173,7 @@ cc_library( deps = [ ":attributes_compat_gen", ":attributes_gen", + ":binary_format", ":config", ":intrinsics_gen", ":support", @@ -1376,6 +1594,7 @@ cc_library( "include/llvm/Object/*.inc", ]), deps = [ + ":binary_format", ":bit_reader", ":config", ":core", -- cgit v1.2.3 From 3a2971bd8e2277bd6a32bd222852952b57b11fc4 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Wed, 7 Jun 2017 22:38:41 -0700 Subject: Adds the base for ClusterResolvers, a new way of communicating with and retrieving cluster information for running distributed TensorFlow. Implementations of this class would eventually allow users to simply point TensorFlow at a cluster management endpoint, and TensorFlow will automatically retrieve the host names/IPs and port numbers of TensorFlow workers from the cluster management service. PiperOrigin-RevId: 158358761 --- tensorflow/BUILD | 1 + tensorflow/contrib/BUILD | 1 + tensorflow/contrib/cluster_resolver/BUILD | 47 ++++ tensorflow/contrib/cluster_resolver/README.md | 5 + .../cluster_resolver/python/training/__init__.py | 23 ++ .../python/training/cluster_resolver.py | 171 +++++++++++++++ .../python/training/cluster_resolver_test.py | 238 +++++++++++++++++++++ 7 files changed, 486 insertions(+) create mode 100644 tensorflow/contrib/cluster_resolver/BUILD create mode 100644 tensorflow/contrib/cluster_resolver/README.md create mode 100644 tensorflow/contrib/cluster_resolver/python/training/__init__.py create mode 100644 tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py create mode 100644 tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 1353f15ec4..53214088bb 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -223,6 +223,7 @@ filegroup( "//tensorflow/contrib/boosted_trees/resources:all_files", "//tensorflow/contrib/cloud:all_files", "//tensorflow/contrib/cloud/kernels:all_files", + "//tensorflow/contrib/cluster_resolver:all_files", "//tensorflow/contrib/compiler:all_files", "//tensorflow/contrib/copy_graph:all_files", "//tensorflow/contrib/crf:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index b4ff1da30f..2f4d500507 100755 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -16,6 +16,7 @@ py_library( "//tensorflow/contrib/batching:batch_py", "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/cloud:cloud_py", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD new file mode 100644 index 0000000000..34cdb2a132 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -0,0 +1,47 @@ +# Description: Operations defined for Cluster Resolvers + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +package( + default_visibility = [ + "//tensorflow:__subpackages__", + ], +) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) + +py_library( + name = "cluster_resolver_py", + srcs = [ + "python/training/__init__.py", + "python/training/cluster_resolver.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + ], +) + +tf_py_test( + name = "cluster_resolver_py_test", + srcs = ["python/training/cluster_resolver_test.py"], + additional_deps = [ + ":cluster_resolver_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + main = "python/training/cluster_resolver_test.py", +) diff --git a/tensorflow/contrib/cluster_resolver/README.md b/tensorflow/contrib/cluster_resolver/README.md new file mode 100644 index 0000000000..6fe6871eb4 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/README.md @@ -0,0 +1,5 @@ +# Cluster Resolvers + +Cluster Resolvers are a new way of specifying cluster information for distributed execution. Built on top of existing `ClusterSpec` framework, Cluster Resolvers allow users to simply specify a configuration and a cluster management service and a `ClusterResolver` will automatically fetch the relevant information from the service and populate `ClusterSpec`s. + +`ClusterResolvers` are designed to work well with `ManagedTrainingSession` and `ClusterSpec` propagation so that distributed training sessions remain robust in the face of node and network failures. diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py new file mode 100644 index 0000000000..3520467bc6 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library Imports for Cluster Resolvers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py new file mode 100644 index 0000000000..87da24f22d --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from tensorflow.python.training.server_lib import ClusterSpec + + +class ClusterResolver(object): + """Abstract class for all implementations of ClusterResolvers. + + This defines the skeleton for all implementations of ClusterResolvers. + ClusterResolvers are a way for TensorFlow to communicate with various cluster + management systems (e.g. GCE, AWS, etc...). + + By letting TensorFlow communicate with these systems, we will be able to + automatically discover and resolve IP addresses for various TensorFlow + workers. This will eventually allow us to automatically recover from + underlying machine failures and scale TensorFlow worker clusters up and down. + """ + + @abc.abstractmethod + def cluster_spec(self): + """Retrieve the current state of the cluster and returns a ClusterSpec. + + Returns: + A ClusterSpec representing the state of the cluster at the moment this + function is called. + + Implementors of this function must take care in ensuring that the + ClusterSpec returned is up-to-date at the time of calling this function. + This usually means retrieving the information from the underlying cluster + management system every time this function is invoked and reconstructing + a cluster_spec, rather than attempting to cache anything. + """ + raise NotImplementedError( + 'cluster_spec is not implemented for {}.'.format(self)) + + +class SimpleClusterResolver(ClusterResolver): + """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" + + def __init__(self, cluster_spec): + """Creates a SimpleClusterResolver from a ClusterSpec.""" + super(SimpleClusterResolver, self).__init__() + + if not isinstance(cluster_spec, ClusterSpec): + raise TypeError('cluster_spec must be a ClusterSpec.') + self._cluster_spec = cluster_spec + + def cluster_spec(self): + """Returns the ClusterSpec passed into the constructor.""" + return self._cluster_spec + + +class UnionClusterResolver(ClusterResolver): + """Performs a union on underlying ClusterResolvers. + + This class performs a union given two or more existing ClusterResolvers. It + merges the underlying ClusterResolvers, and returns one unified ClusterSpec + when as_cluster_spec is called. The details of the merge function is + documented in the as_cluster_spec function. + """ + + def __init__(self, *args): + """Initializes a UnionClusterResolver with other ClusterResolvers. + + Args: + *args: `ClusterResolver` objects to be unionized. + + Raises: + TypeError: If any argument is not a subclass of `ClusterResolvers`. + """ + super(UnionClusterResolver, self).__init__() + + for cluster_resolver in args: + if not isinstance(cluster_resolver, ClusterResolver): + raise TypeError('All arguments must be a sub-class of ' + '`ClusterResolver.`') + self._cluster_resolvers = args + + def cluster_spec(self): + """Returns a union of all the ClusterSpecs from the ClusterResolvers. + + Returns: + A ClusterSpec containing host information merged from all the underlying + ClusterResolvers. + + Raises: + KeyError: If there are conflicting keys detected when merging two or + more dictionaries, this exception is raised. + + Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the + same job name, we will merge the list/dict of workers. + + If *all* underlying ClusterSpecs expose the set of workers as lists, we will + concatenate the lists of workers, starting with the list of workers from + the first ClusterResolver passed into the constructor. + + If *any* of the ClusterSpecs expose the set of workers as a dict, we will + treat all the sets of workers as dicts (even if they are returned as lists) + and will only merge them into a dict if there is no conflicting keys. If + there is a conflicting key, we will raise a `KeyError`. + """ + + merged_cluster = {} + + # We figure out whether it is all lists for a particular job, or whether + # there are dicts inside. + for cluster_resolver in self._cluster_resolvers: + cluster_spec = cluster_resolver.cluster_spec() + cluster_dict = cluster_spec.as_dict() + + for job_name, tasks in cluster_dict.items(): + if job_name in merged_cluster: + # If we see a dict, then we write a dict out regardless. + if isinstance(tasks, dict): + merged_cluster[job_name] = {} + else: + # We take whichever type is present. + if isinstance(tasks, list): + merged_cluster[job_name] = [] + else: + merged_cluster[job_name] = {} + + # We then do the merge as appropriate in merged_cluster[job]. + for cluster_resolver in self._cluster_resolvers: + cluster_spec = cluster_resolver.cluster_spec() + cluster_dict = cluster_spec.as_dict() + + for job_name, tasks in cluster_dict.items(): + if isinstance(merged_cluster[job_name], list): + # We all have lists, we can just concatenate and be done. + merged_cluster[job_name].extend(tasks) + else: + if isinstance(tasks, list): + # We convert to a dictionary if the type is a list. + task_dict = dict(zip(range(0, len(tasks)), tasks)) + else: + # We can simply make a copy (for update) and be done. + task_dict = tasks.copy() + + # We detect if there are duplicates, and raise an error if so. + task_keys = set(task_dict) + merged_keys = set(merged_cluster[job_name].keys()) + intersected_keys = task_keys.intersection(merged_keys) + if intersected_keys: + raise KeyError('Duplicate keys detected when merging two ' + 'ClusterSpecs: %s' % repr(intersected_keys)) + + # We do the merge after all the processing. + merged_cluster[job_name].update(task_dict) + + return ClusterSpec(merged_cluster) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py new file mode 100644 index 0000000000..dbfb77723c --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -0,0 +1,238 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cluster Resolvers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + + +class UnionClusterResolverTest(test.TestCase): + # TODO(frankchn): Transform to parameterized test after it is included in the + # TF open source codebase. + + def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): + self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) + self.assertProtoEquals( + expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + self.assertProtoEquals( + expected_proto, + server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) + self.assertProtoEquals( + expected_proto, + server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + + def testSingleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + simple_resolver = SimpleClusterResolver(base_cluster_spec) + union_resolver = UnionClusterResolver(simple_resolver) + + expected_proto = """ + job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } + tasks { key: 1 value: 'ps1:2222' } } + job { name: 'worker' tasks { key: 0 value: 'worker0:2222' } + tasks { key: 1 value: 'worker1:2222' } + tasks { key: 2 value: 'worker2:2222' } } + """ + actual_cluster_spec = union_resolver.cluster_spec() + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testTwoNonOverlappingJobMergedClusterResolver(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "ps": [ + "ps0:2222", + "ps1:2222" + ] + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": [ + "worker0:2222", + "worker1:2222", + "worker2:2222" + ] + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } + tasks { key: 1 value: 'ps1:2222' } } + job { name: 'worker' tasks { key: 0 value: 'worker0:2222' } + tasks { key: 1 value: 'worker1:2222' } + tasks { key: 2 value: 'worker2:2222' } } + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + def testOverlappingJobMergedClusterResolver(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": [ + "worker4:2222", + "worker5:2222" + ] + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": [ + "worker0:2222", + "worker1:2222", + "worker2:2222" + ] + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'worker' tasks { key: 0 value: 'worker4:2222' } + tasks { key: 1 value: 'worker5:2222' } + tasks { key: 2 value: 'worker0:2222' } + tasks { key: 3 value: 'worker1:2222' } + tasks { key: 4 value: 'worker2:2222' } } + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + def testOverlappingSparseJobMergedClusterResolverThrowError(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": { + 7: "worker4:2222", + 9: "worker5:2222" + } + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": { + 3: "worker0:2222", + 6: "worker1:2222", + 7: "worker2:2222" + } + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + self.assertRaises(KeyError, union_cluster.cluster_spec) + + def testOverlappingDictAndListThrowError(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": [ + "worker4:2222", + "worker5:2222" + ] + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": { + 1: "worker0:2222", + 2: "worker1:2222", + 3: "worker2:2222" + } + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + self.assertRaises(KeyError, union_cluster.cluster_spec) + + def testOverlappingJobNonOverlappingKey(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": { + 5: "worker4:2222", + 9: "worker5:2222" + } + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": { + 3: "worker0:2222", + 6: "worker1:2222", + 7: "worker2:2222" + } + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'worker' tasks { key: 3 value: 'worker0:2222' } + tasks { key: 5 value: 'worker4:2222' } + tasks { key: 6 value: 'worker1:2222' } + tasks { key: 7 value: 'worker2:2222' } + tasks { key: 9 value: 'worker5:2222' }} + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + def testMixedModeNonOverlappingKey(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "worker": [ + "worker4:2222", + "worker5:2222" + ] + }) + cluster_spec_2 = server_lib.ClusterSpec({ + "worker": { + 3: "worker0:2222", + 6: "worker1:2222", + 7: "worker2:2222" + } + }) + cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) + cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) + + union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'worker' tasks { key: 0 value: 'worker4:2222' } + tasks { key: 1 value: 'worker5:2222' } + tasks { key: 3 value: 'worker0:2222' } + tasks { key: 6 value: 'worker1:2222' } + tasks { key: 7 value: 'worker2:2222' }} + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + def testRetainSparseJobWithNoMerging(self): + base_cluster_spec = server_lib.ClusterSpec({ + "worker": { + 1: "worker0:2222", + 3: "worker1:2222", + 5: "worker2:2222" + } + }) + + base_cluster_resolver = SimpleClusterResolver(base_cluster_spec) + union_cluster = UnionClusterResolver(base_cluster_resolver) + cluster_spec = union_cluster.cluster_spec() + + expected_proto = """ + job { name: 'worker' tasks { key: 1 value: 'worker0:2222' } + tasks { key: 3 value: 'worker1:2222' } + tasks { key: 5 value: 'worker2:2222' } } + """ + self._verifyClusterSpecEquality(cluster_spec, expected_proto) + + +if __name__ == "__main__": + test.main() -- cgit v1.2.3 From 1d68f729b9f62b50a407cd6a99bfa57be494f260 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 8 Jun 2017 07:25:20 -0700 Subject: Remove unneeded BUILD dependency PiperOrigin-RevId: 158391996 --- tensorflow/cc/BUILD | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 71f375d048..fbc96685c8 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -454,7 +454,6 @@ cc_library( ":client_session", ":ops", ":scope", - "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:lib_internal", "//tensorflow/core:tensorflow", @@ -479,7 +478,7 @@ cc_binary( ], deps = [ ":cc_ops", - "//tensorflow/core:all_kernels", + "//tensorflow/core:all_kernels", # buildcleaner: keep "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", -- cgit v1.2.3 From 50d80ddf926423c16864f886a4fd2297d7725da1 Mon Sep 17 00:00:00 2001 From: Jonathan Hseu Date: Thu, 8 Jun 2017 11:58:21 -0700 Subject: Fix fft_ops_test.py for CPU --- tensorflow/python/kernel_tests/fft_ops_test.py | 61 +++++++++++++------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py index 1296dfdda2..546e7a296d 100644 --- a/tensorflow/python/kernel_tests/fft_ops_test.py +++ b/tensorflow/python/kernel_tests/fft_ops_test.py @@ -298,36 +298,37 @@ class RFFTOpsTest(BaseFFTOpsTest): use_placeholder=True) def testFftLength(self): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - for size in (5, 6): - inner_dim = size // 2 + 1 - r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( - (size,) * dims) - c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), - 10).reshape((size,) * (dims - 1) + (inner_dim,)) - - # Test truncation (FFT size < dimensions). - fft_length = (size - 2,) * rank - self._CompareForward(r2c.astype(np.float32), rank, fft_length) - self._CompareBackward(c2r.astype(np.complex64), rank, fft_length) - - # Confirm it works with unknown shapes as well. - self._CompareForward(r2c.astype(np.float32), rank, fft_length, - use_placeholder=True) - self._CompareBackward(c2r.astype(np.complex64), rank, fft_length, - use_placeholder=True) - - # Test padding (FFT size > dimensions). - fft_length = (size + 2,) * rank - self._CompareForward(r2c.astype(np.float32), rank, fft_length) - self._CompareBackward(c2r.astype(np.complex64), rank, fft_length) - - # Confirm it works with unknown shapes as well. - self._CompareForward(r2c.astype(np.float32), rank, fft_length, - use_placeholder=True) - self._CompareBackward(c2r.astype(np.complex64), rank, fft_length, - use_placeholder=True) + if test.is_gpu_available(cuda_only=True): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + for size in (5, 6): + inner_dim = size // 2 + 1 + r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( + (size,) * dims) + c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), + 10).reshape((size,) * (dims - 1) + (inner_dim,)) + + # Test truncation (FFT size < dimensions). + fft_length = (size - 2,) * rank + self._CompareForward(r2c.astype(np.float32), rank, fft_length) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length) + + # Confirm it works with unknown shapes as well. + self._CompareForward(r2c.astype(np.float32), rank, fft_length, + use_placeholder=True) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length, + use_placeholder=True) + + # Test padding (FFT size > dimensions). + fft_length = (size + 2,) * rank + self._CompareForward(r2c.astype(np.float32), rank, fft_length) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length) + + # Confirm it works with unknown shapes as well. + self._CompareForward(r2c.astype(np.float32), rank, fft_length, + use_placeholder=True) + self._CompareBackward(c2r.astype(np.complex64), rank, fft_length, + use_placeholder=True) def testRandom(self): np.random.seed(12345) -- cgit v1.2.3