aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/c_api.cc24
-rw-r--r--tensorflow/c/eager/c_api.h8
-rw-r--r--tensorflow/c/eager/c_api_internal.h6
-rw-r--r--tensorflow/c/eager/c_api_test.cc6
-rw-r--r--tensorflow/compiler/jit/BUILD7
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc4
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc9
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc207
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h11
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc60
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc2
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/contrib/bayesflow/BUILD2
-rw-r--r--tensorflow/contrib/distributions/BUILD34
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py80
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py153
-rw-r--r--tensorflow/contrib/distributions/python/ops/kumaraswamy.py89
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils.py224
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils_test.py196
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist.py99
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py16
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50.py153
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py6
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py61
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h8
-rw-r--r--tensorflow/contrib/lite/kernels/mean_test.cc72
-rw-r--r--tensorflow/contrib/lite/testing/BUILD5
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc3
-rw-r--r--tensorflow/contrib/opt/BUILD3
-rw-r--r--tensorflow/contrib/py2tf/impl/conversion.py22
-rw-r--r--tensorflow/contrib/py2tf/pyct/inspect_utils.py121
-rw-r--r--tensorflow/contrib/py2tf/pyct/inspect_utils_test.py107
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py2
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py1
-rw-r--r--tensorflow/core/common_runtime/function.cc11
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc2
-rw-r--r--tensorflow/core/framework/function.cc8
-rw-r--r--tensorflow/core/framework/function.h10
-rw-r--r--tensorflow/core/graph/control_flow.cc11
-rw-r--r--tensorflow/core/graph/control_flow.h16
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD56
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer.h35
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc61
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h65
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc87
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc77
-rw-r--r--tensorflow/core/kernels/function_ops.cc34
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h4
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto3
-rw-r--r--tensorflow/docs_src/get_started/datasets_quickstart.md3
-rw-r--r--tensorflow/docs_src/install/install_mac.md8
-rw-r--r--tensorflow/python/eager/context.py3
-rw-r--r--tensorflow/python/eager/core_test.py16
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc89
-rw-r--r--tensorflow/python/framework/test_util.py25
-rw-r--r--tensorflow/python/grappler/hierarchical_controller.py12
-rw-r--r--tensorflow/python/grappler/tf_optimizer.i1
-rwxr-xr-xtensorflow/python/keras/BUILD32
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/densenet.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/inception_v3.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/mobilenet.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/nasnet.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/resnet50.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/vgg16.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/vgg19.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/xception.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/__init__.py15
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py504
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/input_layer.py230
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/network.py (renamed from tensorflow/python/keras/_impl/keras/engine/topology.py)1059
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/saving.py671
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/saving_test.py375
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/sequential.py997
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/sequential_test.py152
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology_test.py169
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/advanced_activations.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/embeddings.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/local.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/merge.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/noise.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/models.py1325
-rw-r--r--tensorflow/python/keras/_impl/keras/models_test.py348
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/generic_utils.py17
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py12
-rw-r--r--tensorflow/python/ops/array_ops.py12
-rw-r--r--tensorflow/python/ops/math_grad.py6
-rw-r--r--tensorflow/python/training/optimizer.py3
-rw-r--r--tensorflow/python/training/saver.py10
-rw-r--r--tensorflow/python/training/saver_test.py88
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt6
-rw-r--r--tensorflow/workspace.bzl11
-rw-r--r--third_party/cub/BUILD0
-rw-r--r--third_party/cub/fix_compilation_in_clang.patch23
201 files changed, 5237 insertions, 3571 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 98ef6f0d0a..c27a7129fa 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -154,16 +154,22 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->t.dtype());
}
-int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
+int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
+ status->status = tensorflow::Status::OK();
+ return h->t.dims();
+}
-int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
+int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
+ TF_Status* status) {
+ status->status = tensorflow::Status::OK();
return h->t.dim_size(dim_index);
}
-const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
+const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
// TODO(apassos) this will be potentially incorrect in the distributed case as
// our local device will have a name which depends on the ClusterSpec and
// hence will require the context to resolve.
+ status->status = tensorflow::Status::OK();
return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: h->d->name().c_str();
}
@@ -297,11 +303,9 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
// Questionable heuristic ...
- //
- // Motivation: After an 'op' is placed on GPU because some of its earlier
- // inputs are on GPU, we want to keep the 'op' there, even if some later
- // inputs of it are not on GPU.
- if (IsCPU(op->device) && !IsCPU(h->d)) {
+ // - If a device was explicitly set on the op, always use that.
+ // - If not, place on the first non-host device seen.
+ if (op->device == nullptr && !IsCPU(h->d)) {
op->device = h->d;
}
if (!status->status.ok()) return;
@@ -802,6 +806,10 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
+ if (ctx->log_device_placement) {
+ LOG(INFO) << "Executing op " << ndef.op() << " in device "
+ << device->name();
+ }
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
// Knowledge of the implementation of Init (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 7a321b54da..90cfb7500e 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -119,11 +119,13 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
-TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
+TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
+ TF_Status* status);
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
- int dim_index);
+ int dim_index,
+ TF_Status* status);
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
- TFE_TensorHandle* h);
+ TFE_TensorHandle* h, TF_Status* status);
TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
TF_Status* status);
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 7b9f1db02e..3356054cd0 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -50,7 +50,9 @@ struct TFE_Context {
rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
pflr(new tensorflow::ProcessFunctionLibraryRuntime(
session->device_mgr, opts.session_options.options.env,
- TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {}
+ TF_GRAPH_DEF_VERSION, &func_lib_def, {})),
+ log_device_placement(
+ opts.session_options.options.config.log_device_placement()) {}
const TFE_ContextDevicePlacementPolicy policy;
@@ -88,6 +90,8 @@ struct TFE_Context {
std::atomic<bool> should_store_metadata{false};
tensorflow::mutex metadata_mu;
tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu);
+
+ const bool log_device_placement;
};
struct TFE_TensorHandle {
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 4a3ecbc0ab..00fb7e68d0 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -932,7 +932,8 @@ TEST(CAPI, Variables) {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
- EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle));
+ EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle, status));
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float value = 0.0f;
TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@@ -974,7 +975,8 @@ void BM_ReadVariable(int iters) {
CHECK_EQ(1, num_retvals);
CHECK(h);
CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
- CHECK_EQ(0, TFE_TensorHandleNumDims(h));
+ CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
h = nullptr;
}
tensorflow::testing::StopTiming();
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index a711319607..af259e0564 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -102,12 +102,17 @@ cc_library(
cc_library(
name = "xla_interpreter_device",
srcs = ["xla_interpreter_device.cc"],
+ visibility = [":friends"],
deps = [
+ ":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
+ "//tensorflow/core:lib",
],
- alwayslink = True,
+ alwayslink = 1,
)
cc_library(
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 2614deefd8..a329451b14 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -25,8 +25,8 @@ namespace tensorflow {
const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER";
const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT";
-constexpr std::array<DataType, 5> kExecAllTypes = {
- {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}};
+constexpr std::array<DataType, 6> kExecAllTypes = {
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 37ca1b893a..e6a6e54927 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -126,7 +126,9 @@ cc_library(
":bfloat16_support",
":hlo",
":hlo_dataflow_analysis",
+ ":hlo_dce",
":hlo_pass",
+ ":tuple_simplifier",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4391462c1c..5ddd8ec377 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -122,6 +122,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleBitcastConvert(HloInstruction* bitcast) override;
+
Status HandleBroadcast(HloInstruction* broadcast) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
@@ -411,6 +413,13 @@ Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
return Status::OK();
}
+Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
+ HloInstruction* bitcast) {
+ // Eliminate bitcast converts between same shape.
+ ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
// If a copy feeds a copy, make it a single copy.
if (copy->operand(0)->opcode() == HloOpcode::kCopy) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 9246cb25d2..6145c690b9 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -17,8 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -229,55 +231,10 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
return true;
}
-// The algorithm first does a forward pass (parameters to root) to determine a
-// set of instructions to consider using bfloat16, then does a backward pass to
-// determine the precisions of those instructions according to the need of
-// their users.
-StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
- TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
-
+Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
+ HloModule* module) {
std::list<HloComputation*> computations_topological_order =
module->MakeComputationPostOrder();
- // The first step is a forward pass (parameters to root), where we determine
- // the potential candidate instructions to use bfloat16 in the outputs that
- // are not likely to cause overhead from extra explicit conversions. This is
- // done forwardly because we determine whether an HLO is a candidate partially
- // based on whether its operands are candidates.
- for (auto computation : computations_topological_order) {
- for (auto inst : computation->MakeInstructionPostOrder()) {
- if (InstructionIsCandidateForBF16Output(inst)) {
- consider_using_bfloat16_.insert(inst);
- }
- }
- }
-
- // The second step is a backward pass (root to parameters), where we modify
- // the precisions of the instructions identified in the first step when
- // feasible. This is done backwardly because we determine the precision of an
- // HLO's output based on how it is later used.
- //
- // The precision of an instruction is determined by its users, so we do the
- // propagation in reverse topological order.
- for (auto comp_it = computations_topological_order.rbegin();
- comp_it != computations_topological_order.rend(); ++comp_it) {
- if ((*comp_it)->IsFusionComputation()) {
- // Fusion computations are handled when visiting the fusion instruction.
- continue;
- }
- auto insts = (*comp_it)->MakeInstructionPostOrder();
- for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
- DetermineAndMutateInstructionPrecision(*inst_it,
- /*skip_parameters=*/true);
- }
- }
-
- if (!changed_) {
- return false;
- }
-
- // It's possible that an instruction does not define a buffer, but the
- // defining instruction's shape has changed. So we need to adjust the output
- // shapes of instructions according to the HLO values they refer to.
for (auto comp_it = computations_topological_order.rbegin();
comp_it != computations_topological_order.rend(); ++comp_it) {
auto insts = (*comp_it)->MakeInstructionPostOrder();
@@ -328,6 +285,162 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
}
}
}
+
+ // We could have changed a fusion computation's root shape to have a different
+ // precision than the fusion node's output, if the fusion root does not
+ // define a buffer (e.g., a tuple). Now we add conversions after such fusion
+ // roots to make them match the fusion output. If the fusion output is a
+ // (possibly nested) tuple, we first create get-tuple-elements, then convert
+ // the unmatching leaf nodes, and finally create a new tuple as the fusion
+ // computation's root. If tuples and get-tuple-elements are created, we will
+ // run tuple simplifier and dead code elimination at the end (dead code is not
+ // allowed in fusion computation). E.g.,
+ //
+ // (1) (2) (3)
+ // a b a b a b
+ // |\ | |\ | |\ |
+ // \ add -> |add -> | add
+ // \ | \ | convert |
+ // tuple tuple \ |
+ // / \ tuple
+ // gte gte
+ // | |
+ // convert |
+ // \ /
+ // tuple
+ // (1) a is F32 but tuple is BF16
+ // (2) after adding conversion
+ // (3) after tuple simplifier and DCE.
+ bool needs_tuple_simplifier = false;
+ for (auto computation : computations_topological_order) {
+ auto insts = computation->MakeInstructionPostOrder();
+ for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+ auto hlo = *inst_it;
+ if (hlo->opcode() != HloOpcode::kFusion) {
+ continue;
+ }
+ auto fusion_computation = hlo->fused_instructions_computation();
+ auto fusion_root = fusion_computation->root_instruction();
+ if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
+ continue;
+ }
+ ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
+ // Iterate through nodes in the shape tree in pre-order and initialize
+ // each non-root node with a corresponding get-tuple-element. For a leaf
+ // node, if its shape does not match the fusion output, create a
+ // conversion node to overwrite the node value.
+ for (auto it = converted_outputs.begin(); it != converted_outputs.end();
+ ++it) {
+ ShapeIndex output_index = it->first;
+ HloInstruction*& output = it->second;
+ const Shape subshape =
+ ShapeUtil::GetSubshape(hlo->shape(), output_index);
+ if (output_index.empty()) {
+ output = fusion_root;
+ } else {
+ ShapeIndex parent_index = output_index;
+ parent_index.pop_back();
+ output = fusion_computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ subshape, converted_outputs.element(parent_index),
+ output_index.back()));
+ }
+ if (ShapeUtil::IsTuple(subshape)) {
+ continue;
+ }
+ if (!ShapeUtil::Compatible(
+ subshape,
+ ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) {
+ output = fusion_computation->AddInstruction(
+ HloInstruction::CreateConvert(subshape, output));
+ }
+ }
+ // Iterate through nodes in the shape tree in reverse pre-order and create
+ // a tuple instruction for each non-leaf node where the elements are the
+ // values of its child nodes.
+ for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend();
+ ++it) {
+ ShapeIndex output_index = it->first;
+ HloInstruction*& output = it->second;
+ const Shape& subshape =
+ ShapeUtil::GetSubshape(hlo->shape(), output_index);
+ if (!ShapeUtil::IsTuple(subshape)) {
+ continue;
+ }
+ std::vector<HloInstruction*> elements(
+ ShapeUtil::TupleElementCount(subshape));
+ ShapeIndex child_index = output_index;
+ for (int64 i = 0; i < elements.size(); ++i) {
+ child_index.push_back(i);
+ elements[i] = converted_outputs.element(child_index);
+ child_index.pop_back();
+ }
+ output = fusion_computation->AddInstruction(
+ HloInstruction::CreateTuple(elements));
+ }
+ fusion_computation->set_root_instruction(converted_outputs.element({}));
+ needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape());
+ }
+ }
+ if (needs_tuple_simplifier) {
+ TupleSimplifier tuple_simplifier;
+ TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
+ HloDCE dce;
+ TF_RETURN_IF_ERROR(dce.Run(module).status());
+ }
+ return Status::OK();
+}
+
+// The algorithm first does a forward pass (parameters to root) to determine a
+// set of instructions to consider using bfloat16, then does a backward pass to
+// determine the precisions of those instructions according to the need of
+// their users.
+StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
+
+ std::list<HloComputation*> computations_topological_order =
+ module->MakeComputationPostOrder();
+ // The first step is a forward pass (parameters to root), where we determine
+ // the potential candidate instructions to use bfloat16 in the outputs that
+ // are not likely to cause overhead from extra explicit conversions. This is
+ // done forwardly because we determine whether an HLO is a candidate partially
+ // based on whether its operands are candidates.
+ for (auto computation : computations_topological_order) {
+ for (auto inst : computation->MakeInstructionPostOrder()) {
+ if (InstructionIsCandidateForBF16Output(inst)) {
+ consider_using_bfloat16_.insert(inst);
+ }
+ }
+ }
+
+ // The second step is a backward pass (root to parameters), where we modify
+ // the precisions of the instructions identified in the first step when
+ // feasible. This is done backwardly because we determine the precision of an
+ // HLO's output based on how it is later used.
+ //
+ // The precision of an instruction is determined by its users, so we do the
+ // propagation in reverse topological order.
+ for (auto comp_it = computations_topological_order.rbegin();
+ comp_it != computations_topological_order.rend(); ++comp_it) {
+ if ((*comp_it)->IsFusionComputation()) {
+ // Fusion computations are handled when visiting the fusion instruction.
+ continue;
+ }
+ auto insts = (*comp_it)->MakeInstructionPostOrder();
+ for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+ DetermineAndMutateInstructionPrecision(*inst_it,
+ /*skip_parameters=*/true);
+ }
+ }
+
+ if (!changed_) {
+ return false;
+ }
+
+ // It's possible that an instruction does not define a buffer, but the
+ // defining instruction's shape has changed. So we need to adjust the output
+ // shapes of instructions according to the HLO values they refer to.
+ TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module));
return true;
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index aa81dde3b0..ccf77d7b4e 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -94,10 +94,21 @@ class BFloat16Propagation : public HloPassInterface {
// Special handling in the mutation pass for fusion computations.
void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion);
+ // ***************************
+ // Functions called by the final inconsistency resolving pass.
+
+ // Adjusts the output shapes of HloInstructions such that if two
+ // HloInstructions have aliasing buffers in their outputs, they must have the
+ // same precision.
+ Status ResolveInconsistencyOfAliasingBuffers(HloModule* module);
+
// Makes the fusion parameters match the precision of the actual parameters
// passed to the fusion node.
void AdjustFusionParameters(HloInstruction* fusion);
+ // ***************************
+ // Functions called and state used by two or more passes.
+
// Returns whether all uses of the given HloInstruction can consume BF16
// input.
bool AllUsersConsumeBF16(const HloInstruction& hlo,
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 4c86c6b26e..2047e2053a 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -68,7 +68,7 @@ class BFloat16PropagationTest : public HloTestBase {
// Returns whether the given HloInstruction's output element type is BF16 or
// the only use of it is converting to BF16.
- bool OutputsBF16(HloInstruction* inst) {
+ bool OutputsBF16(const HloInstruction* inst) {
if (inst->shape().element_type() == BF16) {
return true;
}
@@ -287,6 +287,64 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
EXPECT_TRUE(OutputsBF16(b_f1));
}
+// Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion
+// outputs are only used by a dot, and 3) one element of the tuple is used by
+// an add in the fusion computation, then the propagation pass should create a
+// convert in the fusion computation to keep the add's operand in F32 but change
+// the fusion output to BF16. E.g., the following fusion computation
+// (F32, F32) fusion_computation(F32 a, F32 b)
+// = tuple(F32 a, F32 add(F32 a, F32 b))
+// will be changed to
+// (BF16, BF16) fusion_computation(F32 a, F32 b)
+// = tuple(BF16 convert(a), BF16 add(F32 a, F32 b))
+TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) {
+ auto module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
+
+ auto builder_f = HloComputation::Builder("fusion0");
+ HloInstruction* a_f =
+ builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
+ HloInstruction* b_f =
+ builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
+ HloInstruction* add_f = builder_f.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
+ HloInstruction* tuple_f =
+ builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f}));
+ auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
+ auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
+ tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add},
+ comp_f));
+
+ HloInstruction* gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, fusion, 0));
+ HloInstruction* gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, fusion, 1));
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1));
+
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), dot);
+ EXPECT_TRUE(OutputsBF16(gte0));
+ EXPECT_TRUE(OutputsBF16(gte1));
+ EXPECT_FALSE(OutputsBF16(a_f));
+ EXPECT_FALSE(OutputsBF16(b_f));
+ EXPECT_TRUE(OutputsBF16(add_f));
+ auto new_fusion_root = comp_f->root_instruction();
+ EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple);
+ EXPECT_EQ(new_fusion_root->operand(1), add_f);
+ EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert);
+ EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0)));
+}
+
// A select over tuples does not define the leaf buffers, so the types in
// on_true and on_false must match, so that as long as one of them is F32, the
// other must be F32 as well.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 623d6714de..04b37d913e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -46,12 +46,14 @@ namespace {
class HloExecutionProfiler {
public:
// If profiling is enabled, start an execution timer running.
- explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile,
- se::Stream* stream,
- const HloComputation* computation)
+ explicit HloExecutionProfiler(
+ bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
+ const HloComputation* computation)
: do_profile_(do_profile),
profile_(profile),
stream_(stream),
+ sub_streams_(sub_streams),
computation_(computation) {
if (do_profile_) {
clock_rate_ghz_ =
@@ -70,6 +72,7 @@ class HloExecutionProfiler {
CHECK(!finished_execution_) << "Call FinishExecution only once!";
finished_execution_ = true;
if (do_profile_) {
+ stream_->ThenWaitFor(&sub_streams_);
stream_->ThenStopTimer(execution_timer_.get());
stream_->BlockHostUntilDone().IgnoreError();
profile_->set_total_cycles_executed(
@@ -88,6 +91,7 @@ class HloExecutionProfiler {
// that the hlo_instruction took to execute in the profile.
void FinishOperation(const HloInstruction* hlo_instruction) {
if (do_profile_) {
+ stream_->ThenWaitFor(&sub_streams_);
stream_->ThenStopTimer(per_op_timer_.get());
stream_->BlockHostUntilDone().IgnoreError();
profile_->SetCyclesTakenBy(
@@ -100,6 +104,7 @@ class HloExecutionProfiler {
double clock_rate_ghz_;
HloExecutionProfile* profile_;
se::Stream* stream_;
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams_;
const HloComputation* computation_;
std::unique_ptr<se::Timer> execution_timer_;
std::unique_ptr<se::Timer> per_op_timer_;
@@ -147,13 +152,9 @@ Status GpuExecutable::ExecuteThunks(
LOG(WARNING) << "PROFILING: profiling is enabled";
}
- HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream,
- hlo_module_->entry_computation());
-
- uint64 start_micros = tensorflow::Env::Default()->NowMicros();
-
// Stream 0 indicates `main_stream` and substreams start from stream 1.
std::vector<Pool<se::Stream>::SmartPtr> sub_streams;
+ sub_streams.reserve(thunk_schedule_->StreamCount() - 1);
while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) {
sub_streams.emplace_back();
TF_ASSIGN_OR_RETURN(
@@ -161,6 +162,10 @@ Status GpuExecutable::ExecuteThunks(
run_options->BorrowStream(main_stream->parent()->device_ordinal()));
}
+ HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream,
+ sub_streams, hlo_module_->entry_computation());
+ uint64 start_micros = tensorflow::Env::Default()->NowMicros();
+
// The next event enqueued on stream N must not run until the thunk at
// last_blocking_thunk_for_stream[N] completes.
std::map<int32, const Thunk*> last_blocking_thunk_for_stream;
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
index 1e5f0f797a..fcd723af14 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -40,7 +40,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
VLOG(2) << "Before dce:";
XLA_VLOG_LINES(2, module->ToString());
- for (auto* computation : module->MakeNonfusionComputations()) {
+ for (auto* computation : module->MakeComputationPostOrder()) {
std::unordered_set<HloInstruction*> live_instructions;
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
[&live_instructions](HloInstruction* instruction) {
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 0cb9b5d810..883063d0f0 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -93,7 +93,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> InterpreterExecutable::ExecuteOnStream(
TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> result,
transfer_manager->AllocateShapedBuffer(
result_literal->shape(), run_options->allocator(),
- run_options->device_ordinal()));
+ executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
executor, *result_literal, *result));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 1958e5abf6..97abf217d7 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1011,6 +1011,7 @@ xla_test(
shard_count = 40,
tags = [
"enable_for_xla_interpreter",
+ "optonly",
],
deps = [
"//tensorflow/compiler/xla:array2d",
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index d7beb26e1b..08b29fb6bc 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -39,7 +39,7 @@ py_library(
cuda_py_test(
name = "metropolis_hastings_test",
- size = "medium",
+ size = "large",
srcs = ["python/kernel_tests/metropolis_hastings_test.py"],
additional_deps = [
":bayesflow_py",
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 35dd2ee439..ed79ef70f8 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -252,6 +252,21 @@ cuda_py_test(
)
cuda_py_test(
+ name = "kumaraswamy_test",
+ srcs = ["python/kernel_tests/kumaraswamy_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "moving_stats_test",
size = "small",
srcs = ["python/kernel_tests/moving_stats_test.py"],
@@ -916,6 +931,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "kumaraswamy_bijector_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/kumaraswamy_bijector_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "masked_autoregressive_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/masked_autoregressive_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
new file mode 100644
index 0000000000..ad11d9f248
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Kumaraswamy Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import Kumaraswamy
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
+from tensorflow.python.platform import test
+
+
+class KumaraswamyBijectorTest(test.TestCase):
+ """Tests correctness of the Kumaraswamy bijector."""
+
+ def testBijector(self):
+ with self.test_session():
+ a = 2.
+ b = 0.3
+ bijector = Kumaraswamy(
+ concentration1=a, concentration0=b,
+ event_ndims=0, validate_args=True)
+ self.assertEqual("kumaraswamy", bijector.name)
+ x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32)
+ # Kumaraswamy cdf. This is the same as inverse(x).
+ y = 1. - (1. - x ** a) ** b
+ self.assertAllClose(y, bijector.inverse(x).eval())
+ self.assertAllClose(x, bijector.forward(y).eval())
+ kumaraswamy_log_pdf = (np.log(a) + np.log(b) + (a - 1) * np.log(x) +
+ (b - 1) * np.log1p(-x ** a))
+
+ self.assertAllClose(
+ # We should lose a dimension from calculating the determinant of the
+ # jacobian.
+ kumaraswamy_log_pdf,
+ bijector.inverse_log_det_jacobian(x).eval())
+ self.assertAllClose(
+ -bijector.inverse_log_det_jacobian(x).eval(),
+ bijector.forward_log_det_jacobian(y).eval(),
+ rtol=1e-4,
+ atol=0.)
+
+ def testScalarCongruency(self):
+ with self.test_session():
+ assert_scalar_congruency(
+ Kumaraswamy(concentration1=0.5, concentration0=1.1),
+ lower_x=0., upper_x=1., n=int(10e3), rtol=0.02)
+
+ def testBijectiveAndFinite(self):
+ with self.test_session():
+ concentration1 = 1.2
+ concentration0 = 2.
+ bijector = Kumaraswamy(
+ concentration1=concentration1,
+ concentration0=concentration0, validate_args=True)
+ # Omitting the endpoints 0 and 1, since idlj will be inifinity at these
+ # endpoints.
+ y = np.linspace(.01, 0.99, num=10).astype(np.float32)
+ x = 1 - (1 - y ** concentration1) ** concentration0
+ assert_bijective_and_finite(bijector, x, y, rtol=1e-3)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py
index ea3c86b5c0..2980e2bfe9 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py
@@ -130,10 +130,8 @@ class KumaraswamyTest(test.TestCase):
dist.prob([.1, .3, .6]).eval()
dist.prob([.2, .3, .5]).eval()
# Either condition can trigger.
- with self.assertRaisesOpError("sample must be positive"):
+ with self.assertRaisesOpError("sample must be non-negative"):
dist.prob([-1., 0.1, 0.5]).eval()
- with self.assertRaisesOpError("sample must be positive"):
- dist.prob([0., 0.1, 0.5]).eval()
with self.assertRaisesOpError("sample must be no larger than `1`"):
dist.prob([.1, .2, 1.2]).eval()
@@ -249,13 +247,13 @@ class KumaraswamyTest(test.TestCase):
a = np.array([1., 2, 3])
b = np.array([2., 4, 1.2])
dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
+ with self.assertRaisesOpError("Mode undefined for concentration1 <= 1."):
dist.mode().eval()
a = np.array([2., 2, 3])
b = np.array([1., 4, 1.2])
dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
+ with self.assertRaisesOpError("Mode undefined for concentration0 <= 1."):
dist.mode().eval()
def testKumaraswamyModeEnableAllowNanStats(self):
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index 93923c3f08..9437f56b1e 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -26,6 +26,7 @@
@@Identity
@@Inline
@@Invert
+@@Kumaraswamy
@@MaskedAutoregressiveFlow
@@Permute
@@PowerTransform
@@ -59,6 +60,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
+from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import *
from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import *
from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py
new file mode 100644
index 0000000000..f5de052c9e
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py
@@ -0,0 +1,153 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kumaraswamy bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+__all__ = [
+ "Kumaraswamy",
+]
+
+
+class Kumaraswamy(bijector.Bijector):
+ """Compute `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a), X in [0, 1]`.
+
+ This bijector maps inputs from `[0, 1]` to [0, 1]`. The inverse of the
+ bijector applied to a uniform random variable `X ~ U(0, 1) gives back a
+ random variable with the [Kumaraswamy distribution](
+ https://en.wikipedia.org/wiki/Kumaraswamy_distribution):
+
+ ```none
+ Y ~ Kumaraswamy(a, b)
+ pdf(y; a, b, 0 <= y <= 1) = a * b * y ** (a - 1) * (1 - y**a) ** (b - 1)
+ ```
+ """
+
+ def __init__(self,
+ concentration1=None,
+ concentration0=None,
+ event_ndims=0,
+ validate_args=False,
+ name="kumaraswamy"):
+ """Instantiates the `Kumaraswamy` bijector.
+
+ Args:
+ concentration1: Python `float` scalar indicating the transform power,
+ i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `a` is
+ `concentration1`.
+ concentration0: Python `float` scalar indicating the transform power,
+ i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is
+ `concentration0`.
+ event_ndims: Python scalar indicating the number of dimensions associated
+ with a particular draw from the distribution. Currently only zero is
+ supported.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+
+ Raises:
+ ValueError: If `event_ndims` is not zero.
+ """
+ self._graph_parents = []
+ self._name = name
+ self._validate_args = validate_args
+
+ event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
+ event_ndims_const = tensor_util.constant_value(event_ndims)
+ if event_ndims_const is not None and event_ndims_const not in (0,):
+ raise ValueError("event_ndims(%s) was not 0" % event_ndims_const)
+ else:
+ if validate_args:
+ event_ndims = control_flow_ops.with_dependencies(
+ [check_ops.assert_equal(
+ event_ndims, 0, message="event_ndims was not 0")],
+ event_ndims)
+
+ with self._name_scope("init", values=[concentration1, concentration0]):
+ concentration1 = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration1, name="concentration1"),
+ validate_args=validate_args)
+ concentration0 = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration0, name="concentration0"),
+ validate_args=validate_args)
+
+ self._concentration1 = concentration1
+ self._concentration0 = concentration0
+ super(Kumaraswamy, self).__init__(
+ event_ndims=0,
+ validate_args=validate_args,
+ name=name)
+
+ @property
+ def concentration1(self):
+ """The `a` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`."""
+ return self._concentration1
+
+ @property
+ def concentration0(self):
+ """The `b` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`."""
+ return self._concentration0
+
+ def _forward(self, x):
+ x = self._maybe_assert_valid(x)
+ return math_ops.exp(
+ math_ops.log1p(-math_ops.exp(math_ops.log1p(-x) / self.concentration0))
+ / self.concentration1)
+
+ def _inverse(self, y):
+ y = self._maybe_assert_valid(y)
+ return math_ops.exp(math_ops.log1p(
+ -(1 - y**self.concentration1)**self.concentration0))
+
+ def _inverse_log_det_jacobian(self, y):
+ y = self._maybe_assert_valid(y)
+ event_dims = self._event_dims_tensor(y)
+ return math_ops.reduce_sum(
+ math_ops.log(self.concentration1) + math_ops.log(self.concentration0) +
+ (self.concentration1 - 1) * math_ops.log(y) +
+ (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1),
+ axis=event_dims)
+
+ def _maybe_assert_valid_concentration(self, concentration, validate_args):
+ """Checks the validity of a concentration parameter."""
+ if not validate_args:
+ return concentration
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ concentration,
+ message="Concentration parameter must be positive."),
+ ], concentration)
+
+ def _maybe_assert_valid(self, x):
+ if not self.validate_args:
+ return x
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_non_negative(
+ x,
+ message="sample must be non-negative"),
+ check_ops.assert_less_equal(
+ x, array_ops.ones([], self.concentration0.dtype),
+ message="sample must be no larger than `1`."),
+ ], x)
diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
index 74d5d8773c..120b38db3c 100644
--- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
+++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
@@ -20,15 +20,17 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
-from tensorflow.python.ops.distributions import beta
from tensorflow.python.ops.distributions import distribution
-from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.ops.distributions import transformed_distribution
+from tensorflow.python.ops.distributions import uniform
from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -60,7 +62,7 @@ def _harmonic_number(x):
@tf_export("distributions.Kumaraswamy")
-class Kumaraswamy(beta.Beta):
+class Kumaraswamy(transformed_distribution.TransformedDistribution):
"""Kumaraswamy distribution.
The Kumaraswamy distribution is defined over the `(0, 1)` interval using
@@ -151,59 +153,32 @@ class Kumaraswamy(beta.Beta):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
+ concentration1 = ops.convert_to_tensor(
+ concentration1, name="concentration1")
+ concentration0 = ops.convert_to_tensor(
+ concentration0, name="concentration0")
super(Kumaraswamy, self).__init__(
- concentration1=concentration1,
- concentration0=concentration0,
- validate_args=validate_args,
- allow_nan_stats=allow_nan_stats,
+ distribution=uniform.Uniform(
+ low=array_ops.zeros([], dtype=concentration1.dtype),
+ high=array_ops.ones([], dtype=concentration1.dtype),
+ allow_nan_stats=allow_nan_stats),
+ bijector=bijectors.Kumaraswamy(
+ concentration1=concentration1, concentration0=concentration0,
+ validate_args=validate_args),
+ batch_shape=distribution_util.get_broadcast_shape(
+ concentration1, concentration0),
name=name)
self._reparameterization_type = distribution.FULLY_REPARAMETERIZED
- def _sample_n(self, n, seed=None):
- expanded_concentration1 = array_ops.ones_like(
- self.total_concentration, dtype=self.dtype) * self.concentration1
- expanded_concentration0 = array_ops.ones_like(
- self.total_concentration, dtype=self.dtype) * self.concentration0
- shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
- uniform_sample = random_ops.random_uniform(
- shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed)
-
- kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**(
- 1. / expanded_concentration1)
- return kumaraswamy_sample
-
- @distribution_util.AppendDocstring(_kumaraswamy_sample_note)
- def _log_cdf(self, x):
- a = self.concentration1
- b = self.concentration0
- return math_ops.log1p(-(1 - x**a)**b)
+ @property
+ def concentration1(self):
+ """Concentration parameter associated with a `1` outcome."""
+ return self.bijector.concentration1
- @distribution_util.AppendDocstring(_kumaraswamy_sample_note)
- def _cdf(self, x):
- a = self.concentration1
- b = self.concentration0
- return 1 - (1 - x**a)**b
-
- def _survival_function(self, x):
- a = self.concentration1
- b = self.concentration0
- return (1 - x**a)**b
-
- def _log_survival_function(self, x):
- a = self.concentration1
- b = self.concentration0
- return b * math_ops.log1p(-x**a)
-
- def _log_unnormalized_prob(self, x):
- x = self._maybe_assert_valid_sample(x)
- a = self.concentration1
- b = self.concentration0
- return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a)
-
- def _log_normalization(self):
- a = self.concentration1
- b = self.concentration0
- return -(math_ops.log(a) + math_ops.log(b))
+ @property
+ def concentration0(self):
+ """Concentration parameter associated with a `0` outcome."""
+ return self.bijector.concentration0
def _entropy(self):
a = self.concentration1
@@ -213,10 +188,11 @@ class Kumaraswamy(beta.Beta):
def _moment(self, n):
"""Compute the n'th (uncentered) moment."""
+ total_concentration = self.concentration1 + self.concentration0
expanded_concentration1 = array_ops.ones_like(
- self.total_concentration, dtype=self.dtype) * self.concentration1
+ total_concentration, dtype=self.dtype) * self.concentration1
expanded_concentration0 = array_ops.ones_like(
- self.total_concentration, dtype=self.dtype) * self.concentration0
+ total_concentration, dtype=self.dtype) * self.concentration0
beta_arg0 = 1 + n / expanded_concentration1
beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1)
log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta(
@@ -246,13 +222,14 @@ class Kumaraswamy(beta.Beta):
name="nan")
is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.)
return array_ops.where(is_defined, mode, nan)
+
return control_flow_ops.with_dependencies([
check_ops.assert_less(
- array_ops.ones([], dtype=self.dtype),
+ array_ops.ones([], dtype=self.concentration1.dtype),
self.concentration1,
message="Mode undefined for concentration1 <= 1."),
check_ops.assert_less(
- array_ops.ones([], dtype=self.dtype),
+ array_ops.ones([], dtype=self.concentration0.dtype),
self.concentration0,
message="Mode undefined for concentration0 <= 1.")
], mode)
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py
index d9648ffb03..e57093bdbc 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import weakref
@@ -26,6 +27,7 @@ from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import control_flow_ops
@@ -37,6 +39,7 @@ from tensorflow.python.training import checkpointable as core_checkpointable
from tensorflow.python.training import checkpointable_utils as core_checkpointable_utils
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.util import deprecation
_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names.
@@ -278,6 +281,37 @@ def _serialize_object_graph(root_checkpointable):
slot_variables=slot_variables)
+def gather_initializers(root_checkpointable):
+ """Traverse the object graph and find initialization ops.
+
+ Looks for `Checkpointable` objects which are dependencies of
+ `root_checkpointable` and which have an `initializer` property. Includes
+ initializers for slot variables only if the variable they are slotting for and
+ the optimizer are dependencies of `root_checkpointable` (i.e. if they would be
+ saved with a checkpoint).
+
+ Args:
+ root_checkpointable: A `Checkpointable` object to gather initializers for.
+ Returns:
+ A list of initialization ops.
+ """
+ # TODO(allenl): Extract out gathering logic so the naming logic doesn't have
+ # to run.
+ checkpointable_objects, path_to_root = (
+ _breadth_first_checkpointable_traversal(root_checkpointable))
+ object_names = {
+ obj: _object_prefix_from_path(path)
+ for obj, path in path_to_root.items()}
+ node_ids = {node: node_id for node_id, node
+ in enumerate(checkpointable_objects)}
+ _serialize_slot_variables(
+ checkpointable_objects=checkpointable_objects,
+ node_ids=node_ids,
+ object_names=object_names)
+ return [c.initializer for c in checkpointable_objects
+ if hasattr(c, "initializer") and c.initializer is not None]
+
+
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
def __init__(self, tensor, name):
@@ -288,7 +322,26 @@ class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
return control_flow_ops.no_op()
-class CheckpointLoadStatus(object):
+class _LoadStatus(object):
+ """Abstract base for load status callbacks."""
+
+ @abc.abstractmethod
+ def assert_consumed(self):
+ """Raises an exception unless a non-trivial restoration has completed."""
+ pass
+
+ @abc.abstractmethod
+ def run_restore_ops(self, session=None):
+ """Runs restore ops from the checkpoint. Requires a valid checkpoint."""
+ pass
+
+ @abc.abstractmethod
+ def initialize_or_restore(self, session=None):
+ """Runs restore ops from the checkpoint, or initializes variables."""
+ pass
+
+
+class CheckpointLoadStatus(_LoadStatus):
"""Checks the status of checkpoint loading and manages restore ops.
Returned from `Saver.restore`. Since `restore` may defer the loading of values
@@ -348,6 +401,105 @@ class CheckpointLoadStatus(object):
session = ops.get_default_session()
session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
+ def initialize_or_restore(self, session=None):
+ """Alias for `run_restore_ops`.
+
+ This method has a sibling in `InitializationOnlyStatus` which instead
+ initializes variables. That type is returned if no checkpoint is specified
+ in `Saver.restore`.
+
+ Args:
+ session: The session to run restore ops in. If `None`, uses the default
+ session.
+ """
+ self.run_restore_ops(session=session)
+
+
+class InitializationOnlyStatus(_LoadStatus):
+ """Returned from `Saver.restore` when no checkpoint has been specified.
+
+ Objects of this type have the same `assert_consumed` method as
+ `CheckpointLoadStatus`, but it always fails. However,
+ `initialize_or_restore` works on objects of both types, and will
+ initialize variables in `InitializationOnlyStatus` objects or restore them
+ otherwise.
+ """
+
+ def __init__(self, root_checkpointable):
+ self._root_checkpointable = root_checkpointable
+
+ def assert_consumed(self):
+ """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
+ raise AssertionError(
+ "No checkpoint specified (save_path=None); nothing is being restored.")
+
+ def run_restore_ops(self, session=None):
+ """For consistency with `CheckpointLoadStatus`.
+
+ Use `initialize_or_restore` for initializing if no checkpoint was passed
+ to `Saver.restore` and restoring otherwise.
+
+ Args:
+ session: Not used.
+ """
+ raise AssertionError(
+ "No checkpoint specified, so no restore ops are available "
+ "(save_path=None to Saver.restore).")
+
+ def initialize_or_restore(self, session=None):
+ """Runs initialization ops for variables.
+
+ Only objects which would be saved by `Saver.save` will be initialized. See
+ `gather_initializers` for details.
+
+ This method does nothing when executing eagerly (initializers get run
+ eagerly).
+
+ Args:
+ session: The session to run initialization ops in. If `None`, uses the
+ default session.
+ """
+ if context.in_eager_mode():
+ return # run eagerly
+ if session is None:
+ session = ops.get_default_session()
+ session.run(gather_initializers(self._root_checkpointable))
+
+
+_DEPRECATED_RESTORE_INSTRUCTIONS = (
+ "Restoring a name-based tf.train.Saver checkpoint using the object-based "
+ "restore API. This mode uses global names to match variables, and so is "
+ "somewhat fragile. It also adds new restore ops to the graph each time it "
+ "is called. Prefer re-encoding training checkpoints in the object-based "
+ "format: run save() on the object-based saver (the same one this message "
+ "is coming from) and use that checkpoint in the future.")
+
+
+class NameBasedSaverStatus(_LoadStatus):
+ """Status for loading a name-based training checkpoint."""
+
+ def __init__(self, object_saver, save_path):
+ self._object_saver = object_saver
+ self._save_path = save_path
+
+ def assert_consumed(self):
+ """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
+ raise AssertionError(
+ "Restoring a name-based checkpoint. No load status is available.")
+
+ @deprecation.deprecated(
+ date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
+ def run_restore_ops(self, session=None):
+ """Load the name-based training checkpoint using a new `tf.train.Saver`."""
+ if session is None and context.in_graph_mode():
+ session = ops.get_default_session()
+ saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
+ sess=session, save_path=self._save_path)
+
+ def initialize_or_restore(self, session=None):
+ """Alias for `run_restore_ops`."""
+ self.run_restore_ops(session=session)
+
class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
"""Pretends to be a session, inserts extra feeds on run()."""
@@ -429,7 +581,7 @@ class Saver(object):
Args:
file_prefix: A prefix to use for the checkpoint filenames
(/path/to/directory/and_a_prefix). Names are generated based on this
- prefix and the global step, if provided.
+ prefix and `checkpoint_number`, if provided.
checkpoint_number: An integer variable or Tensor, used to number
checkpoints. Typically this value is saved along with other variables in
training checkpoints, which will happen automatically if it was created
@@ -483,6 +635,17 @@ class Saver(object):
global_step=checkpoint_number)
return save_path
+ def _global_variable_names(self):
+ """Generate a `tf.train.Saver`-style `var_list` using `variable.name`s."""
+ named_saveables, graph_proto = _serialize_object_graph(
+ self._root_checkpointable)
+ saver_names = {}
+ for object_proto in graph_proto.nodes:
+ for attribute_proto in object_proto.attributes:
+ saver_names[attribute_proto.full_name] = named_saveables[
+ attribute_proto.checkpoint_key]
+ return saver_names
+
def restore(self, save_path, session=None):
"""Restore a training checkpoint.
@@ -518,20 +681,35 @@ class Saver(object):
If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph.
+ Name-based `tf.train.Saver` checkpoints can be loaded using this
+ method. There is no deferred loading, and names are used to match
+ variables. No restore ops are created/run until `run_restore_ops()` or
+ `initialize_or_restore()` are called on the returned status object, even
+ when executing eagerly. Re-encode name-based checkpoints using this
+ object-based `Saver.save` as soon as possible.
+
Args:
save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If None (as when there is no latest
- checkpoint for `tf.train.latest_checkpoint` to return), does nothing.
+ checkpoint for `tf.train.latest_checkpoint` to return), returns an
+ object which may run initializers for objects in the dependency
+ graph. If the checkpoint was written by the name-based `tf.train.Saver`,
+ names are used to match variables.
session: The session to retrieve metadata with. Ignored when executing
eagerly. If not provided when graph building, the default session is
used.
Returns:
- A `CheckpointLoadStatus` object, which can be used to make assertions
- about the status of checkpoint restoration and run restore ops.
+ A load status object, which can be used to make assertions about the
+ status of checkpoint restoration and run initialization/restore ops
+ (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
+ `save_path` is `None`).
+
+ If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
+ object is returned which runs restore ops from a name-based saver.
"""
if save_path is None:
- return
+ return InitializationOnlyStatus(self._root_checkpointable)
in_graph_mode = context.in_graph_mode()
if in_graph_mode:
if session is None:
@@ -542,21 +720,27 @@ class Saver(object):
session = None
file_prefix_tensor = constant_op.constant(save_path)
file_prefix_feed_dict = None
- if not in_graph_mode or self._object_graph_restore_tensor is None:
- object_graph_string, = io_ops.restore_v2(
- prefix=file_prefix_tensor,
- tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
- shape_and_slices=[""],
- dtypes=[dtypes.string],
- name="object_graph_proto_read")
+ try:
+ if not in_graph_mode or self._object_graph_restore_tensor is None:
+ object_graph_string, = io_ops.restore_v2(
+ prefix=file_prefix_tensor,
+ tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
+ shape_and_slices=[""],
+ dtypes=[dtypes.string],
+ name="object_graph_proto_read")
+ if in_graph_mode:
+ self._object_graph_restore_tensor = object_graph_string
if in_graph_mode:
- self._object_graph_restore_tensor = object_graph_string
- if in_graph_mode:
- object_graph_string = session.run(
- self._object_graph_restore_tensor,
- feed_dict=file_prefix_feed_dict)
- else:
- object_graph_string = object_graph_string.numpy()
+ object_graph_string = session.run(
+ self._object_graph_restore_tensor,
+ feed_dict=file_prefix_feed_dict)
+ else:
+ object_graph_string = object_graph_string.numpy()
+ except errors_impl.NotFoundError:
+ # The object graph proto does not exist in this checkpoint. Try again with
+ # name-based saving.
+ return NameBasedSaverStatus(self, save_path)
+
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
index b7554defde..3d6a200276 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
@@ -36,7 +36,6 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
from tensorflow.python.training import adam
from tensorflow.python.training import checkpointable
from tensorflow.python.training import saver as core_saver
@@ -140,7 +139,7 @@ class Checkpoint(checkpointable.Checkpointable):
super(Checkpoint, self).__init__()
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
setattr(self, k, v)
- self._save_counter = None
+ self._save_counter = None # Created lazily for restore-on-create.
self._saver = checkpointable_utils.Saver(weakref.ref(self))
@property
@@ -170,8 +169,12 @@ class Checkpoint(checkpointable.Checkpointable):
session=session)
def restore(self, save_path):
- return self._saver.restore(
- save_path=save_path)
+ status = self._saver.restore(save_path=save_path)
+ # Create the save counter now so it gets initialized with other variables
+ # when graph building. Creating it earlier would lead to double
+ # initialization when executing eagerly.
+ self.save_counter # pylint: disable=pointless-statement
+ return status
class InterfaceTests(test.TestCase):
@@ -206,8 +209,7 @@ class InterfaceTests(test.TestCase):
with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"):
checkpointable_utils.add_variable(obj, name="duplicate", shape=[])
- if context.in_graph_mode():
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(checkpointable_utils.gather_initializers(obj))
self.assertEqual("constant_initializer:0", constant_initializer.name)
self.assertEqual(1, self.evaluate(constant_initializer))
self.assertEqual("some_variable_scope/ones_initializer:0",
@@ -287,7 +289,8 @@ class CheckpointingTests(test.TestCase):
optimizer.minimize(
other_network(input_value),
global_step=optimizer_step)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(checkpointable_utils.gather_initializers(
+ root_checkpointable))
self.evaluate(train_op)
named_variables, serialized_graph = (
checkpointable_utils._serialize_object_graph(root_checkpointable))
@@ -385,7 +388,8 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(network(input_value))
# TODO(allenl): Make initialization more pleasant when graph building.
root_checkpointable.save_counter # pylint: disable=pointless-statement
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(checkpointable_utils.gather_initializers(
+ root_checkpointable))
self.evaluate(train_op)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.]))
@@ -429,6 +433,7 @@ class CheckpointingTests(test.TestCase):
self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
+ # TODO(allenl): Debug garbage created by this test in python3.
def testDeferredRestorationUsageEager(self):
"""An idiomatic eager execution example."""
num_training_steps = 10
@@ -468,28 +473,57 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(
network(input_value),
global_step=root.global_step)
- root.save_counter # pylint: disable=pointless-statement
- init_op = variables.global_variables_initializer()
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
+ status = root.restore(save_path=checkpoint_path)
+ status.initialize_or_restore(session=session)
if checkpoint_path is None:
self.assertEqual(0, training_continuation)
- session.run(init_op)
- # Another alternative would be to run initializers automatically
- # if no checkpoint is being loaded. This would make deferred
- # loading a bit more useful with graph execution.
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
else:
- status = root.restore(save_path=checkpoint_path).assert_consumed()
- status.run_restore_ops()
+ status.assert_consumed()
for _ in range(num_training_steps):
session.run(train_op)
- root.save(file_prefix=checkpoint_prefix,
- session=session)
+ root.save(file_prefix=checkpoint_prefix, session=session)
self.assertEqual((training_continuation + 1) * num_training_steps,
session.run(root.global_step))
self.assertEqual(training_continuation + 1,
session.run(root.save_counter))
+ @test_util.run_in_graph_and_eager_modes()
+ def testAgnosticUsage(self):
+ """Graph/eager agnostic usage."""
+ # Does create garbage when executing eagerly due to ops.Graph() creation.
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()):
+ network = MyNetwork()
+ optimizer = CheckpointableAdam(0.001)
+ root = Checkpoint(
+ optimizer=optimizer, network=network,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ input_value = constant_op.constant([[3.]])
+ train_fn = functools.partial(
+ optimizer.minimize,
+ functools.partial(network, input_value),
+ global_step=root.global_step)
+ if context.in_graph_mode():
+ train_fn = functools.partial(self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+
def _get_checkpoint_name(self, name):
root = checkpointable.Checkpointable()
checkpointable_utils.add_variable(
@@ -602,7 +636,11 @@ class CheckpointingTests(test.TestCase):
optimizer = CheckpointableAdam(0.1)
if context.in_graph_mode():
train_op = optimizer.minimize(root.var)
- self.evaluate(variables.global_variables_initializer())
+ # Note that `optimizer` has not been added as a dependency of
+ # `root`. Create a one-off grouping so that slot variables for `root.var`
+ # get initialized too.
+ self.evaluate(checkpointable_utils.gather_initializers(
+ Checkpoint(root=root, optimizer=optimizer)))
self.evaluate(train_op)
else:
optimizer.minimize(root.var.read_value)
@@ -709,7 +747,7 @@ class CheckpointingTests(test.TestCase):
save_root.dep_one.dep_three = dep_three
save_root.dep_two.dep_three = dep_three
checkpointable_utils.add_variable(dep_three, name="var", initializer=0.)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(checkpointable_utils.gather_initializers(save_root))
save_path = checkpointable_utils.Saver(save_root).save(
os.path.join(checkpoint_directory, "ckpt"))
load_root = checkpointable.Checkpointable()
@@ -732,7 +770,7 @@ class CheckpointingTests(test.TestCase):
save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64)
checkpointable_utils.add_variable(
save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(checkpointable_utils.gather_initializers(save_root))
save_path = checkpointable_utils.Saver(save_root).save(
os.path.join(checkpoint_directory, "ckpt"))
load_root = checkpointable.Checkpointable()
@@ -760,7 +798,7 @@ class CheckpointingTests(test.TestCase):
first, "v1", initializer=[3., 1., 4.])
second.v = checkpointable_utils.add_variable(
second, "v2", initializer=[1., 1., 2., 3.])
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(checkpointable_utils.gather_initializers(first))
checkpoint_directory = self.get_temp_dir()
save_path = checkpointable_utils.Saver(first).save(
os.path.join(checkpoint_directory, "ckpt"))
@@ -835,7 +873,7 @@ class CheckpointingTests(test.TestCase):
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = CheckpointableAdam(0.1)
obj.opt.minimize(obj.var.read_value())
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(checkpointable_utils.gather_initializers(obj))
saver = checkpointable_utils.Saver(obj)
saver.save(checkpoint_prefix)
before_ops = graph.get_operations()
@@ -853,7 +891,7 @@ class CheckpointingTests(test.TestCase):
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = CheckpointableAdam(0.1)
obj.opt.minimize(obj.var.read_value())
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(checkpointable_utils.gather_initializers(obj))
saver = checkpointable_utils.Saver(obj)
save_path = saver.save(checkpoint_prefix)
saver.restore(save_path)
@@ -861,5 +899,115 @@ class CheckpointingTests(test.TestCase):
saver.restore(save_path)
self.assertEqual(before_ops, graph.get_operations())
+
+class CheckpointCompatibilityTests(test.TestCase):
+
+ def _initialized_model(self):
+ input_value = constant_op.constant([[3.]])
+ network = MyNetwork()
+ optimizer = CheckpointableAdam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = Checkpoint(
+ optimizer=optimizer, network=network, optimizer_step=optimizer_step)
+ train_op = optimizer.minimize(
+ functools.partial(network, input_value),
+ global_step=optimizer_step)
+ self.evaluate(checkpointable_utils.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ # A regular variable, a slot variable, and a non-slot Optimizer variable
+ # with known values to check when loading.
+ self.evaluate(network._named_dense.bias.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=network._named_dense.bias, name="m").assign([2.]))
+ beta1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta1_power.assign(3.))
+ return root_checkpointable
+
+ def _set_sentinels(self, root_checkpointable):
+ self.evaluate(root_checkpointable.network._named_dense.bias.assign([101.]))
+ self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.network._named_dense.bias, name="m")
+ .assign([102.]))
+ beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta1_power.assign(103.))
+
+ def _check_sentinels(self, root_checkpointable):
+ self.assertAllEqual(
+ [1.], self.evaluate(root_checkpointable.network._named_dense.bias))
+ self.assertAllEqual([2.], self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.network._named_dense.bias, name="m")))
+ beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta1_power))
+
+ def _write_name_based_checkpoint(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ name_saver = core_saver.Saver()
+ return name_saver.save(
+ sess=session, save_path=checkpoint_prefix,
+ global_step=root.optimizer_step)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testLoadFromNameBasedSaver(self):
+ """Save a name-based checkpoint, load it using the object-based API."""
+ save_path = self._write_name_based_checkpoint()
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ with self.assertRaises(AssertionError):
+ self._check_sentinels(root)
+ object_saver = checkpointable_utils.Saver(root)
+ status = object_saver.restore(save_path)
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
+ status.run_restore_ops()
+ self._check_sentinels(root)
+ self._set_sentinels(root)
+ status.initialize_or_restore()
+ self._check_sentinels(root)
+
+ # TODO(allenl): Test for the core name-based saver loading object-based
+ # checkpoints once object-based checkpointing is in core.
+
+ def testSaveGraphLoadEager(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ object_saver = checkpointable_utils.Saver(root)
+ save_path = object_saver.save(
+ session=session, file_prefix=checkpoint_prefix)
+ with context.eager_mode():
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed()
+ self._check_sentinels(root)
+
+ def testSaveEagerLoadGraph(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.eager_mode():
+ root = self._initialized_model()
+ object_saver = checkpointable_utils.Saver(root)
+ save_path = object_saver.save(file_prefix=checkpoint_prefix)
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph):
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed().run_restore_ops()
+ self._check_sentinels(root)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
index b9ac79f46c..5f51d52622 100644
--- a/tensorflow/contrib/eager/python/examples/gan/mnist.py
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -35,7 +35,7 @@ from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None
-class Discriminator(tfe.Network):
+class Discriminator(tf.keras.Model):
"""GAN Discriminator.
A network to differentiate between generated and real handwritten digits.
@@ -56,19 +56,15 @@ class Discriminator(tfe.Network):
else:
assert data_format == 'channels_last'
self._input_shape = [-1, 28, 28, 1]
- self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME',
- data_format=data_format,
- activation=tf.tanh))
- self.pool1 = self.track_layer(
- tf.layers.AveragePooling2D(2, 2, data_format=data_format))
- self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5,
- data_format=data_format,
- activation=tf.tanh))
- self.pool2 = self.track_layer(
- tf.layers.AveragePooling2D(2, 2, data_format=data_format))
- self.flatten = self.track_layer(tf.layers.Flatten())
- self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh))
- self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None))
+ self.conv1 = tf.layers.Conv2D(
+ 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
+ self.pool1 = tf.layers.AveragePooling2D(2, 2, data_format=data_format)
+ self.conv2 = tf.layers.Conv2D(
+ 128, 5, data_format=data_format, activation=tf.tanh)
+ self.pool2 = tf.layers.AveragePooling2D(2, 2, data_format=data_format)
+ self.flatten = tf.layers.Flatten()
+ self.fc1 = tf.layers.Dense(1024, activation=tf.tanh)
+ self.fc2 = tf.layers.Dense(1, activation=None)
def call(self, inputs):
"""Return two logits per image estimating input authenticity.
@@ -95,7 +91,7 @@ class Discriminator(tfe.Network):
return x
-class Generator(tfe.Network):
+class Generator(tf.keras.Model):
"""Generator of handwritten digits similar to the ones in the MNIST dataset.
"""
@@ -116,18 +112,17 @@ class Generator(tfe.Network):
else:
assert data_format == 'channels_last'
self._pre_conv_shape = [-1, 6, 6, 128]
- self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128,
- activation=tf.tanh))
+ self.fc1 = tf.layers.Dense(6 * 6 * 128, activation=tf.tanh)
# In call(), we reshape the output of fc1 to _pre_conv_shape
# Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
- self.conv1 = self.track_layer(tf.layers.Conv2DTranspose(
- 64, 4, strides=2, activation=None, data_format=data_format))
+ self.conv1 = tf.layers.Conv2DTranspose(
+ 64, 4, strides=2, activation=None, data_format=data_format)
# Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
- self.conv2 = self.track_layer(tf.layers.Conv2DTranspose(
- 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format))
+ self.conv2 = tf.layers.Conv2DTranspose(
+ 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
def call(self, inputs):
"""Return a batch of generated images.
@@ -168,7 +163,8 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
"""
loss_on_real = tf.losses.sigmoid_cross_entropy(
- tf.ones_like(discriminator_real_outputs), discriminator_real_outputs,
+ tf.ones_like(discriminator_real_outputs),
+ discriminator_real_outputs,
label_smoothing=0.25)
loss_on_generated = tf.losses.sigmoid_cross_entropy(
tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
@@ -198,9 +194,8 @@ def generator_loss(discriminator_gen_outputs):
return loss
-def train_one_epoch(generator, discriminator,
- generator_optimizer, discriminator_optimizer,
- dataset, log_interval, noise_dim):
+def train_one_epoch(generator, discriminator, generator_optimizer,
+ discriminator_optimizer, dataset, log_interval, noise_dim):
"""Trains `generator` and `discriminator` models on `dataset`.
Args:
@@ -222,14 +217,18 @@ def train_one_epoch(generator, discriminator,
with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
current_batch_size = images.shape[0]
- noise = tf.random_uniform(shape=[current_batch_size, noise_dim],
- minval=-1., maxval=1., seed=batch_index)
+ noise = tf.random_uniform(
+ shape=[current_batch_size, noise_dim],
+ minval=-1.,
+ maxval=1.,
+ seed=batch_index)
with tfe.GradientTape(persistent=True) as g:
generated_images = generator(noise)
- tf.contrib.summary.image('generated_images',
- tf.reshape(generated_images, [-1, 28, 28, 1]),
- max_images=10)
+ tf.contrib.summary.image(
+ 'generated_images',
+ tf.reshape(generated_images, [-1, 28, 28, 1]),
+ max_images=10)
discriminator_gen_outputs = discriminator(generated_images)
discriminator_real_outputs = discriminator(images)
@@ -245,17 +244,17 @@ def train_one_epoch(generator, discriminator,
discriminator.variables)
with tf.variable_scope('generator'):
- generator_optimizer.apply_gradients(zip(generator_grad,
- generator.variables))
+ generator_optimizer.apply_gradients(
+ zip(generator_grad, generator.variables))
with tf.variable_scope('discriminator'):
- discriminator_optimizer.apply_gradients(zip(discriminator_grad,
- discriminator.variables))
+ discriminator_optimizer.apply_gradients(
+ zip(discriminator_grad, discriminator.variables))
if log_interval and batch_index > 0 and batch_index % log_interval == 0:
print('Batch #%d\tAverage Generator Loss: %.6f\t'
- 'Average Discriminator Loss: %.6f' % (
- batch_index, total_generator_loss/batch_index,
- total_discriminator_loss/batch_index))
+ 'Average Discriminator Loss: %.6f' %
+ (batch_index, total_generator_loss / batch_index,
+ total_discriminator_loss / batch_index))
def main(_):
@@ -266,10 +265,9 @@ def main(_):
# Load the datasets
data = input_data.read_data_sets(FLAGS.data_dir)
- dataset = (tf.data.Dataset
- .from_tensor_slices(data.train.images)
- .shuffle(60000)
- .batch(FLAGS.batch_size))
+ dataset = (
+ tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000)
+ .batch(FLAGS.batch_size))
# Create the models and optimizers
generator = Generator(data_format)
@@ -294,20 +292,17 @@ def main(_):
start = time.time()
with summary_writer.as_default():
train_one_epoch(generator, discriminator, generator_optimizer,
- discriminator_optimizer,
- dataset, FLAGS.log_interval, FLAGS.noise)
+ discriminator_optimizer, dataset, FLAGS.log_interval,
+ FLAGS.noise)
end = time.time()
- print('\nTrain time for epoch #%d (global step %d): %f' % (
- epoch, global_step.numpy(), end - start))
+ print('\nTrain time for epoch #%d (global step %d): %f' %
+ (epoch, global_step.numpy(), end - start))
all_variables = (
- generator.variables
- + discriminator.variables
- + generator_optimizer.variables()
- + discriminator_optimizer.variables()
- + [global_step])
- tfe.Saver(all_variables).save(
- checkpoint_prefix, global_step=global_step)
+ generator.variables + discriminator.variables +
+ generator_optimizer.variables() +
+ discriminator_optimizer.variables() + [global_step])
+ tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
index 6ce4de6ee0..157a6360ea 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
@@ -33,23 +33,13 @@ import tensorflow as tf
import tensorflow.contrib.eager as tfe
-class LinearModel(tfe.Network):
- """A TensorFlow linear regression model.
-
- Uses TensorFlow's eager execution.
-
- For those familiar with TensorFlow graphs, notice the absence of
- `tf.Session`. The `forward()` method here immediately executes and
- returns output values. The `loss()` method immediately compares the
- output of `forward()` with the target and returns the MSE loss value.
- The `fit()` performs gradient-descent training on the model's weights
- and bias.
- """
+class LinearModel(tf.keras.Model):
+ """A TensorFlow linear regression model."""
def __init__(self):
"""Constructs a LinearModel object."""
super(LinearModel, self).__init__()
- self._hidden_layer = self.track_layer(tf.layers.Dense(1))
+ self._hidden_layer = tf.layers.Dense(1)
def call(self, xs):
"""Invoke the linear model.
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
index 9982fdb07e..6b59413141 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
@@ -27,10 +27,9 @@ from __future__ import print_function
import functools
import tensorflow as tf
-import tensorflow.contrib.eager as tfe
-class _IdentityBlock(tfe.Network):
+class _IdentityBlock(tf.keras.Model):
"""_IdentityBlock is the block that has no conv layer at shortcut.
Args:
@@ -50,31 +49,24 @@ class _IdentityBlock(tfe.Network):
bn_name_base = 'bn' + str(stage) + block + '_branch'
bn_axis = 1 if data_format == 'channels_first' else 3
- self.conv2a = self.track_layer(
- tf.layers.Conv2D(
- filters1, (1, 1),
- name=conv_name_base + '2a',
- data_format=data_format))
- self.bn2a = self.track_layer(
- tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a'))
-
- self.conv2b = self.track_layer(
- tf.layers.Conv2D(
- filters2,
- kernel_size,
- padding='same',
- data_format=data_format,
- name=conv_name_base + '2b'))
- self.bn2b = self.track_layer(
- tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b'))
-
- self.conv2c = self.track_layer(
- tf.layers.Conv2D(
- filters3, (1, 1),
- name=conv_name_base + '2c',
- data_format=data_format))
- self.bn2c = self.track_layer(
- tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c'))
+ self.conv2a = tf.layers.Conv2D(
+ filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format)
+ self.bn2a = tf.layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2a')
+
+ self.conv2b = tf.layers.Conv2D(
+ filters2,
+ kernel_size,
+ padding='same',
+ data_format=data_format,
+ name=conv_name_base + '2b')
+ self.bn2b = tf.layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2b')
+
+ self.conv2c = tf.layers.Conv2D(
+ filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
+ self.bn2c = tf.layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2c')
def call(self, input_tensor, training=False):
x = self.conv2a(input_tensor)
@@ -92,7 +84,7 @@ class _IdentityBlock(tfe.Network):
return tf.nn.relu(x)
-class _ConvBlock(tfe.Network):
+class _ConvBlock(tf.keras.Model):
"""_ConvBlock is the block that has a conv layer at shortcut.
Args:
@@ -121,41 +113,35 @@ class _ConvBlock(tfe.Network):
bn_name_base = 'bn' + str(stage) + block + '_branch'
bn_axis = 1 if data_format == 'channels_first' else 3
- self.conv2a = self.track_layer(
- tf.layers.Conv2D(
- filters1, (1, 1),
- strides=strides,
- name=conv_name_base + '2a',
- data_format=data_format))
- self.bn2a = self.track_layer(
- tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a'))
-
- self.conv2b = self.track_layer(
- tf.layers.Conv2D(
- filters2,
- kernel_size,
- padding='same',
- name=conv_name_base + '2b',
- data_format=data_format))
- self.bn2b = self.track_layer(
- tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b'))
-
- self.conv2c = self.track_layer(
- tf.layers.Conv2D(
- filters3, (1, 1),
- name=conv_name_base + '2c',
- data_format=data_format))
- self.bn2c = self.track_layer(
- tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c'))
-
- self.conv_shortcut = self.track_layer(
- tf.layers.Conv2D(
- filters3, (1, 1),
- strides=strides,
- name=conv_name_base + '1',
- data_format=data_format))
- self.bn_shortcut = self.track_layer(
- tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1'))
+ self.conv2a = tf.layers.Conv2D(
+ filters1, (1, 1),
+ strides=strides,
+ name=conv_name_base + '2a',
+ data_format=data_format)
+ self.bn2a = tf.layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2a')
+
+ self.conv2b = tf.layers.Conv2D(
+ filters2,
+ kernel_size,
+ padding='same',
+ name=conv_name_base + '2b',
+ data_format=data_format)
+ self.bn2b = tf.layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2b')
+
+ self.conv2c = tf.layers.Conv2D(
+ filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
+ self.bn2c = tf.layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '2c')
+
+ self.conv_shortcut = tf.layers.Conv2D(
+ filters3, (1, 1),
+ strides=strides,
+ name=conv_name_base + '1',
+ data_format=data_format)
+ self.bn_shortcut = tf.layers.BatchNormalization(
+ axis=bn_axis, name=bn_name_base + '1')
def call(self, input_tensor, training=False):
x = self.conv2a(input_tensor)
@@ -176,7 +162,8 @@ class _ConvBlock(tfe.Network):
return tf.nn.relu(x)
-class ResNet50(tfe.Network):
+# pylint: disable=not-callable
+class ResNet50(tf.keras.Model):
"""Instantiates the ResNet50 architecture.
Args:
@@ -220,32 +207,28 @@ class ResNet50(tfe.Network):
self.include_top = include_top
def conv_block(filters, stage, block, strides=(2, 2)):
- l = _ConvBlock(
+ return _ConvBlock(
3,
filters,
stage=stage,
block=block,
data_format=data_format,
strides=strides)
- return self.track_layer(l)
def id_block(filters, stage, block):
- l = _IdentityBlock(
+ return _IdentityBlock(
3, filters, stage=stage, block=block, data_format=data_format)
- return self.track_layer(l)
-
- self.conv1 = self.track_layer(
- tf.layers.Conv2D(
- 64, (7, 7),
- strides=(2, 2),
- data_format=data_format,
- padding='same',
- name='conv1'))
+
+ self.conv1 = tf.layers.Conv2D(
+ 64, (7, 7),
+ strides=(2, 2),
+ data_format=data_format,
+ padding='same',
+ name='conv1')
bn_axis = 1 if data_format == 'channels_first' else 3
- self.bn_conv1 = self.track_layer(
- tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1'))
- self.max_pool = self.track_layer(
- tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format))
+ self.bn_conv1 = tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
+ self.max_pool = tf.layers.MaxPooling2D(
+ (3, 3), strides=(2, 2), data_format=data_format)
self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
self.l2b = id_block([64, 64, 256], stage=2, block='b')
@@ -267,13 +250,11 @@ class ResNet50(tfe.Network):
self.l5b = id_block([512, 512, 2048], stage=5, block='b')
self.l5c = id_block([512, 512, 2048], stage=5, block='c')
- self.avg_pool = self.track_layer(
- tf.layers.AveragePooling2D(
- (7, 7), strides=(7, 7), data_format=data_format))
+ self.avg_pool = tf.layers.AveragePooling2D(
+ (7, 7), strides=(7, 7), data_format=data_format)
if self.include_top:
- self.fc1000 = self.track_layer(
- tf.layers.Dense(classes, name='fc1000'))
+ self.fc1000 = tf.layers.Dense(classes, name='fc1000')
else:
reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
reduction_indices = tf.constant(reduction_indices)
@@ -288,7 +269,7 @@ class ResNet50(tfe.Network):
else:
self.global_pooling = None
- def call(self, input_tensor, training=False):
+ def call(self, input_tensor, training):
x = self.conv1(input_tensor)
x = self.bn_conv1(x, training=training)
x = tf.nn.relu(x)
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
index 23317886e7..551c76b0df 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
@@ -55,7 +55,7 @@ class ResNet50GraphTest(tf.test.TestCase):
with tf.Graph().as_default():
images = tf.placeholder(tf.float32, image_shape(None))
model = resnet50.ResNet50(data_format())
- predictions = model(images)
+ predictions = model(images, training=False)
init = tf.global_variables_initializer()
@@ -114,7 +114,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
with tf.Graph().as_default():
images = tf.placeholder(tf.float32, image_shape(None))
model = resnet50.ResNet50(data_format())
- predictions = model(images)
+ predictions = model(images, training=False)
init = tf.global_variables_initializer()
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index 0ff8746884..c106ab0a06 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -71,7 +71,7 @@ class ResNet50Test(tf.test.TestCase):
model.call = tfe.defun(model.call)
with tf.device(device):
images, _ = random_batch(2)
- output = model(images)
+ output = model(images, training=False)
self.assertEqual((2, 1000), output.shape)
def test_apply(self):
@@ -85,7 +85,7 @@ class ResNet50Test(tf.test.TestCase):
model = resnet50.ResNet50(data_format, include_top=False)
with tf.device(device):
images, _ = random_batch(2)
- output = model(images)
+ output = model(images, training=False)
output_shape = ((2, 2048, 1, 1)
if data_format == 'channels_first' else (2, 1, 1, 2048))
self.assertEqual(output_shape, output.shape)
@@ -95,7 +95,7 @@ class ResNet50Test(tf.test.TestCase):
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
with tf.device(device):
images, _ = random_batch(2)
- output = model(images)
+ output = model(images, training=False)
self.assertEqual((2, 2048), output.shape)
def test_train(self):
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index c861cfff54..7319eaa7de 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -61,8 +61,8 @@ class _LossRelativeChangeHook(session_run_hook.SessionRunHook):
loss = run_values.results
assert loss is not None
if self._prev_loss:
- relative_change = (abs(loss - self._prev_loss) /
- (1 + abs(self._prev_loss)))
+ relative_change = (
+ abs(loss - self._prev_loss) / (1 + abs(self._prev_loss)))
if relative_change < self._tolerance:
run_context.request_stop()
self._prev_loss = loss
@@ -233,7 +233,57 @@ class _ModelFn(object):
# TODO(agarwal,ands): support sharded input.
class KMeansClustering(estimator.Estimator):
- """An Estimator for K-Means clustering."""
+ """An Estimator for K-Means clustering.
+
+ Example:
+ ```
+ import numpy as np
+ import tensorflow as tf
+
+ num_points = 100
+ dimensions = 2
+ points = np.random.uniform(0, 1000, [num_points, dimensions])
+
+ def input_fn():
+ return tf.train.limit_epochs(
+ tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1)
+
+ num_clusters = 5
+ kmeans = tf.contrib.factorization.KMeansClustering(
+ num_clusters=num_clusters, use_mini_batch=False)
+
+ # train
+ num_iterations = 10
+ previous_centers = None
+ for _ in xrange(num_iterations):
+ kmeans.train(input_fn)
+ cluster_centers = kmeans.cluster_centers()
+ if previous_centers is not None:
+ print 'delta:', cluster_centers - previous_centers
+ previous_centers = cluster_centers
+ print 'score:', kmeans.score(input_fn)
+ print 'cluster centers:', cluster_centers
+
+ # map the input points to their clusters
+ cluster_indices = list(kmeans.predict_cluster_index(input_fn))
+ for i, point in enumerate(points):
+ cluster_index = cluster_indices[i]
+ center = cluster_centers[cluster_index]
+ print 'point:', point, 'is in cluster', cluster_index, 'centered at', center
+ ```
+
+ The `SavedModel` saved by the `export_savedmodel` method does not include the
+ cluster centers. However, the cluster centers may be retrieved by the
+ latest checkpoint saved during training. Specifically,
+ ```
+ kmeans.cluster_centers()
+ ```
+ is equivalent to
+ ```
+ tf.train.load_variable(
+ kmeans.model_dir, KMeansClustering.CLUSTER_CENTERS_VAR_NAME)
+ ```
+ """
# Valid values for the distance_metric constructor argument.
SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE
@@ -253,6 +303,9 @@ class KMeansClustering(estimator.Estimator):
CLUSTER_INDEX = 'cluster_index'
ALL_DISTANCES = 'all_distances'
+ # Variable name used by cluster_centers().
+ CLUSTER_CENTERS_VAR_NAME = clustering_ops.CLUSTERS_VAR_NAME
+
def __init__(self,
num_clusters,
model_dir=None,
@@ -406,4 +459,4 @@ class KMeansClustering(estimator.Estimator):
def cluster_centers(self):
"""Returns the cluster centers."""
- return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME)
+ return self.get_variable_value(KMeansClustering.CLUSTER_CENTERS_VAR_NAME)
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index f5290a14d3..53de21697b 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -2899,9 +2899,11 @@ inline void Mean(T* input_data, const int* input_dims, const int input_num_dims,
for (int idx = 0; idx < num_resolved_axis; ++idx) {
num_elements_in_axis *= static_cast<size_t>(input_dims[resolved_axis[idx]]);
}
- for (size_t idx = 0; idx < num_outputs; ++idx) {
- output_data[idx] = static_cast<T>(static_cast<float>(output_data[idx]) /
- num_elements_in_axis);
+ if (num_elements_in_axis > 0) {
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ output_data[idx] = static_cast<T>(static_cast<float>(output_data[idx]) /
+ num_elements_in_axis);
+ }
}
}
diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc
index c4c53c2ded..2d6d4bc2da 100644
--- a/tensorflow/contrib/lite/kernels/mean_test.cc
+++ b/tensorflow/contrib/lite/kernels/mean_test.cc
@@ -74,7 +74,7 @@ class MeanOpDynamicModel : public BaseMeanOpModel {
}
};
-TEST(ConstMeanOpTest, NotKeepDims) {
+TEST(ConstFloatMeanOpTest, NotKeepDims) {
std::initializer_list<float> data = {
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
@@ -86,7 +86,7 @@ TEST(ConstMeanOpTest, NotKeepDims) {
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
}
-TEST(ConstMeanOpTest, KeepDims) {
+TEST(ConstFloatMeanOpTest, KeepDims) {
std::initializer_list<float> data = {
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
@@ -99,7 +99,7 @@ TEST(ConstMeanOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
}
-TEST(DynamicMeanOpTest, NotKeepDims) {
+TEST(DynamicFloatMeanOpTest, NotKeepDims) {
std::initializer_list<float> data = {
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
@@ -114,7 +114,7 @@ TEST(DynamicMeanOpTest, NotKeepDims) {
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
}
-TEST(DynamicMeanOpTest, KeepDims) {
+TEST(DynamicFloatMeanOpTest, KeepDims) {
std::initializer_list<float> data = {
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
@@ -130,6 +130,70 @@ TEST(DynamicMeanOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
}
+TEST(DynamicFloatMeanOpTest, Scale) {
+ std::initializer_list<float> data = {9.527};
+ MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::initializer_list<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+TEST(ConstUint8MeanOpTest, NotKeepDims) {
+ std::initializer_list<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24};
+ MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({12, 13}));
+}
+
+TEST(ConstUint8MeanOpTest, KeepDims) {
+ std::initializer_list<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24};
+ MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({10, 12, 14}));
+}
+
+TEST(DynamicUint8MeanOpTest, NotKeepDims) {
+ std::initializer_list<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24};
+ MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}},
+ {TensorType_INT32, {4}}, false);
+ std::initializer_list<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({12, 13}));
+}
+
+TEST(DynamicUint8MeanOpTest, KeepDims) {
+ std::initializer_list<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24};
+ MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}},
+ {TensorType_INT32, {2}}, true);
+ std::initializer_list<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({10, 12, 14}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index b5960d6f8d..83b9e21427 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -317,7 +317,10 @@ tf_cc_test(
"//tensorflow/contrib/lite:testdata/multi_add.bin",
"//tensorflow/contrib/lite:testdata/multi_add.pb",
],
- tags = ["no_oss"],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_oss",
+ ],
deps = [
":tflite_diff_flags",
":tflite_diff_util",
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 976363fd44..86606d1239 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -92,6 +92,9 @@ std::map<string, string> kBrokenTests = {
// Transpose only supports 1D-4D input tensors.
{R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
+
+ // Lstm kernel gets different results on tsan, asan, msan.
+ {R"(^\/lstmdtype=tf.float32.*)", "73830845"},
};
// Allows test data to be unzipped into a temporary directory and makes
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index bc374d66c3..827279bd47 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -70,9 +70,6 @@ py_test(
srcs = ["python/training/moving_average_optimizer_test.py"],
srcs_version = "PY2AND3",
tags = [
- "manual",
- "no_oss", # b/73507407
- "notap",
"notsan", # b/31055119
],
deps = [
diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py
index 4bf698f207..044de33568 100644
--- a/tensorflow/contrib/py2tf/impl/conversion.py
+++ b/tensorflow/contrib/py2tf/impl/conversion.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import gast
-import six
from tensorflow.contrib.py2tf import utils
from tensorflow.contrib.py2tf.converters import asserts
@@ -36,6 +35,7 @@ from tensorflow.contrib.py2tf.converters import side_effect_guards
from tensorflow.contrib.py2tf.impl import config
from tensorflow.contrib.py2tf.impl import naming
from tensorflow.contrib.py2tf.pyct import context
+from tensorflow.contrib.py2tf.pyct import inspect_utils
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct import qual_names
from tensorflow.contrib.py2tf.pyct.static_analysis import activity
@@ -155,7 +155,7 @@ def class_to_graph(c, conversion_map):
if not members:
raise ValueError('Cannot convert %s: it has no member methods.')
- class_globals = None
+ class_namespace = None
for _, m in members:
node, _ = function_to_graph(
m,
@@ -164,10 +164,10 @@ def class_to_graph(c, conversion_map):
arg_types={'self': (c.__name__, c)},
owner_type=c)
# TODO(mdan): Do not assume all members have the same view of globals.
- if class_globals is None:
- class_globals = six.get_function_globals(m)
+ if class_namespace is None:
+ class_namespace = inspect_utils.getnamespace(m)
converted_members[m] = node
- namer = conversion_map.new_namer(class_globals)
+ namer = conversion_map.new_namer(class_namespace)
class_name = namer.compiled_class_name(c.__name__, c)
node = gast.ClassDef(
class_name,
@@ -202,19 +202,11 @@ def function_to_graph(f, conversion_map, arg_values, arg_types,
"""Specialization of `entity_to_graph` for callable functions."""
node, source = parser.parse_entity(f)
node = node.body[0]
- namespace = six.get_function_globals(f)
-
- # This is needed for non-global functions.
- closure = six.get_function_closure(f)
- if closure:
- for e in closure:
- if callable(e.cell_contents):
- fn = e.cell_contents
- namespace[fn.__name__] = fn
+ namespace = inspect_utils.getnamespace(f)
_add_self_references(namespace, conversion_map.api_module)
-
namer = conversion_map.new_namer(namespace)
+
ctx = context.EntityContext(
namer=namer,
source_code=source,
diff --git a/tensorflow/contrib/py2tf/pyct/inspect_utils.py b/tensorflow/contrib/py2tf/pyct/inspect_utils.py
index 86cf52afd5..d19c6ed75e 100644
--- a/tensorflow/contrib/py2tf/pyct/inspect_utils.py
+++ b/tensorflow/contrib/py2tf/pyct/inspect_utils.py
@@ -21,33 +21,58 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import itertools
+
import six
from tensorflow.python.util import tf_inspect
-def getcallargs(c, *args, **kwargs):
- """Extension of getcallargs to non-function callables."""
- if tf_inspect.isfunction(c):
- # The traditional getcallargs
- return tf_inspect.getcallargs(c, *args, **kwargs)
+def getnamespace(f):
+ """Returns the complete namespace of a function.
+
+ Namespace is defined here as the mapping of all non-local variables to values.
+ This includes the globals and the closure variables. Note that this captures
+ the entire globals collection of the function, and may contain extra symbols
+ that it does not actually use.
+
+ Args:
+ f: User defined function.
+ Returns:
+ A dict mapping symbol names to values.
+ """
+ namespace = dict(six.get_function_globals(f))
+ closure = six.get_function_closure(f)
+ freevars = six.get_function_code(f).co_freevars
+ if freevars and closure:
+ for name, cell in zip(freevars, closure):
+ namespace[name] = cell.cell_contents
+ return namespace
+
+
+def getmethodclass(m):
+ """Resolves a function's owner, e.g. a method's class.
+
+ Note that this returns the object that the function was retrieved from, not
+ necessarily the class where it was defined.
- if tf_inspect.isclass(c):
- # Constructors: pass a fake None for self, then remove it.
- arg_map = tf_inspect.getcallargs(c.__init__, None, *args, **kwargs)
- assert 'self' in arg_map, 'no "self" argument, is this not a constructor?'
- del arg_map['self']
- return arg_map
+ This function relies on Python stack frame support in the interpreter, and
+ has the same limitations that inspect.currentframe.
- if hasattr(c, '__call__'):
- # Callable objects: map self to the object itself
- return tf_inspect.getcallargs(c.__call__, *args, **kwargs)
+ Limitations. This function will only work correctly if the owned class is
+ visible in the caller's global or local variables.
- raise NotImplementedError('unknown callable "%s"' % type(c))
+ Args:
+ m: A user defined function
+ Returns:
+ The class that this function was retrieved from, or None if the function
+ is not an object or class method, or the class that owns the object or
+ method is not visible to m.
-def getmethodclass(m, namespace):
- """Resolves a function's owner, e.g. a method's class."""
+ Raises:
+ ValueError: if the class could not be resolved for any unexpected reason.
+ """
# Instance method and class methods: should be bound to a non-null "self".
# If self is a class, then it's a class method.
@@ -57,34 +82,38 @@ def getmethodclass(m, namespace):
return m.__self__
return type(m.__self__)
- # Class and static methods: platform specific.
- if hasattr(m, 'im_class'): # Python 2
- return m.im_class
-
- if hasattr(m, '__qualname__'): # Python 3
- qn = m.__qualname__.split('.')
- if len(qn) < 2:
- return None
- owner_name, func_name = qn[-2:]
- assert func_name == m.__name__, (
- 'inconsistent names detected '
- '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % (func_name,
- m.__name__, m))
- if owner_name == '<locals>':
- return None
- if owner_name not in namespace:
- raise ValueError(
- 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' %
- (owner_name, m, namespace))
- return namespace[owner_name]
-
- if six.PY2:
- # In Python 2 it's impossible, to our knowledge, to detect the class of a
- # static function. So we're forced to walk all the objects in the
- # namespace and see if they own it. If any reader finds a better solution,
- # please let us know.
- for _, v in namespace.items():
- if hasattr(v, m.__name__) and getattr(v, m.__name__) is m:
- return v
+ # Class, static and unbound methods: search all defined classes in any
+ # namespace. This is inefficient but more robust method.
+ owners = []
+ caller_frame = tf_inspect.currentframe().f_back
+ try:
+ # TODO(mdan): This doesn't consider cell variables.
+ # TODO(mdan): This won't work if the owner is hidden inside a container.
+ # Cell variables may be pulled using co_freevars and the closure.
+ for v in itertools.chain(caller_frame.f_locals.values(),
+ caller_frame.f_globals.values()):
+ if hasattr(v, m.__name__):
+ candidate = getattr(v, m.__name__)
+ # Py2 methods may be bound or unbound, extract im_func to get the
+ # underlying function.
+ if hasattr(candidate, 'im_func'):
+ candidate = candidate.im_func
+ if hasattr(m, 'im_func'):
+ m = m.im_func
+ if candidate is m:
+ owners.append(v)
+ finally:
+ del caller_frame
+
+ if owners:
+ if len(owners) == 1:
+ return owners[0]
+
+ # If multiple owners are found, and are not subclasses, raise an error.
+ owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners)
+ for o in owner_types:
+ if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)):
+ return o
+ raise ValueError('Found too many owners of %s: %s' % (m, owners))
return None
diff --git a/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py b/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py
index 5d92e75b18..5528ac851f 100644
--- a/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py
+++ b/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
from functools import wraps
+import six
+
from tensorflow.contrib.py2tf.pyct import inspect_utils
from tensorflow.python.platform import test
@@ -76,6 +78,10 @@ def free_function():
pass
+def factory():
+ return free_function
+
+
def free_factory():
def local_function():
pass
@@ -84,87 +90,87 @@ def free_factory():
class InspectUtilsTest(test.TestCase):
- def test_getcallargs_constructor(self):
-
- class TestSuperclass(object):
+ def test_getnamespace_globals(self):
+ ns = inspect_utils.getnamespace(factory)
+ self.assertEqual(ns['free_function'], free_function)
- def __init__(self, x):
- pass
-
- class TestCallable(TestSuperclass):
- pass
+ def test_getnamespace_hermetic(self):
- self.assertDictEqual({
- 'x': 1
- }, inspect_utils.getcallargs(TestCallable, 1))
+ # Intentionally hiding the global function to make sure we don't overwrite
+ # it in the global namespace.
+ free_function = object() # pylint:disable=redefined-outer-name
- def test_getcallargs_object(self):
+ def test_fn():
+ return free_function
- class TestCallable(object):
+ ns = inspect_utils.getnamespace(test_fn)
+ globs = six.get_function_globals(test_fn)
+ self.assertTrue(ns['free_function'] is free_function)
+ self.assertFalse(globs['free_function'] is free_function)
- def __call__(self, x):
- pass
+ def test_getnamespace_locals(self):
- obj = TestCallable()
- self.assertDictEqual({
- 'self': obj,
- 'x': 1
- }, inspect_utils.getcallargs(obj, 1))
+ def called_fn():
+ return 0
- def test_getcallargs_function(self):
+ closed_over_list = []
+ closed_over_primitive = 1
- def test_fn(x):
- return x + 1
+ def local_fn():
+ closed_over_list.append(1)
+ local_var = 1
+ return called_fn() + local_var + closed_over_primitive
- self.assertDictEqual({
- 'x': 1
- }, inspect_utils.getcallargs(test_fn, 1))
+ ns = inspect_utils.getnamespace(local_fn)
+ self.assertEqual(ns['called_fn'], called_fn)
+ self.assertEqual(ns['closed_over_list'], closed_over_list)
+ self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
+ self.assertTrue('local_var' not in ns)
def test_getmethodclass(self):
self.assertEqual(
- inspect_utils.getmethodclass(free_function, {}), None)
+ inspect_utils.getmethodclass(free_function), None)
self.assertEqual(
- inspect_utils.getmethodclass(free_factory(), {}), None)
+ inspect_utils.getmethodclass(free_factory()), None)
- ns = {'TestClass': TestClass}
self.assertEqual(
- inspect_utils.getmethodclass(TestClass.member_function, ns),
+ inspect_utils.getmethodclass(TestClass.member_function),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(TestClass.decorated_member, ns),
+ inspect_utils.getmethodclass(TestClass.decorated_member),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(TestClass.fn_decorated_member, ns),
+ inspect_utils.getmethodclass(TestClass.fn_decorated_member),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(TestClass.wrap_decorated_member, ns),
+ inspect_utils.getmethodclass(TestClass.wrap_decorated_member),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(TestClass.static_method, ns),
+ inspect_utils.getmethodclass(TestClass.static_method),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(TestClass.class_method, ns),
+ inspect_utils.getmethodclass(TestClass.class_method),
TestClass)
test_obj = TestClass()
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.member_function, ns),
+ inspect_utils.getmethodclass(test_obj.member_function),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.decorated_member, ns),
+ inspect_utils.getmethodclass(test_obj.decorated_member),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.fn_decorated_member, ns),
+ inspect_utils.getmethodclass(test_obj.fn_decorated_member),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.wrap_decorated_member, ns),
+ inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.static_method, ns),
+ inspect_utils.getmethodclass(test_obj.static_method),
TestClass)
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.class_method, ns),
+ inspect_utils.getmethodclass(test_obj.class_method),
TestClass)
def test_getmethodclass_locals(self):
@@ -190,34 +196,33 @@ class InspectUtilsTest(test.TestCase):
pass
self.assertEqual(
- inspect_utils.getmethodclass(local_function, {}), None)
+ inspect_utils.getmethodclass(local_function), None)
- ns = {'LocalClass': LocalClass}
self.assertEqual(
- inspect_utils.getmethodclass(LocalClass.member_function, ns),
+ inspect_utils.getmethodclass(LocalClass.member_function),
LocalClass)
self.assertEqual(
- inspect_utils.getmethodclass(LocalClass.decorated_member, ns),
+ inspect_utils.getmethodclass(LocalClass.decorated_member),
LocalClass)
self.assertEqual(
- inspect_utils.getmethodclass(LocalClass.fn_decorated_member, ns),
+ inspect_utils.getmethodclass(LocalClass.fn_decorated_member),
LocalClass)
self.assertEqual(
- inspect_utils.getmethodclass(LocalClass.wrap_decorated_member, ns),
+ inspect_utils.getmethodclass(LocalClass.wrap_decorated_member),
LocalClass)
test_obj = LocalClass()
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.member_function, ns),
+ inspect_utils.getmethodclass(test_obj.member_function),
LocalClass)
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.decorated_member, ns),
+ inspect_utils.getmethodclass(test_obj.decorated_member),
LocalClass)
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.fn_decorated_member, ns),
+ inspect_utils.getmethodclass(test_obj.fn_decorated_member),
LocalClass)
self.assertEqual(
- inspect_utils.getmethodclass(test_obj.wrap_decorated_member, ns),
+ inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
LocalClass)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py
index 9c0a9a9e74..0388be5d25 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py
@@ -86,6 +86,7 @@ class LiveValueResolver(transformer.Base):
if not hasattr(parent_object, node.attr):
raise AttributeError('%s has no attribute %s' % (parent_object,
node.attr))
+ anno.setanno(node, 'parent_type', type(parent_object))
anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
# TODO(mdan): Investigate the role built-in annotations can play here.
@@ -96,6 +97,7 @@ class LiveValueResolver(transformer.Base):
# This would not hold for dynamic members like function attributes.
# For the dynamic case, we simply leave the node without an annotation,
# and let downstream consumers figure out what to do.
+ anno.setanno(node, 'parent_type', parent_type)
anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
anno.setanno(node, 'fqn',
anno.getanno(node.value, 'type_fqn') + (node.attr,))
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py
index 1e81bc70a8..c133a455b3 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py
@@ -103,6 +103,7 @@ class LiveValuesResolverTest(test.TestCase):
arg_types={'self': (TestClass.__name__, TestClass)})
func_node = node.body[0].body[0].value.func
self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val'))
+ self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type'))
self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn'))
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index b941819838..3e937ceb64 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -42,11 +42,8 @@ limitations under the License.
namespace tensorflow {
// A few string constant used throughout this module.
-//
-// TODO(zhifengc): Dedup some of these constants into
-// framework/function.h
-static constexpr const char* const kArgOp = "_Arg";
-static constexpr const char* const kRetOp = "_Retval";
+static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
+static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
static constexpr const char* const kGradientOp =
FunctionLibraryDefinition::kGradientOp;
static constexpr const char* const kNodeLabel = "Func";
@@ -177,6 +174,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
}
Device* device() override { return device_; }
+ const DeviceMgr* device_mgr() const override { return device_mgr_; }
Env* env() override { return env_; }
int graph_def_version() override { return graph_def_version_; }
@@ -1580,9 +1578,6 @@ Status FunctionDefToBodyHelper(
// Call BuildControlFlowInfo to validate that this function body has
// well-formed control flow.
- // NOTE(skyewm): this is usually done in Partition(), but we don't partition
- // function bodies. This should be removed if function bodies ever go through
- // the Partition() path.
std::vector<ControlFlowInfo> dummy;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
index 2ed07e3669..bb14e0197b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
@@ -34,7 +34,7 @@ namespace {
class GrpcWorkerCache : public WorkerCachePartial {
public:
// TODO(ncteisen): consider adding a config var or flag for this
- static constexpr const size_t kGrpcWorkerCacheThreadCount = 2;
+ static constexpr const size_t kGrpcWorkerCacheThreadCount = 8;
explicit GrpcWorkerCache(GrpcChannelCache* channel_cache,
WorkerInterface* local_worker,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index 1beb198732..b20e744a97 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -52,7 +52,7 @@ namespace {
class GrpcWorkerService : public AsyncServiceInterface {
// TODO(ncteisen): consider adding a config var or flag for this
- static constexpr const size_t kGrpcWorkerServiceThreadCount = 2;
+ static constexpr const size_t kGrpcWorkerServiceThreadCount = 8;
public:
GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder)
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index eae8e6c3c1..3e7b89d4eb 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -168,7 +168,7 @@ class FunctionInstantiationHelper {
strings::StrAppend(&name, "_", i);
}
NodeDef* gnode = AddNode(name);
- gnode->set_op("_Arg");
+ gnode->set_op(FunctionLibraryDefinition::kArgOp);
AddAttr("T", dtypes[i], gnode);
AddAttr("index", arg_index, gnode);
result_.arg_types.push_back(dtypes[i]);
@@ -328,7 +328,7 @@ class FunctionInstantiationHelper {
strings::StrAppend(&name, "_", i);
}
NodeDef* gnode = AddNode(name);
- gnode->set_op("_Retval");
+ gnode->set_op(FunctionLibraryDefinition::kRetOp);
AddInput(nodes_.size() - 1, item->nid, item->idx + i);
AddAttr("T", dtypes[i], gnode);
AddAttr("index", (*ret_index)++, gnode);
@@ -558,9 +558,9 @@ string Print(gtl::ArraySlice<const NodeDef*> nodes) {
std::vector<const NodeDef*> ret;
std::vector<const NodeDef*> body;
for (const NodeDef* n : nodes) {
- if (n->op() == "_Arg") {
+ if (n->op() == FunctionLibraryDefinition::kArgOp) {
arg.push_back(n);
- } else if (n->op() == "_Retval") {
+ } else if (n->op() == FunctionLibraryDefinition::kRetOp) {
ret.push_back(n);
} else {
body.push_back(n);
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e27001133b..e00399f97d 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -344,6 +344,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
Status LookUp(const string& op_type_name,
const OpRegistrationData** op_reg_data) const override;
+ // Ops created for function arguments bear the name given by `kArgOp`; those
+ // created for return values bear the name given by `kRetOp`.
+ static constexpr const char* const kArgOp = "_Arg";
+ static constexpr const char* const kRetOp = "_Retval";
+
static constexpr const char* const kGradientOp = "SymbolicGradient";
static constexpr const char* const kFuncAttr = "f";
@@ -404,6 +409,8 @@ struct FunctionBody;
// Forward declare. Defined in common_runtime/device.h
class Device;
+// Forward declare. Defined in common_runtime/device_mgr.h
+class DeviceMgr;
class FunctionLibraryRuntime {
public:
@@ -518,6 +525,9 @@ class FunctionLibraryRuntime {
// Returns the device on which the function executes.
virtual Device* device() = 0;
+ // Get the DeviceMgr from which the device was obtained.
+ virtual const DeviceMgr* device_mgr() const = 0;
+
// Returns the function library definition that backs this runtime.
// NOTE(mrry): The returned library definition is the default function library
// for this runtime. The runtime may instantiate functions from separate
diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc
index db6683d1e7..30ff19cd7e 100644
--- a/tensorflow/core/graph/control_flow.cc
+++ b/tensorflow/core/graph/control_flow.cc
@@ -24,23 +24,24 @@ limitations under the License.
namespace tensorflow {
-Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
+Status BuildControlFlowInfo(const Graph* g,
+ std::vector<ControlFlowInfo>* info) {
info->clear();
info->resize(g->num_node_ids());
std::vector<const Node*> parent_nodes;
parent_nodes.resize(g->num_node_ids());
- Node* src_node = g->source_node();
+ const Node* src_node = g->source_node();
ControlFlowInfo& src_info = (*info)[src_node->id()];
src_info.frame = src_node;
src_info.parent_frame = src_node;
string frame_name;
- std::deque<Node*> ready;
+ std::deque<const Node*> ready;
ready.push_back(src_node);
while (!ready.empty()) {
- Node* curr_node = ready.front();
+ const Node* curr_node = ready.front();
ready.pop_front();
const ControlFlowInfo& curr_info = (*info)[curr_node->id()];
const Node* frame = curr_info.frame;
@@ -56,7 +57,7 @@ Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
}
for (const Edge* out_edge : curr_node->out_edges()) {
- Node* out = out_edge->dst();
+ const Node* out = out_edge->dst();
int out_id = out->id();
ControlFlowInfo* out_info = &(*info)[out_id];
const Node* out_parent = out_info->parent_frame;
diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h
index 372044f538..79e2be0d4b 100644
--- a/tensorflow/core/graph/control_flow.h
+++ b/tensorflow/core/graph/control_flow.h
@@ -30,14 +30,14 @@ struct ControlFlowInfo {
string frame_name; // frame name of a node
};
-// Assign to each node the name of the frame and the level it belongs to.
-// We check the well-formedness of the graph: All inputs to a node must
-// come from the same frame and have the same "static" iteration level.
-// `info` is cleared and populated by this function.
-// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level
-// 0. This essentially means there can't be multiple serial Nexts in
-// an iteration, which all sane front-ends should satisfy.
-Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info);
+// Clear and populate `info` with each node's frame and the level it belongs to.
+// We check the well-formedness of the graph: All inputs to a node must come
+// from the same frame and have the same "static" iteration level.
+//
+// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level 0.
+// This essentially means there can't be multiple serial Nexts in an iteration,
+// which all sane front-ends should satisfy.
+Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info);
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index e839630605..50ba48ea7a 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -158,6 +158,18 @@ cc_library(
)
cc_library(
+ name = "custom_graph_optimizer",
+ hdrs = [
+ "custom_graph_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "arithmetic_optimizer",
srcs = ["arithmetic_optimizer.cc"],
hdrs = [
@@ -368,6 +380,8 @@ cc_library(
":arithmetic_optimizer",
":auto_parallel",
":constant_folding",
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
":dependency_optimizer",
":graph_optimizer",
":layout_optimizer",
@@ -382,6 +396,48 @@ cc_library(
],
)
+tf_cc_test(
+ name = "meta_optimizer_test",
+ srcs = ["meta_optimizer_test.cc"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":meta_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
+
+cc_library(
+ name = "custom_graph_optimizer_registry",
+ srcs = ["custom_graph_optimizer_registry.cc"],
+ hdrs = ["custom_graph_optimizer_registry.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":custom_graph_optimizer",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "custom_graph_optimizer_registry_test",
+ size = "small",
+ srcs = ["custom_graph_optimizer_registry_test.cc"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_library(
name = "loop_optimizer",
srcs = ["loop_optimizer.cc"],
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h
new file mode 100644
index 0000000000..a80d46f416
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_
+#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_
+
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// A custom optimizer that can be registered.
+class CustomGraphOptimizer : public GraphOptimizer {
+ public:
+ virtual ~CustomGraphOptimizer() {}
+ virtual Status Init() = 0;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc
new file mode 100644
index 0000000000..6eed43c2b1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace grappler {
+
+namespace {
+typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
+ RegistrationMap;
+RegistrationMap* registered_optimizers = nullptr;
+RegistrationMap* GetRegistrationMap() {
+ if (registered_optimizers == nullptr)
+ registered_optimizers = new RegistrationMap;
+ return registered_optimizers;
+}
+} // namespace
+
+std::unique_ptr<CustomGraphOptimizer>
+CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) {
+ const auto it = GetRegistrationMap()->find(name);
+ if (it == GetRegistrationMap()->end()) return nullptr;
+ return std::unique_ptr<CustomGraphOptimizer>(it->second());
+}
+
+std::vector<string> CustomGraphOptimizerRegistry::GetRegisteredOptimizers() {
+ std::vector<string> optimizer_names;
+ optimizer_names.reserve(GetRegistrationMap()->size());
+ for (const auto& opt : *GetRegistrationMap())
+ optimizer_names.emplace_back(opt.first);
+ return optimizer_names;
+}
+
+void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
+ const Creator& optimizer_creator, const string& name) {
+ const auto it = GetRegistrationMap()->find(name);
+ if (it != GetRegistrationMap()->end()) {
+ LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name;
+ }
+ GetRegistrationMap()->insert({name, optimizer_creator});
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
new file mode 100644
index 0000000000..796da91373
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class CustomGraphOptimizerRegistry {
+ public:
+ static std::unique_ptr<CustomGraphOptimizer> CreateByNameOrNull(
+ const string& name);
+
+ static std::vector<string> GetRegisteredOptimizers();
+
+ typedef std::function<CustomGraphOptimizer*()> Creator;
+ // Regsiter graph optimizer which can be called during program initialization.
+ // This class is not thread-safe.
+ static void RegisterOptimizerOrDie(const Creator& optimizer_creator,
+ const string& name);
+};
+
+class CustomGraphOptimizerRegistrar {
+ public:
+ explicit CustomGraphOptimizerRegistrar(
+ const CustomGraphOptimizerRegistry::Creator& creator,
+ const string& name) {
+ CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(creator, name);
+ }
+};
+
+#define REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, name) \
+ namespace { \
+ static CustomGraphOptimizerRegistrar \
+ MyCustomGraphOptimizerClass##_registrar( \
+ []() { return new MyCustomGraphOptimizerClass; }, (name)); \
+ } // namespace
+
+#define REGISTER_GRAPH_OPTIMIZER(MyCustomGraphOptimizerClass) \
+ REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, \
+ #MyCustomGraphOptimizerClass)
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc
new file mode 100644
index 0000000000..629f5e83c1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+static const char* kTestOptimizerName = "Test";
+
+class TestGraphOptimizer : public CustomGraphOptimizer {
+ public:
+ Status Init() override { return Status::OK(); }
+ string name() const override { return kTestOptimizerName; }
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ return Status::OK();
+ }
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+};
+
+REGISTER_GRAPH_OPTIMIZER_AS(TestGraphOptimizer, "StaticRegister");
+
+TEST(CustomGraphOptimizerRegistryTest, DynamicRegistration) {
+ std::vector<string> optimizers =
+ CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
+ std::unique_ptr<const CustomGraphOptimizer> test_optimizer;
+ ASSERT_EQ(
+ 0, std::count(optimizers.begin(), optimizers.end(), "DynamicRegister"));
+ test_optimizer =
+ CustomGraphOptimizerRegistry::CreateByNameOrNull("DynamicRegister");
+ EXPECT_EQ(nullptr, test_optimizer);
+ CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
+ []() { return new TestGraphOptimizer; }, "DynamicRegister");
+ optimizers = CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
+ ASSERT_EQ(
+ 1, std::count(optimizers.begin(), optimizers.end(), "DynamicRegister"));
+ test_optimizer =
+ CustomGraphOptimizerRegistry::CreateByNameOrNull("DynamicRegister");
+ ASSERT_NE(nullptr, test_optimizer);
+ EXPECT_EQ(kTestOptimizerName, test_optimizer->name());
+}
+
+TEST(CustomGraphOptimizerRegistryTest, StaticRegistration) {
+ const std::vector<string> optimizers =
+ CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
+ EXPECT_EQ(1,
+ std::count(optimizers.begin(), optimizers.end(), "StaticRegister"));
+ std::unique_ptr<const CustomGraphOptimizer> test_optimizer =
+ CustomGraphOptimizerRegistry::CreateByNameOrNull("StaticRegister");
+ ASSERT_NE(nullptr, test_optimizer);
+ EXPECT_EQ(kTestOptimizerName, test_optimizer->name());
+}
+
+TEST(GraphOptimizerRegistryTest, CrashesOnDuplicateRegistration) {
+ const auto creator = []() { return new TestGraphOptimizer; };
+ EXPECT_DEATH(CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
+ creator, "StaticRegister"),
+ "twice");
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e27b9df620..7ae77207af 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
@@ -126,14 +127,26 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
new AutoParallel(cfg_.auto_parallel().num_replicas())));
}
} else {
- std::set<string> available_optimizers = {
+ const std::set<string> available_optimizers = {
"pruning", "constfold", "layout", "memory",
"autoparallel", "arithmetic", "dependency", "loop"};
- for (const auto& optimizer : cfg_.optimizers()) {
- if (available_optimizers.find(optimizer) != available_optimizers.end()) {
- optimizers.push_back(NewOptimizer(optimizer));
+ std::vector<string> custom_optimizer_names;
+ for (const auto& optimizer_name : cfg_.optimizers()) {
+ if (available_optimizers.find(optimizer_name) !=
+ available_optimizers.end()) {
+ optimizers.push_back(NewOptimizer(optimizer_name));
+ } else {
+ custom_optimizer_names.push_back(optimizer_name);
}
}
+ // Now run the custom optimizers.
+ for (const auto& optimizer_name : custom_optimizer_names) {
+ std::unique_ptr<CustomGraphOptimizer> opt =
+ CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
+ if (opt == nullptr) continue;
+ TF_RETURN_IF_ERROR(opt->Init());
+ optimizers.push_back(std::move(opt));
+ }
}
if (optimizers.empty()) {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
new file mode 100644
index 0000000000..536347d834
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class TestOptimizer : public CustomGraphOptimizer {
+ public:
+ static void SetOptimized(const bool flag_value) { optimized_ = flag_value; }
+ static bool IsOptimized() { return optimized_; }
+
+ TestOptimizer() {}
+ string name() const override { return "test_optimizer"; }
+
+ Status Init() override { return Status::OK(); }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ optimized_ = true;
+ *optimized_graph = item.graph;
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ static bool optimized_;
+};
+
+bool TestOptimizer::optimized_;
+
+REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
+
+TEST(MetaOptimizerTest, RunsCustomOptimizer) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizer");
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 9d4bc35ba8..a094ebe5e2 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -32,7 +32,9 @@ limitations under the License.
namespace tensorflow {
-static const char* const kGradientOp = "SymbolicGradient";
+static const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
+static const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
+static const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp;
class ArgOp : public OpKernel {
public:
@@ -89,26 +91,25 @@ class RetvalOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
};
-REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
-REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
+REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
+REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
#if TENSORFLOW_USE_SYCL
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
- Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
+ Name(kArgOp).Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
-TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
+TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
.Device(DEVICE_SYCL)
.HostMemory("output")
.TypeConstraint<int32>("T"),
ArgOp);
#undef REGISTER
-#define REGISTER(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- RetvalOp);
+#define REGISTER(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(kRetOp).Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
-TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
+TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
.Device(DEVICE_SYCL)
.HostMemory("input")
.TypeConstraint<int32>("T"),
@@ -118,16 +119,16 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
- Name("_Arg").Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
+ Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
-TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
+TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
.Device(DEVICE_GPU)
.HostMemory("output")
.TypeConstraint<int32>("T"),
ArgOp);
#undef REGISTER
-REGISTER_KERNEL_BUILDER(Name("_Arg")
+REGISTER_KERNEL_BUILDER(Name(kArgOp)
.Device(DEVICE_GPU)
.HostMemory("output")
.TypeConstraint<ResourceHandle>("T"),
@@ -135,9 +136,9 @@ REGISTER_KERNEL_BUILDER(Name("_Arg")
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
- Name("_Retval").Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
+ Name(kRetOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
-TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
+TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
.Device(DEVICE_GPU)
.HostMemory("input")
.TypeConstraint<int32>("T"),
@@ -287,7 +288,8 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL),
class RemoteCallOp : public AsyncOpKernel {
public:
explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ OP_REQUIRES_OK(ctx,
+ ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
}
~RemoteCallOp() override {}
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 15ae4c1fc5..9237fa51d8 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -280,8 +280,8 @@ __global__ void ColumnReduceMax16ColumnsKernel(
const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp);
// not the most efficient way to do this sum
for (int i = 1; i < rows_in_this_warp; ++i) {
- value_type tmp =
- cub::ShuffleIndex(sum, threadIdx.x + i * num_cols, 32, 0xffffffff);
+ value_type tmp = cub::ShuffleIndex<32, value_type>(
+ sum, static_cast<int>(threadIdx.x + i * num_cols), 0xffffffff);
if (lane < num_cols) sum = op(sum, tmp);
}
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index a61eecaa29..504ed5d819 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -87,5 +87,8 @@ message RewriterConfig {
// ("autoparallel"). Memory optimization passes ("memory") invoked here are
// not configurable (in contrast to memory optimization passes through the
// meta-optimizer) and act only on manual op annotations.
+ //
+ // Custom registered optimizers will be run after the base optimizers, in
+ // the order that they are specified.
repeated string optimizers = 100;
}
diff --git a/tensorflow/docs_src/get_started/datasets_quickstart.md b/tensorflow/docs_src/get_started/datasets_quickstart.md
index bc69773d21..c972e5e555 100644
--- a/tensorflow/docs_src/get_started/datasets_quickstart.md
+++ b/tensorflow/docs_src/get_started/datasets_quickstart.md
@@ -265,9 +265,6 @@ ds = tf.data.TextLineDataset(train_path).skip(1)
### Build a csv line parser
-Ultimately we will need to parse each of the lines in the dataset, to
-produce the necessary `(features, label)` pairs.
-
We will start by building a function to parse a single line.
The following `iris_data.parse_line` function accomplishes this task using the
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index 5be38ae1ef..623ca6bb79 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -118,8 +118,8 @@ Take the following steps to install TensorFlow with Virtualenv:
Python 2.7, the command to install
TensorFlow in the active Virtualenv is as follows:
- <pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl</b></pre>
+ <pre> $ <b>pip install --upgrade \
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common-installation-problems).
@@ -241,8 +241,8 @@ take the following steps:
you are installing TensorFlow for Mac OS and Python 2.7
issue the following command:
- <pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl</b> </pre>
+ <pre> $ <b>sudo pip install --upgrade \
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl</b> </pre>
If the preceding command fails, see
[installation problems](#common-installation-problems).
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 07652d3e02..0e9c21b221 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -60,8 +60,7 @@ class _EagerContext(threading.local):
def __init__(self):
super(_EagerContext, self).__init__()
- self.device_spec = pydev.DeviceSpec.from_string(
- "/job:localhost/replica:0/task:0/device:CPU:0")
+ self.device_spec = pydev.DeviceSpec.from_string("")
self.device_name = self.device_spec.to_string()
self.mode = _default_mode
self.scope_name = ""
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index c68e2f422e..0e40d8a5c0 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
@@ -65,8 +66,7 @@ class TFETest(test_util.TensorFlowTestCase):
ctx.summary_writer_resource = 'mock'
self.assertEqual('mock', ctx.summary_writer_resource)
- self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
- ctx.device_name)
+ self.assertEqual('', ctx.device_name)
self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
with ctx.device('GPU:0'):
self.assertEqual('/job:localhost/replica:0/task:0/device:GPU:0',
@@ -100,6 +100,18 @@ class TFETest(test_util.TensorFlowTestCase):
self.assertEqual(len(cpu_stats.node_stats), 1)
self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add')
+ def testShouldCopy(self):
+ if not context.context().num_gpus():
+ self.skipTest('No devices other than CPUs found')
+ with ops.device('gpu:0'):
+ x = constant_op.constant(1.0)
+ y = array_ops.identity(x)
+ # The value we're testing y.device against will depend on what the behavior
+ # of not explicitly specifying a device in the context is. This behavior is
+ # subject to change (for example, in the future we may want to use GPUs, if
+ # available, when no device is explicitly provided)
+ self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0')
+
def testContextStackContainsEagerMode(self):
# Eager execution has been enabled, and no other context
# switch has occurred, so `context_stack` should contain
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 6fa076507d..3ec2109d32 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -185,6 +185,12 @@ typedef struct EagerTensor {
// This stores `_keras_mask` object and is set by Tensorflow layers.
PyObject* keras_mask;
+
+ // We store a status object here as an optimization to avoid allocating a new
+ // Status objects on different functions that operate on EagerTensor and need
+ // to use a TF_Status object. However note that accesses to `status` are not
+ // thread-safe.
+ TF_Status* status;
} EagerTensor;
// tp_init for EagerTensor.
@@ -195,6 +201,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
self->handle_data = Py_None;
Py_INCREF(Py_None);
self->keras_mask = Py_None;
+ self->status = TF_NewStatus();
PyObject* value;
PyObject* context = nullptr;
PyObject* device = nullptr;
@@ -269,17 +276,17 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
}
TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
- auto out_status = tensorflow::make_safe(TF_NewStatus());
handle = tensorflow::make_safe(
EagerCast(GetContext(context), handle.get(), handle_dtype,
- static_cast<TF_DataType>(desired_dtype), out_status.get()));
- if (TF_GetCode(out_status.get()) != TF_OK) {
- PyErr_SetString(
- PyExc_ValueError,
- tensorflow::strings::StrCat("Error while casting from DataType ",
- handle_dtype, " to ", desired_dtype, ". ",
- TF_Message(out_status.get()))
- .c_str());
+ static_cast<TF_DataType>(desired_dtype), self->status));
+ if (TF_GetCode(self->status) != TF_OK) {
+ PyErr_SetString(PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "Error while casting from DataType ", handle_dtype,
+ " to ", desired_dtype, ". ", TF_Message(self->status))
+ .c_str());
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
return -1;
}
handle_dtype = TFE_TensorHandleDataType(handle.get());
@@ -323,6 +330,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
// tp_dealloc for EagerTensor.
void EagerTensor_dealloc(EagerTensor* self) {
+ TF_DeleteStatus(self->status);
Py_DECREF(self->handle_data);
Py_DECREF(self->keras_mask);
TFE_DeleteTensorHandle(self->handle);
@@ -348,12 +356,21 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
// Getter for `_shape_tuple`.
static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
auto handle = self->handle;
- int n = TFE_TensorHandleNumDims(handle);
+ int n = TFE_TensorHandleNumDims(handle, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ return nullptr;
+ }
PyObject* shape = PyTuple_New(n);
if (PyErr_Occurred()) return nullptr;
for (int i = 0; i < n; ++i) {
- PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i));
- if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
+ PyObject* dim =
+ PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) ||
+ dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
Py_DECREF(shape);
if (dim != nullptr) Py_DECREF(dim);
PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
@@ -365,10 +382,16 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
// Getter for `_rank`.
static PyObject* EagerTensor_rank(EagerTensor* self) {
+ int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ return nullptr;
+ }
#if PY_MAJOR_VERSION < 3
- return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle));
+ return PyInt_FromLong(num_dims);
#else
- return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle));
+ return PyLong_FromLong(num_dims);
#endif
}
@@ -437,10 +460,16 @@ static PyObject* EagerTensor_numpy(EagerTensor* self) {
// Getter `device`.
static PyObject* EagerTensor_device(EagerTensor* self) {
+ const char* device = TFE_TensorHandleDeviceName(self->handle, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ return nullptr;
+ }
#if PY_MAJOR_VERSION >= 3
- return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle));
+ return PyUnicode_FromString(device);
#else
- return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle));
+ return PyBytes_FromString(device);
#endif
}
@@ -576,6 +605,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
Py_INCREF(Py_None);
t->keras_mask = Py_None;
t->handle = handle;
+ t->status = TF_NewStatus();
}
return reinterpret_cast<PyObject*>(t);
}
@@ -673,6 +703,7 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
auto tensor = tensorflow::make_safe(TF_AllocateTensor(
TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
+ auto status = tensorflow::make_safe(TF_NewStatus());
for (Py_ssize_t i = 0; i < num_tensors; ++i) {
PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i);
if (!EagerTensor_CheckExact(tensor_obj)) {
@@ -687,21 +718,27 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
TFE_TensorHandle* handle = t->handle;
- if (slice_dim >= TFE_TensorHandleNumDims(handle)) {
- PyErr_SetString(PyExc_IndexError,
- tensorflow::strings::StrCat(
- "Slice dimension (", slice_dim,
- ") must be smaller than rank of all "
- "tensors, but tensor at index ",
- i, " has rank ", TFE_TensorHandleNumDims(handle))
- .c_str());
+ int num_dims = TFE_TensorHandleNumDims(handle, status.get());
+ if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
+ return nullptr;
+ }
+ if (slice_dim >= num_dims) {
+ PyErr_SetString(
+ PyExc_IndexError,
+ tensorflow::strings::StrCat("Slice dimension (", slice_dim,
+ ") must be smaller than rank of all "
+ "tensors, but tensor at index ",
+ i, " has rank ", num_dims)
+ .c_str());
+ return nullptr;
+ }
+ int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
+ if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
return nullptr;
}
- int64_t dim = TFE_TensorHandleDim(handle, slice_dim);
data[i] = dim;
}
- auto status = tensorflow::make_safe(TF_NewStatus());
TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
if (TF_GetCode(status.get()) != TF_OK) {
PyErr_SetString(
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4fd7003981..7389730d91 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -506,6 +506,30 @@ def assert_no_garbage_created(f):
previous_garbage = len(gc.garbage)
f(self, **kwargs)
gc.collect()
+ if len(gc.garbage) > previous_garbage:
+ logging.error(
+ "The decorated test created work for Python's garbage collector, "
+ "likely due to a reference cycle. New objects in cycle(s):")
+ for i, obj in enumerate(gc.garbage[previous_garbage:]):
+ try:
+ logging.error(
+ "Object %d of %d" % (i, len(gc.garbage) - previous_garbage))
+ def _safe_object_str(obj):
+ return "<%s %d>" % (obj.__class__.__name__, id(obj))
+ logging.error(" Object type: %s" % (_safe_object_str(obj),))
+ logging.error(" Referrer types: %s" % (
+ ', '.join([_safe_object_str(ref)
+ for ref in gc.get_referrers(obj)]),))
+ logging.error(" Referent types: %s" % (
+ ', '.join([_safe_object_str(ref)
+ for ref in gc.get_referents(obj)]),))
+ logging.error(" Object attribute names: %s" % (dir(obj),))
+ logging.error(" Object __str__:")
+ logging.error(obj)
+ logging.error(" Object __repr__:")
+ logging.error(repr(obj))
+ except Exception:
+ logging.error("(Exception while printing object)")
# This will fail if any garbage has been created, typically because of a
# reference cycle.
self.assertEqual(previous_garbage, len(gc.garbage))
@@ -564,6 +588,7 @@ def run_in_graph_and_eager_modes(__unused__=None,
# This decorator runs the wrapped test twice.
# Reset the test environment between runs.
self.tearDown()
+ self._tempdir = None
self.setUp()
def run_eager_mode(self, **kwargs):
diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py
index 655e43e78f..b06fb3c6d0 100644
--- a/tensorflow/python/grappler/hierarchical_controller.py
+++ b/tensorflow/python/grappler/hierarchical_controller.py
@@ -612,10 +612,10 @@ class HierarchicalController(Controller):
num_inter_group_connections = num_connections - num_intra_group_connections
if verbose:
print("grouping evaluation metric")
- print("num_connections={} num_intra_group_connections={} "
- "num_inter_group_connections={}").format(
- num_connections, num_intra_group_connections,
- num_inter_group_connections)
+ print(("num_connections={} num_intra_group_connections={} "
+ "num_inter_group_connections={}").format(
+ num_connections, num_intra_group_connections,
+ num_inter_group_connections))
self.dag_matrix = dag_matrix
# output_shape
@@ -972,8 +972,8 @@ class HierarchicalController(Controller):
controller_ops["reward"]["ph"][child_id]: reward,
})
if verbose:
- print("run_time={:<.5f} reward={:<.5f} "
- "best_reward={:<.5f}").format(run_time, reward, best_reward)
+ print(("run_time={:<.5f} reward={:<.5f} "
+ "best_reward={:<.5f}").format(run_time, reward, best_reward))
# Reward is a double, best_reward a float: allow for some slack in the
# comparison.
diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i
index 1b657983a4..de9326ccfc 100644
--- a/tensorflow/python/grappler/tf_optimizer.i
+++ b/tensorflow/python/grappler/tf_optimizer.i
@@ -100,6 +100,7 @@ PyObject* TF_OptimizeGraph(
tensorflow::grappler::ItemConfig item_config;
item_config.inline_functions = false;
item_config.apply_optimizations = false;
+ item_config.ignore_user_placement = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 16738066ce..a98d08f928 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -39,7 +39,11 @@ py_library(
"_impl/keras/datasets/mnist.py",
"_impl/keras/datasets/reuters.py",
"_impl/keras/engine/__init__.py",
- "_impl/keras/engine/topology.py",
+ "_impl/keras/engine/base_layer.py",
+ "_impl/keras/engine/input_layer.py",
+ "_impl/keras/engine/network.py",
+ "_impl/keras/engine/saving.py",
+ "_impl/keras/engine/sequential.py",
"_impl/keras/engine/training.py",
"_impl/keras/engine/training_eager.py",
"_impl/keras/estimator.py",
@@ -761,9 +765,31 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":keras",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "saving_test",
+ size = "small",
+ srcs = ["_impl/keras/engine/saving_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "sequential_test",
+ size = "small",
+ srcs = ["_impl/keras/engine/sequential_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet.py b/tensorflow/python/keras/_impl/keras/applications/densenet.py
index 6521f84104..ca83e86912 100644
--- a/tensorflow/python/keras/_impl/keras/applications/densenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/densenet.py
@@ -31,7 +31,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
index bf3901fc54..17e407dd58 100644
--- a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
@@ -31,7 +31,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
index e268e97bc6..2897c6058e 100644
--- a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
@@ -37,7 +37,7 @@ from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
index 1bbbedb85e..ad96b53a45 100644
--- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
@@ -79,8 +79,8 @@ from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine import InputSpec
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
from tensorflow.python.keras._impl.keras.layers import Conv2D
diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet.py b/tensorflow/python/keras/_impl/keras/applications/nasnet.py
index 08dae57f00..dd33230a7e 100644
--- a/tensorflow/python/keras/_impl/keras/applications/nasnet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/nasnet.py
@@ -49,7 +49,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import add
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
index a47dd657bb..46c0e63557 100644
--- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
@@ -34,7 +34,7 @@ from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16.py b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
index 9da74253ab..cefb25063e 100644
--- a/tensorflow/python/keras/_impl/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
@@ -32,7 +32,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Conv2D
from tensorflow.python.keras._impl.keras.layers import Dense
from tensorflow.python.keras._impl.keras.layers import Flatten
diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19.py b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
index 961c1f9918..dadaf4fdf0 100644
--- a/tensorflow/python/keras/_impl/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
@@ -32,7 +32,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Conv2D
from tensorflow.python.keras._impl.keras.layers import Dense
from tensorflow.python.keras._impl.keras.layers import Flatten
diff --git a/tensorflow/python/keras/_impl/keras/applications/xception.py b/tensorflow/python/keras/_impl/keras/applications/xception.py
index 7e7ca5a18a..971063a16d 100644
--- a/tensorflow/python/keras/_impl/keras/applications/xception.py
+++ b/tensorflow/python/keras/_impl/keras/applications/xception.py
@@ -44,7 +44,7 @@ from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
from tensorflow.python.keras._impl.keras.layers import Conv2D
diff --git a/tensorflow/python/keras/_impl/keras/engine/__init__.py b/tensorflow/python/keras/_impl/keras/engine/__init__.py
index 31f624f9af..1bc533ab8f 100644
--- a/tensorflow/python/keras/_impl/keras/engine/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/engine/__init__.py
@@ -18,13 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
-from tensorflow.python.keras._impl.keras.engine.topology import Input
-from tensorflow.python.keras._impl.keras.engine.topology import InputLayer
-from tensorflow.python.keras._impl.keras.engine.topology import InputSpec
-from tensorflow.python.keras._impl.keras.engine.topology import Layer
+from tensorflow.python.keras._impl.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
+from tensorflow.python.keras._impl.keras.engine.input_layer import Input
+from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer
+from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.network import Network
from tensorflow.python.keras._impl.keras.engine.training import Model
-
-
-# Note: topology.Node is an internal class,
-# it isn't meant to be used by Keras users.
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
new file mode 100644
index 0000000000..142325041b
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -0,0 +1,504 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Base layer code (`Layer`).
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import zip # pylint: disable=redefined-builtin
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.utils import generic_utils
+from tensorflow.python.layers import base as tf_base_layers
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
+
+
+# pylint: disable=invalid-name
+InputSpec = tf_base_layers.InputSpec
+Node = tf_base_layers.Node
+TFBaseLayer = tf_base_layers.Layer
+# pylint: enable=invalid-name
+
+
+@tf_export('keras.layers.Layer')
+class Layer(tf_base_layers.Layer):
+ """Abstract base layer class.
+
+ # Properties
+ name: String, must be unique within a model.
+ input_spec: List of InputSpec class instances
+ each entry describes one required input:
+ - ndim
+ - dtype
+ A layer with `n` input tensors must have
+ an `input_spec` of length `n`.
+ trainable: Boolean, whether the layer weights
+ will be updated during training.
+ uses_learning_phase: Whether any operation
+ of the layer uses `K.in_training_phase()`
+ or `K.in_test_phase()`.
+ input_shape: Shape tuple. Provided for convenience,
+ but note that there may be cases in which this
+ attribute is ill-defined (e.g. a shared layer
+ with multiple input shapes), in which case
+ requesting `input_shape` will raise an Exception.
+ Prefer using `layer.get_input_shape_for(input_shape)`,
+ or `layer.get_input_shape_at(node_index)`.
+ output_shape: Shape tuple. See above.
+ inbound_nodes: List of nodes.
+ outbound_nodes: List of nodes.
+ input, output: Input/output tensor(s). Note that if the layer is used
+ more than once (shared layer), this is ill-defined
+ and will raise an exception. In such cases, use
+ `layer.get_input_at(node_index)`.
+ input_mask, output_mask: Same as above, for masks.
+ trainable_weights: List of variables.
+ non_trainable_weights: List of variables.
+ weights: The concatenation of the lists trainable_weights and
+ non_trainable_weights (in this order).
+
+ # Methods
+ call(x, mask=None): Where the layer's logic lives.
+ __call__(x, mask=None): Wrapper around the layer logic (`call`).
+ If x is a Keras tensor:
+ - Connect current layer with last layer from tensor:
+ `self._add_inbound_node(last_layer)`
+ - Add layer to tensor history
+ If layer is not built:
+ - Build from inputs shape
+ get_weights()
+ set_weights(weights)
+ get_config()
+ count_params()
+ compute_output_shape(input_shape)
+ compute_mask(x, mask)
+ get_input_at(node_index)
+ get_output_at(node_index)
+ get_input_shape_at(node_index)
+ get_output_shape_at(node_index)
+ get_input_mask_at(node_index)
+ get_output_mask_at(node_index)
+
+ # Class Methods
+ from_config(config)
+
+ # Internal methods:
+ build(input_shape)
+ _add_inbound_node(layer, index=0)
+ """
+
+ def __init__(self, **kwargs):
+ # These properties should be set by the user via keyword arguments.
+ # note that 'dtype', 'input_shape' and 'batch_input_shape'
+ # are only applicable to input layers: do not pass these keywords
+ # to non-input layers.
+ allowed_kwargs = {
+ 'activity_regularizer',
+ 'input_shape',
+ 'batch_input_shape',
+ 'batch_size',
+ 'dtype',
+ 'name',
+ 'trainable',
+ 'weights',
+ }
+ # Validate optional keyword arguments.
+ for kwarg in kwargs:
+ if kwarg not in allowed_kwargs:
+ raise TypeError('Keyword argument not understood:', kwarg)
+
+ # Get layer name.
+ name = kwargs.get('name')
+
+ # Get `trainable` status.
+ trainable = kwargs.get('trainable', True)
+
+ # Get `dtype`.
+ dtype = kwargs.get('dtype')
+ if dtype is None:
+ dtype = K.floatx()
+
+ # Call super, which will set all properties common to Keras layers
+ # and core TF layers.
+ super(Layer, self).__init__(
+ name=name, dtype=dtype, trainable=trainable,
+ activity_regularizer=kwargs.get('activity_regularizer'))
+
+ # Add properties that are Keras-only for now.
+ self.supports_masking = False
+
+ # Manage input shape information if passed.
+ if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
+ # In this case we will later create an input layer
+ # to insert before the current layer
+ if 'batch_input_shape' in kwargs:
+ batch_input_shape = tuple(kwargs['batch_input_shape'])
+ elif 'input_shape' in kwargs:
+ if 'batch_size' in kwargs:
+ batch_size = kwargs['batch_size']
+ else:
+ batch_size = None
+ batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
+ self._batch_input_shape = batch_input_shape
+
+ # Manage initial weight values if passed.
+ if 'weights' in kwargs:
+ self._initial_weights = kwargs['weights']
+ else:
+ self._initial_weights = None
+
+ def add_weight(self,
+ name,
+ shape,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ constraint=None):
+ """Adds a weight variable to the layer.
+
+ Arguments:
+ name: String, the name for the weight variable.
+ shape: The shape tuple of the weight.
+ dtype: The dtype of the weight.
+ initializer: An Initializer instance (callable).
+ regularizer: An optional Regularizer instance.
+ trainable: A boolean, whether the weight should
+ be trained via backprop or not (assuming
+ that the layer itself is also trainable).
+ constraint: An optional Constraint instance.
+
+ Returns:
+ The created weight variable.
+ """
+ if dtype is None:
+ dtype = K.floatx()
+ weight = self.add_variable(name, shape,
+ dtype=dtype,
+ initializer=initializers.get(initializer),
+ regularizer=regularizers.get(regularizer),
+ constraint=constraints.get(constraint),
+ trainable=trainable)
+ return weight
+
+ def call(self, inputs, **kwargs): # pylint: disable=unused-argument
+ """This is where the layer's logic lives.
+
+ Arguments:
+ inputs: Input tensor, or list/tuple of input tensors.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ A tensor or list/tuple of tensors.
+ """
+ return inputs
+
+ def __call__(self, inputs, **kwargs):
+ """Wrapper around self.call(), for handling internal references.
+
+ If a Keras tensor is passed:
+ - We call self._add_inbound_node().
+ - If necessary, we `build` the layer to match
+ the shape of the input(s).
+ - We update the _keras_history of the output tensor(s)
+ with the current layer.
+ This is done as part of _add_inbound_node().
+
+ Arguments:
+ inputs: Can be a tensor or list/tuple of tensors.
+ **kwargs: Additional keyword arguments to be passed to `call()`.
+
+ Returns:
+ Output of the layer's `call` method.
+
+ Raises:
+ ValueError: in case the layer is missing shape information
+ for its `build` call.
+ """
+ # Actually call the layer (optionally building it).
+ output = super(Layer, self).__call__(inputs, **kwargs)
+ if context.in_eager_mode():
+ return output
+
+ # Un-built subclassed network: build it
+ if hasattr(self, '_set_inputs') and not self.inputs:
+ self._set_inputs(inputs, training=kwargs.get('training'))
+
+ # Update learning phase info.
+ output_tensors = generic_utils.to_list(output)
+ uses_lp = any(
+ [getattr(x, '_uses_learning_phase', False)
+ for x in generic_utils.to_list(inputs)])
+ uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp
+ for i in range(len(output_tensors)):
+ output_tensors[i]._uses_learning_phase = getattr(
+ output_tensors[i], '_uses_learning_phase', False) or uses_lp
+
+ # Optionally load weight values that were specified at layer instantiation.
+ if hasattr(self, '_initial_weights') and self._initial_weights is not None:
+ self.set_weights(self._initial_weights)
+ del self._initial_weights
+ return output
+
+ def compute_output_shape(self, input_shape):
+ """Computes the output shape of the layer.
+
+ Assumes that the layer will be built
+ to match that input shape provided.
+
+ Arguments:
+ input_shape: Shape tuple (tuple of integers)
+ or list of shape tuples (one per output tensor of the layer).
+ Shape tuples can include None for free dimensions,
+ instead of an integer.
+
+ Returns:
+ An input shape tuple.
+ """
+ logging.warning(
+ 'All custom layers should implement the '
+ '`compute_output_shape` method. This layer (' + self.name + ') '
+ 'is relying on the base `Layer.compute_output_shape` implementation, '
+ 'which will start raising a `NotImplementedError` '
+ 'as of July 1st, 2018.')
+ return input_shape
+
+ def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
+ """Computes an output mask tensor.
+
+ Arguments:
+ inputs: Tensor or list of tensors.
+ mask: Tensor or list of tensors.
+
+ Returns:
+ None or a tensor (or list of tensors,
+ one per output tensor of the layer).
+ """
+ if not self.supports_masking:
+ if mask is not None:
+ if isinstance(mask, list):
+ if any(m is not None for m in mask):
+ raise TypeError('Layer ' + self.name + ' does not support masking, '
+ 'but was passed an input_mask: ' + str(mask))
+ else:
+ raise TypeError('Layer ' + self.name + ' does not support masking, '
+ 'but was passed an input_mask: ' + str(mask))
+ # masking not explicitly supported: return None as mask
+ return None
+ # if masking is explicitly supported, by default
+ # carry over the input mask
+ return mask
+
+ def get_input_mask_at(self, node_index):
+ """Retrieves the input mask tensor(s) of a layer at a given node.
+
+ Arguments:
+ node_index: Integer, index of the node
+ from which to retrieve the attribute.
+ E.g. `node_index=0` will correspond to the
+ first time the layer was called.
+
+ Returns:
+ A mask tensor
+ (or list of tensors if the layer has multiple inputs).
+ """
+ inputs = self.get_input_at(node_index)
+ if isinstance(inputs, list):
+ return [getattr(x, '_keras_mask', None) for x in inputs]
+ else:
+ return getattr(inputs, '_keras_mask', None)
+
+ def get_output_mask_at(self, node_index):
+ """Retrieves the output mask tensor(s) of a layer at a given node.
+
+ Arguments:
+ node_index: Integer, index of the node
+ from which to retrieve the attribute.
+ E.g. `node_index=0` will correspond to the
+ first time the layer was called.
+
+ Returns:
+ A mask tensor
+ (or list of tensors if the layer has multiple outputs).
+ """
+ output = self.get_output_at(node_index)
+ if isinstance(output, list):
+ return [getattr(x, '_keras_mask', None) for x in output]
+ else:
+ return getattr(output, '_keras_mask', None)
+
+ @property
+ def input_mask(self):
+ """Retrieves the input mask tensor(s) of a layer.
+
+ Only applicable if the layer has exactly one inbound node,
+ i.e. if it is connected to one incoming layer.
+
+ Returns:
+ Input mask tensor (potentially None) or list of input
+ mask tensors.
+
+ Raises:
+ AttributeError: if the layer is connected to
+ more than one incoming layers.
+ """
+ inputs = self.input
+ if isinstance(inputs, list):
+ return [getattr(x, '_keras_mask', None) for x in inputs]
+ else:
+ return getattr(inputs, '_keras_mask', None)
+
+ @property
+ def output_mask(self):
+ """Retrieves the output mask tensor(s) of a layer.
+
+ Only applicable if the layer has exactly one inbound node,
+ i.e. if it is connected to one incoming layer.
+
+ Returns:
+ Output mask tensor (potentially None) or list of output
+ mask tensors.
+
+ Raises:
+ AttributeError: if the layer is connected to
+ more than one incoming layers.
+ """
+ output = self.output
+ if isinstance(output, list):
+ return [getattr(x, '_keras_mask', None) for x in output]
+ else:
+ return getattr(output, '_keras_mask', None)
+
+ def set_weights(self, weights):
+ """Sets the weights of the layer, from Numpy arrays.
+
+ Arguments:
+ weights: a list of Numpy arrays. The number
+ of arrays and their shape must match
+ number of the dimensions of the weights
+ of the layer (i.e. it should match the
+ output of `get_weights`).
+
+ Raises:
+ ValueError: If the provided weights list does not match the
+ layer's specifications.
+ """
+ params = self.weights
+ if len(params) != len(weights):
+ raise ValueError('You called `set_weights(weights)` on layer "' +
+ self.name + '" with a weight list of length ' +
+ str(len(weights)) + ', but the layer was expecting ' +
+ str(len(params)) + ' weights. Provided weights: ' +
+ str(weights)[:50] + '...')
+ if not params:
+ return
+ weight_value_tuples = []
+ param_values = K.batch_get_value(params)
+ for pv, p, w in zip(param_values, params, weights):
+ if pv.shape != w.shape:
+ raise ValueError('Layer weight shape ' + str(pv.shape) +
+ ' not compatible with '
+ 'provided weight shape ' + str(w.shape))
+ weight_value_tuples.append((p, w))
+ K.batch_set_value(weight_value_tuples)
+
+ def get_weights(self):
+ """Returns the current weights of the layer.
+
+ Returns:
+ Weights values as a list of numpy arrays.
+ """
+ params = self.weights
+ return K.batch_get_value(params)
+
+ def get_config(self):
+ """Returns the config of the layer.
+
+ A layer config is a Python dictionary (serializable)
+ containing the configuration of a layer.
+ The same layer can be reinstantiated later
+ (without its trained weights) from this configuration.
+
+ The config of a layer does not include connectivity
+ information, nor the layer class name. These are handled
+ by `Network` (one layer of abstraction above).
+
+ Returns:
+ Python dictionary.
+ """
+ config = {'name': self.name, 'trainable': self.trainable}
+ if hasattr(self, '_batch_input_shape'):
+ config['batch_input_shape'] = self._batch_input_shape
+ if hasattr(self, 'dtype'):
+ config['dtype'] = self.dtype
+ return config
+
+ @classmethod
+ def from_config(cls, config):
+ """Creates a layer from its config.
+
+ This method is the reverse of `get_config`,
+ capable of instantiating the same layer from the config
+ dictionary. It does not handle layer connectivity
+ (handled by Network), nor weights (handled by `set_weights`).
+
+ Arguments:
+ config: A Python dictionary, typically the
+ output of get_config.
+
+ Returns:
+ A layer instance.
+ """
+ return cls(**config)
+
+ @tf_base_layers.Layer.activity_regularizer.setter
+ def activity_regularizer(self, activity_regularizer):
+ self._activity_regularizer = activity_regularizer
+
+
+def shape_type_conversion(fn):
+ """Decorator that handles tuple/TensorShape conversion.
+
+ Used in `compute_output_shape` and `build`.
+
+ Arguments:
+ fn: function to wrap.
+
+ Returns:
+ Wrapped function.
+ """
+
+ def wrapper(instance, input_shape):
+ if input_shape is not None:
+ if isinstance(input_shape, list):
+ input_shape = [
+ tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape]
+ else:
+ input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+ output_shape = fn(instance, input_shape)
+ if output_shape is not None:
+ if isinstance(output_shape, list):
+ return [tensor_shape.TensorShape(x) for x in output_shape]
+ return tensor_shape.TensorShape(output_shape)
+
+ return wrapper
diff --git a/tensorflow/python/keras/_impl/keras/engine/input_layer.py b/tensorflow/python/keras/_impl/keras/engine/input_layer.py
new file mode 100644
index 0000000000..8f9ea6f7a4
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/engine/input_layer.py
@@ -0,0 +1,230 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Input layer code (`Input` and `InputLayer`).
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.engine import base_layer
+from tensorflow.python.layers import base as tf_base_layers
+from tensorflow.python.ops import array_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class InputLayer(base_layer.Layer):
+ """Layer to be used as an entry point into a Network (a graph of layers).
+
+ It can either wrap an existing tensor (pass an `input_tensor` argument)
+ or create its a placeholder tensor (pass arguments `input_shape`, and
+ optionally, `dtype`).
+
+ It is generally recommend to use the functional layer API via `Input`,
+ (which creates an `InputLayer`) without directly using `InputLayer`.
+
+ Arguments:
+ input_shape: Shape tuple (not including the batch axis), or `TensorShape`
+ instance (not including the batch axis).
+ batch_size: Optional input batch size (integer or None).
+ dtype: Datatype of the input.
+ input_tensor: Optional tensor to use as layer input
+ instead of creating a placeholder.
+ sparse: Boolean, whether the placeholder created
+ is meant to be sparse.
+ name: Name of the layer (string).
+ """
+
+ def __init__(self,
+ input_shape=None,
+ batch_size=None,
+ dtype=None,
+ input_tensor=None,
+ sparse=False,
+ name=None,
+ **kwargs):
+ if 'batch_input_shape' in kwargs:
+ batch_input_shape = kwargs.pop('batch_input_shape')
+ if input_shape and batch_input_shape:
+ raise ValueError('Only provide the input_shape OR '
+ 'batch_input_shape argument to '
+ 'InputLayer, not both at the same time.')
+ batch_size = batch_input_shape[0]
+ input_shape = batch_input_shape[1:]
+ if kwargs:
+ raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
+
+ if not name:
+ prefix = 'input'
+ name = prefix + '_' + str(K.get_uid(prefix))
+
+ if not dtype:
+ if input_tensor is None:
+ dtype = K.floatx()
+ else:
+ dtype = K.dtype(input_tensor)
+ super(InputLayer, self).__init__(dtype=dtype, name=name)
+ self.built = True
+ self.sparse = sparse
+ self.batch_size = batch_size
+
+ if isinstance(input_shape, tensor_shape.TensorShape):
+ input_shape = tuple(input_shape.as_list())
+
+ if input_tensor is None:
+ if input_shape is not None:
+ batch_input_shape = (batch_size,) + tuple(input_shape)
+ else:
+ batch_input_shape = None
+
+ if context.in_eager_mode():
+ # In eager mode, create a temporary placeholder to call the layer on.
+ input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access
+ shape=batch_input_shape,
+ dtype=dtype,
+ name=self.name)
+ else:
+ # In graph mode, create a graph placeholder to call the layer on.
+ if sparse:
+ input_tensor = array_ops.sparse_placeholder(
+ shape=batch_input_shape,
+ dtype=dtype,
+ name=self.name)
+ else:
+ input_tensor = array_ops.placeholder(
+ shape=batch_input_shape,
+ dtype=dtype,
+ name=self.name)
+
+ # For compatibility with Keras API.
+ self.is_placeholder = True
+ self._batch_input_shape = batch_input_shape
+ else:
+ # For compatibility with Keras API.
+ self.is_placeholder = False
+ self._batch_input_shape = tuple(input_tensor.get_shape().as_list())
+
+ # Create an input node to add to self.outbound_node
+ # and set output_tensors' _keras_history.
+ input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access
+ tf_base_layers.Node(
+ self,
+ inbound_layers=[],
+ node_indices=[],
+ tensor_indices=[],
+ input_tensors=[input_tensor],
+ output_tensors=[input_tensor])
+
+ def get_config(self):
+ config = {
+ 'batch_input_shape': self._batch_input_shape,
+ 'dtype': self.dtype,
+ 'sparse': self.sparse,
+ 'name': self.name
+ }
+ return config
+
+
+@tf_export('keras.layers.Input', 'keras.Input')
+def Input( # pylint: disable=invalid-name
+ shape=None,
+ batch_size=None,
+ name=None,
+ dtype=None,
+ sparse=False,
+ tensor=None,
+ **kwargs):
+ """`Input()` is used to instantiate a Keras tensor.
+
+ A Keras tensor is a tensor object from the underlying backend
+ (Theano or TensorFlow), which we augment with certain
+ attributes that allow us to build a Keras model
+ just by knowing the inputs and outputs of the model.
+
+ For instance, if a, b and c are Keras tensors,
+ it becomes possible to do:
+ `model = Model(input=[a, b], output=c)`
+
+ The added Keras attribute is:
+ `_keras_history`: Last layer applied to the tensor.
+ the entire layer graph is retrievable from that layer,
+ recursively.
+
+ Arguments:
+ shape: A shape tuple (integers), not including the batch size.
+ For instance, `shape=(32,)` indicates that the expected input
+ will be batches of 32-dimensional vectors.
+ batch_size: optional static batch size (integer).
+ name: An optional name string for the layer.
+ Should be unique in a model (do not reuse the same name twice).
+ It will be autogenerated if it isn't provided.
+ dtype: The data type expected by the input, as a string
+ (`float32`, `float64`, `int32`...)
+ sparse: A boolean specifying whether the placeholder
+ to be created is sparse.
+ tensor: Optional existing tensor to wrap into the `Input` layer.
+ If set, the layer will not create a placeholder tensor.
+ **kwargs: deprecated arguments support.
+
+ Returns:
+ A tensor.
+
+ Example:
+
+ ```python
+ # this is a logistic regression in Keras
+ x = Input(shape=(32,))
+ y = Dense(16, activation='softmax')(x)
+ model = Model(x, y)
+ ```
+
+ Raises:
+ ValueError: in case of invalid arguments.
+ """
+ if 'batch_shape' in kwargs:
+ batch_shape = kwargs.pop('batch_shape')
+ if shape and batch_shape:
+ raise ValueError('Only provide the shape OR '
+ 'batch_shape argument to '
+ 'Input, not both at the same time.')
+ batch_size = batch_shape[0]
+ shape = batch_shape[1:]
+ if kwargs:
+ raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
+
+ if dtype is None:
+ dtype = K.floatx()
+ if not shape and tensor is None:
+ raise ValueError('Please provide to Input either a `shape`'
+ ' or a `tensor` argument. Note that '
+ '`shape` does not include the batch '
+ 'dimension.')
+ input_layer = InputLayer(
+ input_shape=shape,
+ batch_size=batch_size,
+ name=name,
+ dtype=dtype,
+ sparse=sparse,
+ input_tensor=tensor)
+ # Return tensor including `_keras_history`.
+ # Note that in this case train_output and test_output are the same pointer.
+ outputs = input_layer._inbound_nodes[0].output_tensors
+ if len(outputs) == 1:
+ return outputs[0]
+ else:
+ return outputs
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/network.py
index f562a19cf5..453cc8f8b7 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology.py
+++ b/tensorflow/python/keras/_impl/keras/engine/network.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
-"""Base layer code and base model (Network) code.
+"""A `Network` is way to compose layers: the topological form of a `Model`.
"""
from __future__ import absolute_import
from __future__ import division
@@ -30,19 +30,16 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
-from tensorflow.python.keras._impl.keras import constraints
-from tensorflow.python.keras._impl.keras import initializers
-from tensorflow.python.keras._impl.keras import regularizers
-from tensorflow.python.keras._impl.keras.utils import conv_utils
+from tensorflow.python.keras._impl.keras.engine import base_layer
+from tensorflow.python.keras._impl.keras.engine import saving
+from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary
from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.layers import utils as tf_layers_util
-from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
-from tensorflow.python.util.tf_export import tf_export
# pylint: disable=g-import-not-at-top
@@ -57,684 +54,12 @@ except ImportError:
yaml = None
# pylint: enable=g-import-not-at-top
-# pylint: disable=invalid-name
-InputSpec = tf_base_layers.InputSpec
-Node = tf_base_layers.Node
-TFBaseLayer = tf_base_layers.Layer
-# pylint: enable=invalid-name
-
-
-@tf_export('keras.layers.Layer')
-class Layer(tf_base_layers.Layer):
- """Abstract base layer class.
-
- # Properties
- name: String, must be unique within a model.
- input_spec: List of InputSpec class instances
- each entry describes one required input:
- - ndim
- - dtype
- A layer with `n` input tensors must have
- an `input_spec` of length `n`.
- trainable: Boolean, whether the layer weights
- will be updated during training.
- uses_learning_phase: Whether any operation
- of the layer uses `K.in_training_phase()`
- or `K.in_test_phase()`.
- input_shape: Shape tuple. Provided for convenience,
- but note that there may be cases in which this
- attribute is ill-defined (e.g. a shared layer
- with multiple input shapes), in which case
- requesting `input_shape` will raise an Exception.
- Prefer using `layer.get_input_shape_for(input_shape)`,
- or `layer.get_input_shape_at(node_index)`.
- output_shape: Shape tuple. See above.
- inbound_nodes: List of nodes.
- outbound_nodes: List of nodes.
- input, output: Input/output tensor(s). Note that if the layer is used
- more than once (shared layer), this is ill-defined
- and will raise an exception. In such cases, use
- `layer.get_input_at(node_index)`.
- input_mask, output_mask: Same as above, for masks.
- trainable_weights: List of variables.
- non_trainable_weights: List of variables.
- weights: The concatenation of the lists trainable_weights and
- non_trainable_weights (in this order).
-
- # Methods
- call(x, mask=None): Where the layer's logic lives.
- __call__(x, mask=None): Wrapper around the layer logic (`call`).
- If x is a Keras tensor:
- - Connect current layer with last layer from tensor:
- `self._add_inbound_node(last_layer)`
- - Add layer to tensor history
- If layer is not built:
- - Build from inputs shape
- get_weights()
- set_weights(weights)
- get_config()
- count_params()
- compute_output_shape(input_shape)
- compute_mask(x, mask)
- get_input_at(node_index)
- get_output_at(node_index)
- get_input_shape_at(node_index)
- get_output_shape_at(node_index)
- get_input_mask_at(node_index)
- get_output_mask_at(node_index)
-
- # Class Methods
- from_config(config)
-
- # Internal methods:
- build(input_shape)
- _add_inbound_node(layer, index=0)
- """
-
- def __init__(self, **kwargs):
- # These properties should be set by the user via keyword arguments.
- # note that 'dtype', 'input_shape' and 'batch_input_shape'
- # are only applicable to input layers: do not pass these keywords
- # to non-input layers.
- allowed_kwargs = {
- 'activity_regularizer',
- 'input_shape',
- 'batch_input_shape',
- 'batch_size',
- 'dtype',
- 'name',
- 'trainable',
- 'weights',
- }
- # Validate optional keyword arguments.
- for kwarg in kwargs:
- if kwarg not in allowed_kwargs:
- raise TypeError('Keyword argument not understood:', kwarg)
-
- # Get layer name.
- name = kwargs.get('name')
-
- # Get `trainable` status.
- trainable = kwargs.get('trainable', True)
-
- # Get `dtype`.
- dtype = kwargs.get('dtype')
- if dtype is None:
- dtype = K.floatx()
-
- # Call super, which will set all properties common to Keras layers
- # and core TF layers.
- super(Layer, self).__init__(
- name=name, dtype=dtype, trainable=trainable,
- activity_regularizer=kwargs.get('activity_regularizer'))
-
- # Add properties that are Keras-only for now.
- self.supports_masking = False
-
- # Manage input shape information if passed.
- if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
- # In this case we will later create an input layer
- # to insert before the current layer
- if 'batch_input_shape' in kwargs:
- batch_input_shape = tuple(kwargs['batch_input_shape'])
- elif 'input_shape' in kwargs:
- if 'batch_size' in kwargs:
- batch_size = kwargs['batch_size']
- else:
- batch_size = None
- batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
- self._batch_input_shape = batch_input_shape
-
- # Manage initial weight values if passed.
- if 'weights' in kwargs:
- self._initial_weights = kwargs['weights']
- else:
- self._initial_weights = None
-
- def add_weight(self,
- name,
- shape,
- dtype=None,
- initializer=None,
- regularizer=None,
- trainable=True,
- constraint=None):
- """Adds a weight variable to the layer.
-
- Arguments:
- name: String, the name for the weight variable.
- shape: The shape tuple of the weight.
- dtype: The dtype of the weight.
- initializer: An Initializer instance (callable).
- regularizer: An optional Regularizer instance.
- trainable: A boolean, whether the weight should
- be trained via backprop or not (assuming
- that the layer itself is also trainable).
- constraint: An optional Constraint instance.
-
- Returns:
- The created weight variable.
- """
- if dtype is None:
- dtype = K.floatx()
- weight = self.add_variable(name, shape,
- dtype=dtype,
- initializer=initializers.get(initializer),
- regularizer=regularizers.get(regularizer),
- constraint=constraints.get(constraint),
- trainable=trainable)
- return weight
-
- def call(self, inputs, **kwargs): # pylint: disable=unused-argument
- """This is where the layer's logic lives.
-
- Arguments:
- inputs: Input tensor, or list/tuple of input tensors.
- **kwargs: Additional keyword arguments.
-
- Returns:
- A tensor or list/tuple of tensors.
- """
- return inputs
-
- def __call__(self, inputs, **kwargs):
- """Wrapper around self.call(), for handling internal references.
-
- If a Keras tensor is passed:
- - We call self._add_inbound_node().
- - If necessary, we `build` the layer to match
- the shape of the input(s).
- - We update the _keras_history of the output tensor(s)
- with the current layer.
- This is done as part of _add_inbound_node().
-
- Arguments:
- inputs: Can be a tensor or list/tuple of tensors.
- **kwargs: Additional keyword arguments to be passed to `call()`.
-
- Returns:
- Output of the layer's `call` method.
-
- Raises:
- ValueError: in case the layer is missing shape information
- for its `build` call.
- """
- # Actually call the layer (optionally building it).
- output = super(Layer, self).__call__(inputs, **kwargs)
- if context.in_eager_mode():
- return output
-
- # Un-built subclassed network: build it
- if isinstance(self, Network) and not self.inputs:
- self._set_inputs(inputs, training=kwargs.get('training'))
-
- # Update learning phase info.
- output_tensors = to_list(output)
- uses_lp = any(
- [getattr(x, '_uses_learning_phase', False) for x in to_list(inputs)])
- uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp
- for i in range(len(output_tensors)):
- output_tensors[i]._uses_learning_phase = getattr(
- output_tensors[i], '_uses_learning_phase', False) or uses_lp
-
- # Optionally load weight values that were specified at layer instantiation.
- if hasattr(self, '_initial_weights') and self._initial_weights is not None:
- self.set_weights(self._initial_weights)
- del self._initial_weights
- return output
-
- def compute_output_shape(self, input_shape):
- """Computes the output shape of the layer.
-
- Assumes that the layer will be built
- to match that input shape provided.
-
- Arguments:
- input_shape: Shape tuple (tuple of integers)
- or list of shape tuples (one per output tensor of the layer).
- Shape tuples can include None for free dimensions,
- instead of an integer.
-
- Returns:
- An input shape tuple.
- """
- logging.warning(
- 'All custom layers should implement the '
- '`compute_output_shape` method. This layer (' + self.name + ') '
- 'is relying on the base `Layer.compute_output_shape` implementation, '
- 'which will start raising a `NotImplementedError` '
- 'as of July 1st, 2018.')
- return input_shape
-
- def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
- """Computes an output mask tensor.
-
- Arguments:
- inputs: Tensor or list of tensors.
- mask: Tensor or list of tensors.
-
- Returns:
- None or a tensor (or list of tensors,
- one per output tensor of the layer).
- """
- if not self.supports_masking:
- if mask is not None:
- if isinstance(mask, list):
- if any(m is not None for m in mask):
- raise TypeError('Layer ' + self.name + ' does not support masking, '
- 'but was passed an input_mask: ' + str(mask))
- else:
- raise TypeError('Layer ' + self.name + ' does not support masking, '
- 'but was passed an input_mask: ' + str(mask))
- # masking not explicitly supported: return None as mask
- return None
- # if masking is explicitly supported, by default
- # carry over the input mask
- return mask
-
- def get_input_mask_at(self, node_index):
- """Retrieves the input mask tensor(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A mask tensor
- (or list of tensors if the layer has multiple inputs).
- """
- inputs = self.get_input_at(node_index)
- if isinstance(inputs, list):
- return [getattr(x, '_keras_mask', None) for x in inputs]
- else:
- return getattr(inputs, '_keras_mask', None)
-
- def get_output_mask_at(self, node_index):
- """Retrieves the output mask tensor(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A mask tensor
- (or list of tensors if the layer has multiple outputs).
- """
- output = self.get_output_at(node_index)
- if isinstance(output, list):
- return [getattr(x, '_keras_mask', None) for x in output]
- else:
- return getattr(output, '_keras_mask', None)
-
- @property
- def input_mask(self):
- """Retrieves the input mask tensor(s) of a layer.
-
- Only applicable if the layer has exactly one inbound node,
- i.e. if it is connected to one incoming layer.
-
- Returns:
- Input mask tensor (potentially None) or list of input
- mask tensors.
-
- Raises:
- AttributeError: if the layer is connected to
- more than one incoming layers.
- """
- inputs = self.input
- if isinstance(inputs, list):
- return [getattr(x, '_keras_mask', None) for x in inputs]
- else:
- return getattr(inputs, '_keras_mask', None)
-
- @property
- def output_mask(self):
- """Retrieves the output mask tensor(s) of a layer.
-
- Only applicable if the layer has exactly one inbound node,
- i.e. if it is connected to one incoming layer.
-
- Returns:
- Output mask tensor (potentially None) or list of output
- mask tensors.
-
- Raises:
- AttributeError: if the layer is connected to
- more than one incoming layers.
- """
- output = self.output
- if isinstance(output, list):
- return [getattr(x, '_keras_mask', None) for x in output]
- else:
- return getattr(output, '_keras_mask', None)
-
- def set_weights(self, weights):
- """Sets the weights of the layer, from Numpy arrays.
-
- Arguments:
- weights: a list of Numpy arrays. The number
- of arrays and their shape must match
- number of the dimensions of the weights
- of the layer (i.e. it should match the
- output of `get_weights`).
-
- Raises:
- ValueError: If the provided weights list does not match the
- layer's specifications.
- """
- params = self.weights
- if len(params) != len(weights):
- raise ValueError('You called `set_weights(weights)` on layer "' +
- self.name + '" with a weight list of length ' +
- str(len(weights)) + ', but the layer was expecting ' +
- str(len(params)) + ' weights. Provided weights: ' +
- str(weights)[:50] + '...')
- if not params:
- return
- weight_value_tuples = []
- param_values = K.batch_get_value(params)
- for pv, p, w in zip(param_values, params, weights):
- if pv.shape != w.shape:
- raise ValueError('Layer weight shape ' + str(pv.shape) +
- ' not compatible with '
- 'provided weight shape ' + str(w.shape))
- weight_value_tuples.append((p, w))
- K.batch_set_value(weight_value_tuples)
-
- def get_weights(self):
- """Returns the current weights of the layer.
-
- Returns:
- Weights values as a list of numpy arrays.
- """
- params = self.weights
- return K.batch_get_value(params)
-
- def get_config(self):
- """Returns the config of the layer.
-
- A layer config is a Python dictionary (serializable)
- containing the configuration of a layer.
- The same layer can be reinstantiated later
- (without its trained weights) from this configuration.
-
- The config of a layer does not include connectivity
- information, nor the layer class name. These are handled
- by `Network` (one layer of abstraction above).
-
- Returns:
- Python dictionary.
- """
- config = {'name': self.name, 'trainable': self.trainable}
- if hasattr(self, '_batch_input_shape'):
- config['batch_input_shape'] = self._batch_input_shape
- if hasattr(self, 'dtype'):
- config['dtype'] = self.dtype
- return config
-
- @classmethod
- def from_config(cls, config):
- """Creates a layer from its config.
-
- This method is the reverse of `get_config`,
- capable of instantiating the same layer from the config
- dictionary. It does not handle layer connectivity
- (handled by Network), nor weights (handled by `set_weights`).
-
- Arguments:
- config: A Python dictionary, typically the
- output of get_config.
-
- Returns:
- A layer instance.
- """
- return cls(**config)
-
- @tf_base_layers.Layer.activity_regularizer.setter
- def activity_regularizer(self, activity_regularizer):
- self._activity_regularizer = activity_regularizer
+class Network(base_layer.Layer):
+ """A `Network` is a composition of layers.
-class InputLayer(Layer):
- """Layer to be used as an entry point into a Network (a graph of layers).
-
- It can either wrap an existing tensor (pass an `input_tensor` argument)
- or create its a placeholder tensor (pass arguments `input_shape`, and
- optionally, `dtype`).
-
- It is generally recommend to use the functional layer API via `Input`,
- (which creates an `InputLayer`) without directly using `InputLayer`.
-
- Arguments:
- input_shape: Shape tuple (not including the batch axis), or `TensorShape`
- instance (not including the batch axis).
- batch_size: Optional input batch size (integer or None).
- dtype: Datatype of the input.
- input_tensor: Optional tensor to use as layer input
- instead of creating a placeholder.
- sparse: Boolean, whether the placeholder created
- is meant to be sparse.
- name: Name of the layer (string).
- """
-
- def __init__(self,
- input_shape=None,
- batch_size=None,
- dtype=None,
- input_tensor=None,
- sparse=False,
- name=None,
- **kwargs):
- if 'batch_input_shape' in kwargs:
- batch_input_shape = kwargs.pop('batch_input_shape')
- if input_shape and batch_input_shape:
- raise ValueError('Only provide the input_shape OR '
- 'batch_input_shape argument to '
- 'InputLayer, not both at the same time.')
- batch_size = batch_input_shape[0]
- input_shape = batch_input_shape[1:]
- if kwargs:
- raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
-
- if not name:
- prefix = 'input'
- name = prefix + '_' + str(K.get_uid(prefix))
-
- if not dtype:
- if input_tensor is None:
- dtype = K.floatx()
- else:
- dtype = K.dtype(input_tensor)
- super(InputLayer, self).__init__(dtype=dtype, name=name)
- self.built = True
- self.sparse = sparse
- self.batch_size = batch_size
-
- if isinstance(input_shape, tensor_shape.TensorShape):
- input_shape = tuple(input_shape.as_list())
-
- if input_tensor is None:
- if input_shape is not None:
- batch_input_shape = (batch_size,) + tuple(input_shape)
- else:
- batch_input_shape = None
-
- if context.in_eager_mode():
- # In eager mode, create a temporary placeholder to call the layer on.
- input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access
- shape=batch_input_shape,
- dtype=dtype,
- name=self.name)
- else:
- # In graph mode, create a graph placeholder to call the layer on.
- if sparse:
- input_tensor = array_ops.sparse_placeholder(
- shape=batch_input_shape,
- dtype=dtype,
- name=self.name)
- else:
- input_tensor = array_ops.placeholder(
- shape=batch_input_shape,
- dtype=dtype,
- name=self.name)
-
- # For compatibility with Keras API.
- self.is_placeholder = True
- self._batch_input_shape = batch_input_shape
- else:
- # For compatibility with Keras API.
- self.is_placeholder = False
- self._batch_input_shape = tuple(input_tensor.get_shape().as_list())
-
- # Create an input node to add to self.outbound_node
- # and set output_tensors' _keras_history.
- input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access
- tf_base_layers.Node(
- self,
- inbound_layers=[],
- node_indices=[],
- tensor_indices=[],
- input_tensors=[input_tensor],
- output_tensors=[input_tensor])
-
- def get_config(self):
- config = {
- 'batch_input_shape': self._batch_input_shape,
- 'dtype': self.dtype,
- 'sparse': self.sparse,
- 'name': self.name
- }
- return config
-
-
-@tf_export('keras.layers.Input', 'keras.Input')
-def Input( # pylint: disable=invalid-name
- shape=None,
- batch_size=None,
- name=None,
- dtype=None,
- sparse=False,
- tensor=None,
- **kwargs):
- """`Input()` is used to instantiate a Keras tensor.
-
- A Keras tensor is a tensor object from the underlying backend
- (Theano or TensorFlow), which we augment with certain
- attributes that allow us to build a Keras model
- just by knowing the inputs and outputs of the model.
-
- For instance, if a, b and c are Keras tensors,
- it becomes possible to do:
- `model = Model(input=[a, b], output=c)`
-
- The added Keras attribute is:
- `_keras_history`: Last layer applied to the tensor.
- the entire layer graph is retrievable from that layer,
- recursively.
-
- Arguments:
- shape: A shape tuple (integers), not including the batch size.
- For instance, `shape=(32,)` indicates that the expected input
- will be batches of 32-dimensional vectors.
- batch_size: optional static batch size (integer).
- name: An optional name string for the layer.
- Should be unique in a model (do not reuse the same name twice).
- It will be autogenerated if it isn't provided.
- dtype: The data type expected by the input, as a string
- (`float32`, `float64`, `int32`...)
- sparse: A boolean specifying whether the placeholder
- to be created is sparse.
- tensor: Optional existing tensor to wrap into the `Input` layer.
- If set, the layer will not create a placeholder tensor.
- **kwargs: deprecated arguments support.
-
- Returns:
- A tensor.
-
- Example:
-
- ```python
- # this is a logistic regression in Keras
- x = Input(shape=(32,))
- y = Dense(16, activation='softmax')(x)
- model = Model(x, y)
- ```
-
- Raises:
- ValueError: in case of invalid arguments.
- """
- if 'batch_shape' in kwargs:
- batch_shape = kwargs.pop('batch_shape')
- if shape and batch_shape:
- raise ValueError('Only provide the shape OR '
- 'batch_shape argument to '
- 'Input, not both at the same time.')
- batch_size = batch_shape[0]
- shape = batch_shape[1:]
- if kwargs:
- raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
-
- if dtype is None:
- dtype = K.floatx()
- if not shape and tensor is None:
- raise ValueError('Please provide to Input either a `shape`'
- ' or a `tensor` argument. Note that '
- '`shape` does not include the batch '
- 'dimension.')
- input_layer = InputLayer(
- input_shape=shape,
- batch_size=batch_size,
- name=name,
- dtype=dtype,
- sparse=sparse,
- input_tensor=tensor)
- # Return tensor including `_keras_history`.
- # Note that in this case train_output and test_output are the same pointer.
- outputs = input_layer._inbound_nodes[0].output_tensors
- if len(outputs) == 1:
- return outputs[0]
- else:
- return outputs
-
-
-class Network(Layer):
- """A Network is a directed acyclic graph of layers.
-
- It is the topological form of a "model". A Model
- is simply a Network with added training routines.
-
- # Properties
- name
- inputs
- outputs
- input_layers
- output_layers
- input_spec (list of class instances)
- each entry describes one required input:
- - ndim
- - dtype
- trainable (boolean)
- input_shape
- output_shape
- inbound_nodes: list of nodes
- outbound_nodes: list of nodes
- trainable_weights (list of variables)
- non_trainable_weights (list of variables)
-
- # Methods
- summary
- get_layer
- get_weights
- set_weights
- get_config
- compute_output_shape
-
- # Class Methods
- from_config
+ It is the topological form of a "model". A `Model`
+ is simply a `Network` with added training routines.
"""
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
@@ -1053,11 +378,11 @@ class Network(Layer):
if not self._is_graph_network:
return None
- inputs = to_list(inputs)
+ inputs = generic_utils.to_list(inputs)
if mask is None:
masks = [None for _ in range(len(inputs))]
else:
- masks = to_list(mask)
+ masks = generic_utils.to_list(mask)
cache_key = (tf_layers_util.object_list_uid(inputs)
+ '_' + tf_layers_util.object_list_uid(masks))
if cache_key in self._output_mask_cache:
@@ -1818,7 +1143,7 @@ class Network(Layer):
if not proceed:
return
with h5py.File(filepath, 'w') as f:
- save_weights_to_hdf5_group(f, self.layers)
+ saving.save_weights_to_hdf5_group(f, self.layers)
def load_weights(self, filepath, by_name=False):
"""Loads all layer weights from a HDF5 save file.
@@ -1849,9 +1174,9 @@ class Network(Layer):
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
if by_name:
- load_weights_from_hdf5_group_by_name(f, self.layers)
+ saving.load_weights_from_hdf5_group_by_name(f, self.layers)
else:
- load_weights_from_hdf5_group(f, self.layers)
+ saving.load_weights_from_hdf5_group(f, self.layers)
def _updated_config(self):
"""Util hared between different serialization methods.
@@ -1989,364 +1314,6 @@ def get_source_inputs(tensor, layer=None, node_index=None):
return source_tensors
-def to_list(x):
- """Normalizes a list/tensor into a list.
-
- If a tensor is passed, we return
- a list of size 1 containing the tensor.
-
- Arguments:
- x: target object to be normalized.
-
- Returns:
- A list.
- """
- if isinstance(x, list):
- return x
- return [x]
-
-
-def save_weights_to_hdf5_group(f, layers):
- from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
-
- f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers]
- f.attrs['backend'] = K.backend().encode('utf8')
- f.attrs['keras_version'] = str(keras_version).encode('utf8')
-
- for layer in layers:
- g = f.create_group(layer.name)
- symbolic_weights = layer.weights
- weight_values = K.batch_get_value(symbolic_weights)
- weight_names = []
- for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
- if hasattr(w, 'name') and w.name:
- name = str(w.name)
- else:
- name = 'param_' + str(i)
- weight_names.append(name.encode('utf8'))
- g.attrs['weight_names'] = weight_names
- for name, val in zip(weight_names, weight_values):
- param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
- if not val.shape:
- # scalar
- param_dset[()] = val
- else:
- param_dset[:] = val
-
-
-def preprocess_weights_for_loading(layer,
- weights,
- original_keras_version=None,
- original_backend=None):
- """Converts layers weights from Keras 1 format to Keras 2.
-
- Arguments:
- layer: Layer instance.
- weights: List of weights values (Numpy arrays).
- original_keras_version: Keras version for the weights, as a string.
- original_backend: Keras backend the weights were trained with,
- as a string.
-
- Returns:
- A list of weights values (Numpy arrays).
- """
- if layer.__class__.__name__ == 'Bidirectional':
- num_weights_per_layer = len(weights) // 2
- forward_weights = preprocess_weights_for_loading(
- layer.forward_layer, weights[:num_weights_per_layer],
- original_keras_version, original_backend)
- backward_weights = preprocess_weights_for_loading(
- layer.backward_layer, weights[num_weights_per_layer:],
- original_keras_version, original_backend)
- weights = forward_weights + backward_weights
-
- if original_keras_version == '1':
- if layer.__class__.__name__ == 'TimeDistributed':
- weights = preprocess_weights_for_loading(
- layer.layer, weights, original_keras_version, original_backend)
-
- if layer.__class__.__name__ == 'Conv1D':
- shape = weights[0].shape
- # Handle Keras 1.1 format
- if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters:
- # Legacy shape:
- # (filters, input_dim, filter_length, 1)
- assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0],
- 1)
- weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
- weights[0] = weights[0][:, 0, :, :]
-
- if layer.__class__.__name__ == 'Conv2D':
- if layer.data_format == 'channels_first':
- # old: (filters, stack_size, kernel_rows, kernel_cols)
- # new: (kernel_rows, kernel_cols, stack_size, filters)
- weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
-
- if layer.__class__.__name__ == 'Conv2DTranspose':
- if layer.data_format == 'channels_last':
- # old: (kernel_rows, kernel_cols, stack_size, filters)
- # new: (kernel_rows, kernel_cols, filters, stack_size)
- weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
- if layer.data_format == 'channels_first':
- # old: (filters, stack_size, kernel_rows, kernel_cols)
- # new: (kernel_rows, kernel_cols, filters, stack_size)
- weights[0] = np.transpose(weights[0], (2, 3, 0, 1))
-
- if layer.__class__.__name__ == 'Conv3D':
- if layer.data_format == 'channels_first':
- # old: (filters, stack_size, ...)
- # new: (..., stack_size, filters)
- weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))
-
- if layer.__class__.__name__ == 'GRU':
- if len(weights) == 9:
- kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1)
- recurrent_kernel = np.concatenate(
- [weights[1], weights[4], weights[7]], axis=-1)
- bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1)
- weights = [kernel, recurrent_kernel, bias]
-
- if layer.__class__.__name__ == 'LSTM':
- if len(weights) == 12:
- # old: i, c, f, o
- # new: i, f, c, o
- kernel = np.concatenate(
- [weights[0], weights[6], weights[3], weights[9]], axis=-1)
- recurrent_kernel = np.concatenate(
- [weights[1], weights[7], weights[4], weights[10]], axis=-1)
- bias = np.concatenate(
- [weights[2], weights[8], weights[5], weights[11]], axis=-1)
- weights = [kernel, recurrent_kernel, bias]
-
- if layer.__class__.__name__ == 'ConvLSTM2D':
- if len(weights) == 12:
- kernel = np.concatenate(
- [weights[0], weights[6], weights[3], weights[9]], axis=-1)
- recurrent_kernel = np.concatenate(
- [weights[1], weights[7], weights[4], weights[10]], axis=-1)
- bias = np.concatenate(
- [weights[2], weights[8], weights[5], weights[11]], axis=-1)
- if layer.data_format == 'channels_first':
- # old: (filters, stack_size, kernel_rows, kernel_cols)
- # new: (kernel_rows, kernel_cols, stack_size, filters)
- kernel = np.transpose(kernel, (2, 3, 1, 0))
- recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
- weights = [kernel, recurrent_kernel, bias]
-
- if layer.__class__.__name__ in ['Model', 'Sequential']:
- new_weights = []
- # trainable weights
- for sublayer in layer.layers:
- num_weights = len(sublayer.trainable_weights)
- if num_weights > 0:
- new_weights.extend(
- preprocess_weights_for_loading(
- layer=sublayer,
- weights=weights[:num_weights],
- original_keras_version=original_keras_version,
- original_backend=original_backend))
- weights = weights[num_weights:]
-
- # non-trainable weights
- for sublayer in layer.layers:
- num_weights = len([
- l for l in sublayer.weights if l not in sublayer.trainable_weights
- ])
- if num_weights > 0:
- new_weights.extend(
- preprocess_weights_for_loading(
- layer=sublayer,
- weights=weights[:num_weights],
- original_keras_version=original_keras_version,
- original_backend=original_backend))
- weights = weights[num_weights:]
- weights = new_weights
-
- conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
- if layer.__class__.__name__ in conv_layers:
- if original_backend == 'theano':
- weights[0] = conv_utils.convert_kernel(weights[0])
- if layer.__class__.__name__ == 'ConvLSTM2D':
- weights[1] = conv_utils.convert_kernel(weights[1])
- if K.int_shape(layer.weights[0]) != weights[0].shape:
- weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
- if layer.__class__.__name__ == 'ConvLSTM2D':
- weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
-
- # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM
- if layer.__class__.__name__ == 'LSTM' and len(weights) == 3:
- # Determine if loading a CuDNNLSTM layer from the number of bias weights:
- # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
- # if there's no bias weight in the file, skip this conversion
- units = weights[1].shape[0]
- bias = weights[2]
- if len(bias) == units * 8:
- # reshape the kernels
- kernels = np.split(weights[0], 4, axis=1)
- kernels = [
- kernel.reshape(-1).reshape(kernel.shape, order='F')
- for kernel in kernels
- ]
- weights[0] = np.concatenate(kernels, axis=1)
-
- # transpose the recurrent kernels
- recurrent_kernels = np.split(weights[1], 4, axis=1)
- recurrent_kernels = [kernel.T for kernel in recurrent_kernels]
- weights[1] = np.concatenate(recurrent_kernels, axis=1)
-
- # split the bias into half and merge
- weights[2] = bias[:units * 4] + bias[units * 4:]
-
- return weights
-
-
-def load_weights_from_hdf5_group(f, layers):
- """Implements topological (order-based) weight loading.
-
- Arguments:
- f: A pointer to a HDF5 group.
- layers: a list of target layers.
-
- Raises:
- ValueError: in case of mismatch between provided layers
- and weights file.
- """
- if 'keras_version' in f.attrs:
- original_keras_version = f.attrs['keras_version'].decode('utf8')
- else:
- original_keras_version = '1'
- if 'backend' in f.attrs:
- original_backend = f.attrs['backend'].decode('utf8')
- else:
- original_backend = None
-
- filtered_layers = []
- for layer in layers:
- weights = layer.weights
- if weights:
- filtered_layers.append(layer)
-
- layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
- filtered_layer_names = []
- for name in layer_names:
- g = f[name]
- weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
- if weight_names:
- filtered_layer_names.append(name)
- layer_names = filtered_layer_names
- if len(layer_names) != len(filtered_layers):
- raise ValueError('You are trying to load a weight file '
- 'containing ' + str(len(layer_names)) +
- ' layers into a model with ' + str(len(filtered_layers)) +
- ' layers.')
-
- # We batch weight value assignments in a single backend call
- # which provides a speedup in TensorFlow.
- weight_value_tuples = []
- for k, name in enumerate(layer_names):
- g = f[name]
- weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
- weight_values = [g[weight_name] for weight_name in weight_names]
- layer = filtered_layers[k]
- symbolic_weights = layer.weights
- weight_values = preprocess_weights_for_loading(
- layer, weight_values, original_keras_version, original_backend)
- if len(weight_values) != len(symbolic_weights):
- raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
- '" in the current model) was found to '
- 'correspond to layer ' + name + ' in the save file. '
- 'However the new layer ' + layer.name + ' expects ' +
- str(len(symbolic_weights)) +
- ' weights, but the saved weights have ' +
- str(len(weight_values)) + ' elements.')
- weight_value_tuples += zip(symbolic_weights, weight_values)
- K.batch_set_value(weight_value_tuples)
-
-
-def load_weights_from_hdf5_group_by_name(f, layers):
- """Implements name-based weight loading.
-
- (instead of topological weight loading).
-
- Layers that have no matching name are skipped.
-
- Arguments:
- f: A pointer to a HDF5 group.
- layers: a list of target layers.
-
- Raises:
- ValueError: in case of mismatch between provided layers
- and weights file.
- """
- if 'keras_version' in f.attrs:
- original_keras_version = f.attrs['keras_version'].decode('utf8')
- else:
- original_keras_version = '1'
- if 'backend' in f.attrs:
- original_backend = f.attrs['backend'].decode('utf8')
- else:
- original_backend = None
-
- # New file format.
- layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
-
- # Reverse index of layer name to list of layers with name.
- index = {}
- for layer in layers:
- if layer.name:
- index.setdefault(layer.name, []).append(layer)
-
- # We batch weight value assignments in a single backend call
- # which provides a speedup in TensorFlow.
- weight_value_tuples = []
- for k, name in enumerate(layer_names):
- g = f[name]
- weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
- weight_values = [g[weight_name] for weight_name in weight_names]
-
- for layer in index.get(name, []):
- symbolic_weights = layer.weights
- weight_values = preprocess_weights_for_loading(
- layer, weight_values, original_keras_version, original_backend)
- if len(weight_values) != len(symbolic_weights):
- raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
- '") expects ' + str(len(symbolic_weights)) +
- ' weight(s), but the saved weights' + ' have ' +
- str(len(weight_values)) + ' element(s).')
- # Set values.
- for i in range(len(weight_values)):
- weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
- K.batch_set_value(weight_value_tuples)
-
-
-def shape_type_conversion(fn):
- """Decorator that handles tuple/TensorShape conversion.
-
- Used in `compute_output_shape` and `build`.
-
- Arguments:
- fn: function to wrap.
-
- Returns:
- Wrapped function.
- """
-
- def wrapper(instance, input_shape):
- if input_shape is not None:
- if isinstance(input_shape, list):
- input_shape = [
- tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape]
- else:
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
- output_shape = fn(instance, input_shape)
- if output_shape is not None:
- if isinstance(output_shape, list):
- return [tensor_shape.TensorShape(x) for x in output_shape]
- return tensor_shape.TensorShape(output_shape)
-
- return wrapper
-
-
def _make_node_key(layer_name, node_index):
return layer_name + '_ib-' + str(node_index)
diff --git a/tensorflow/python/keras/_impl/keras/engine/saving.py b/tensorflow/python/keras/_impl/keras/engine/saving.py
new file mode 100644
index 0000000000..52522e6935
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/engine/saving.py
@@ -0,0 +1,671 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Model saving utilities.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+
+import numpy as np
+from six.moves import zip # pylint: disable=redefined-builtin
+
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import optimizers
+from tensorflow.python.keras._impl.keras.utils import conv_utils
+from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
+
+# pylint: disable=g-import-not-at-top
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+try:
+ import yaml
+except ImportError:
+ yaml = None
+# pylint: enable=g-import-not-at-top
+
+
+@tf_export('keras.models.save_model')
+def save_model(model, filepath, overwrite=True, include_optimizer=True):
+ """Save a model to a HDF5 file.
+
+ The saved model contains:
+ - the model's configuration (topology)
+ - the model's weights
+ - the model's optimizer's state (if any)
+
+ Thus the saved model can be reinstantiated in
+ the exact same state, without any of the code
+ used for model definition or training.
+
+ Arguments:
+ model: Keras model instance to be saved.
+ filepath: String, path where to save the model.
+ overwrite: Whether we should overwrite any existing
+ model at the target location, or instead
+ ask the user with a manual prompt.
+ include_optimizer: If True, save optimizer's state together.
+
+ Raises:
+ ImportError: if h5py is not available.
+ """
+
+ if h5py is None:
+ raise ImportError('`save_model` requires h5py.')
+
+ def get_json_type(obj):
+ """Serialize any object to a JSON-serializable structure.
+
+ Arguments:
+ obj: the object to serialize
+
+ Returns:
+ JSON-serializable structure representing `obj`.
+
+ Raises:
+ TypeError: if `obj` cannot be serialized.
+ """
+ # if obj is a serializable Keras class instance
+ # e.g. optimizer, layer
+ if hasattr(obj, 'get_config'):
+ return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
+
+ # if obj is any numpy type
+ if type(obj).__module__ == np.__name__:
+ if isinstance(obj, np.ndarray):
+ return {'type': type(obj), 'value': obj.tolist()}
+ else:
+ return obj.item()
+
+ # misc functions (e.g. loss function)
+ if callable(obj):
+ return obj.__name__
+
+ # if obj is a python 'type'
+ if type(obj).__name__ == type.__name__:
+ return obj.__name__
+
+ raise TypeError('Not JSON Serializable:', obj)
+
+ from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
+
+ # If file exists and should not be overwritten.
+ if not overwrite and os.path.isfile(filepath):
+ proceed = ask_to_proceed_with_overwrite(filepath)
+ if not proceed:
+ return
+
+ with h5py.File(filepath, mode='w') as f:
+ f.attrs['keras_version'] = str(keras_version).encode('utf8')
+ f.attrs['backend'] = K.backend().encode('utf8')
+ f.attrs['model_config'] = json.dumps(
+ {
+ 'class_name': model.__class__.__name__,
+ 'config': model.get_config()
+ },
+ default=get_json_type).encode('utf8')
+
+ model_weights_group = f.create_group('model_weights')
+ model_layers = model.layers
+ save_weights_to_hdf5_group(model_weights_group, model_layers)
+
+ if include_optimizer and hasattr(model, 'optimizer'):
+ if isinstance(model.optimizer, optimizers.TFOptimizer):
+ logging.warning(
+ 'TensorFlow optimizers do not '
+ 'make it possible to access '
+ 'optimizer attributes or optimizer state '
+ 'after instantiation. '
+ 'As a result, we cannot save the optimizer '
+ 'as part of the model save file.'
+ 'You will have to compile your model again after loading it. '
+ 'Prefer using a Keras optimizer instead '
+ '(see keras.io/optimizers).')
+ else:
+ f.attrs['training_config'] = json.dumps(
+ {
+ 'optimizer_config': {
+ 'class_name': model.optimizer.__class__.__name__,
+ 'config': model.optimizer.get_config()
+ },
+ 'loss': model.loss,
+ 'metrics': model.metrics,
+ 'sample_weight_mode': model.sample_weight_mode,
+ 'loss_weights': model.loss_weights,
+ },
+ default=get_json_type).encode('utf8')
+
+ # Save optimizer weights.
+ symbolic_weights = getattr(model.optimizer, 'weights')
+ if symbolic_weights:
+ optimizer_weights_group = f.create_group('optimizer_weights')
+ weight_values = K.batch_get_value(symbolic_weights)
+ weight_names = []
+ for w, val in zip(symbolic_weights, weight_values):
+ name = str(w.name)
+ weight_names.append(name.encode('utf8'))
+ optimizer_weights_group.attrs['weight_names'] = weight_names
+ for name, val in zip(weight_names, weight_values):
+ param_dset = optimizer_weights_group.create_dataset(
+ name, val.shape, dtype=val.dtype)
+ if not val.shape:
+ # scalar
+ param_dset[()] = val
+ else:
+ param_dset[:] = val
+ f.flush()
+
+
+@tf_export('keras.models.load_model')
+def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin
+ """Loads a model saved via `save_model`.
+
+ Arguments:
+ filepath: String, path to the saved model.
+ custom_objects: Optional dictionary mapping names
+ (strings) to custom classes or functions to be
+ considered during deserialization.
+ compile: Boolean, whether to compile the model
+ after loading.
+
+ Returns:
+ A Keras model instance. If an optimizer was found
+ as part of the saved model, the model is already
+ compiled. Otherwise, the model is uncompiled and
+ a warning will be displayed. When `compile` is set
+ to False, the compilation is omitted without any
+ warning.
+
+ Raises:
+ ImportError: if h5py is not available.
+ ValueError: In case of an invalid savefile.
+ """
+ if h5py is None:
+ raise ImportError('`load_model` requires h5py.')
+
+ if not custom_objects:
+ custom_objects = {}
+
+ def convert_custom_objects(obj):
+ """Handles custom object lookup.
+
+ Arguments:
+ obj: object, dict, or list.
+
+ Returns:
+ The same structure, where occurrences
+ of a custom object name have been replaced
+ with the custom object.
+ """
+ if isinstance(obj, list):
+ deserialized = []
+ for value in obj:
+ deserialized.append(convert_custom_objects(value))
+ return deserialized
+ if isinstance(obj, dict):
+ deserialized = {}
+ for key, value in obj.items():
+ deserialized[key] = convert_custom_objects(value)
+ return deserialized
+ if obj in custom_objects:
+ return custom_objects[obj]
+ return obj
+
+ with h5py.File(filepath, mode='r') as f:
+ # instantiate model
+ model_config = f.attrs.get('model_config')
+ if model_config is None:
+ raise ValueError('No model found in config file.')
+ model_config = json.loads(model_config.decode('utf-8'))
+ model = model_from_config(model_config, custom_objects=custom_objects)
+
+ # set weights
+ load_weights_from_hdf5_group(f['model_weights'], model.layers)
+
+ # Early return if compilation is not required.
+ if not compile:
+ return model
+
+ # instantiate optimizer
+ training_config = f.attrs.get('training_config')
+ if training_config is None:
+ logging.warning('No training configuration found in save file: '
+ 'the model was *not* compiled. Compile it manually.')
+ return model
+ training_config = json.loads(training_config.decode('utf-8'))
+ optimizer_config = training_config['optimizer_config']
+ optimizer = optimizers.deserialize(
+ optimizer_config, custom_objects=custom_objects)
+
+ # Recover loss functions and metrics.
+ loss = convert_custom_objects(training_config['loss'])
+ metrics = convert_custom_objects(training_config['metrics'])
+ sample_weight_mode = training_config['sample_weight_mode']
+ loss_weights = training_config['loss_weights']
+
+ # Compile model.
+ model.compile(
+ optimizer=optimizer,
+ loss=loss,
+ metrics=metrics,
+ loss_weights=loss_weights,
+ sample_weight_mode=sample_weight_mode)
+
+ # Set optimizer weights.
+ if 'optimizer_weights' in f:
+ # Build train function (to get weight updates).
+ model._make_train_function()
+ optimizer_weights_group = f['optimizer_weights']
+ optimizer_weight_names = [
+ n.decode('utf8')
+ for n in optimizer_weights_group.attrs['weight_names']
+ ]
+ optimizer_weight_values = [
+ optimizer_weights_group[n] for n in optimizer_weight_names
+ ]
+ try:
+ model.optimizer.set_weights(optimizer_weight_values)
+ except ValueError:
+ logging.warning('Error in loading the saved optimizer '
+ 'state. As a result, your model is '
+ 'starting with a freshly initialized '
+ 'optimizer.')
+ return model
+
+
+@tf_export('keras.models.model_from_config')
+def model_from_config(config, custom_objects=None):
+ """Instantiates a Keras model from its config.
+
+ Arguments:
+ config: Configuration dictionary.
+ custom_objects: Optional dictionary mapping names
+ (strings) to custom classes or functions to be
+ considered during deserialization.
+
+ Returns:
+ A Keras model instance (uncompiled).
+
+ Raises:
+ TypeError: if `config` is not a dictionary.
+ """
+ if isinstance(config, list):
+ raise TypeError('`model_from_config` expects a dictionary, not a list. '
+ 'Maybe you meant to use '
+ '`Sequential.from_config(config)`?')
+ from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top
+ return deserialize(config, custom_objects=custom_objects)
+
+
+@tf_export('keras.models.model_from_yaml')
+def model_from_yaml(yaml_string, custom_objects=None):
+ """Parses a yaml model configuration file and returns a model instance.
+
+ Arguments:
+ yaml_string: YAML string encoding a model configuration.
+ custom_objects: Optional dictionary mapping names
+ (strings) to custom classes or functions to be
+ considered during deserialization.
+
+ Returns:
+ A Keras model instance (uncompiled).
+
+ Raises:
+ ImportError: if yaml module is not found.
+ """
+ if yaml is None:
+ raise ImportError('Requires yaml module installed.')
+ config = yaml.load(yaml_string)
+ from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top
+ return deserialize(config, custom_objects=custom_objects)
+
+
+@tf_export('keras.models.model_from_json')
+def model_from_json(json_string, custom_objects=None):
+ """Parses a JSON model configuration file and returns a model instance.
+
+ Arguments:
+ json_string: JSON string encoding a model configuration.
+ custom_objects: Optional dictionary mapping names
+ (strings) to custom classes or functions to be
+ considered during deserialization.
+
+ Returns:
+ A Keras model instance (uncompiled).
+ """
+ config = json.loads(json_string)
+ from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top
+ return deserialize(config, custom_objects=custom_objects)
+
+
+def save_weights_to_hdf5_group(f, layers):
+ from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
+
+ f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers]
+ f.attrs['backend'] = K.backend().encode('utf8')
+ f.attrs['keras_version'] = str(keras_version).encode('utf8')
+
+ for layer in layers:
+ g = f.create_group(layer.name)
+ symbolic_weights = layer.weights
+ weight_values = K.batch_get_value(symbolic_weights)
+ weight_names = []
+ for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
+ if hasattr(w, 'name') and w.name:
+ name = str(w.name)
+ else:
+ name = 'param_' + str(i)
+ weight_names.append(name.encode('utf8'))
+ g.attrs['weight_names'] = weight_names
+ for name, val in zip(weight_names, weight_values):
+ param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
+ if not val.shape:
+ # scalar
+ param_dset[()] = val
+ else:
+ param_dset[:] = val
+
+
+def preprocess_weights_for_loading(layer,
+ weights,
+ original_keras_version=None,
+ original_backend=None):
+ """Converts layers weights from Keras 1 format to Keras 2.
+
+ Arguments:
+ layer: Layer instance.
+ weights: List of weights values (Numpy arrays).
+ original_keras_version: Keras version for the weights, as a string.
+ original_backend: Keras backend the weights were trained with,
+ as a string.
+
+ Returns:
+ A list of weights values (Numpy arrays).
+ """
+ if layer.__class__.__name__ == 'Bidirectional':
+ num_weights_per_layer = len(weights) // 2
+ forward_weights = preprocess_weights_for_loading(
+ layer.forward_layer, weights[:num_weights_per_layer],
+ original_keras_version, original_backend)
+ backward_weights = preprocess_weights_for_loading(
+ layer.backward_layer, weights[num_weights_per_layer:],
+ original_keras_version, original_backend)
+ weights = forward_weights + backward_weights
+
+ if original_keras_version == '1':
+ if layer.__class__.__name__ == 'TimeDistributed':
+ weights = preprocess_weights_for_loading(
+ layer.layer, weights, original_keras_version, original_backend)
+
+ if layer.__class__.__name__ == 'Conv1D':
+ shape = weights[0].shape
+ # Handle Keras 1.1 format
+ if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters:
+ # Legacy shape:
+ # (filters, input_dim, filter_length, 1)
+ assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0],
+ 1)
+ weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
+ weights[0] = weights[0][:, 0, :, :]
+
+ if layer.__class__.__name__ == 'Conv2D':
+ if layer.data_format == 'channels_first':
+ # old: (filters, stack_size, kernel_rows, kernel_cols)
+ # new: (kernel_rows, kernel_cols, stack_size, filters)
+ weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
+
+ if layer.__class__.__name__ == 'Conv2DTranspose':
+ if layer.data_format == 'channels_last':
+ # old: (kernel_rows, kernel_cols, stack_size, filters)
+ # new: (kernel_rows, kernel_cols, filters, stack_size)
+ weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
+ if layer.data_format == 'channels_first':
+ # old: (filters, stack_size, kernel_rows, kernel_cols)
+ # new: (kernel_rows, kernel_cols, filters, stack_size)
+ weights[0] = np.transpose(weights[0], (2, 3, 0, 1))
+
+ if layer.__class__.__name__ == 'Conv3D':
+ if layer.data_format == 'channels_first':
+ # old: (filters, stack_size, ...)
+ # new: (..., stack_size, filters)
+ weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))
+
+ if layer.__class__.__name__ == 'GRU':
+ if len(weights) == 9:
+ kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1)
+ recurrent_kernel = np.concatenate(
+ [weights[1], weights[4], weights[7]], axis=-1)
+ bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1)
+ weights = [kernel, recurrent_kernel, bias]
+
+ if layer.__class__.__name__ == 'LSTM':
+ if len(weights) == 12:
+ # old: i, c, f, o
+ # new: i, f, c, o
+ kernel = np.concatenate(
+ [weights[0], weights[6], weights[3], weights[9]], axis=-1)
+ recurrent_kernel = np.concatenate(
+ [weights[1], weights[7], weights[4], weights[10]], axis=-1)
+ bias = np.concatenate(
+ [weights[2], weights[8], weights[5], weights[11]], axis=-1)
+ weights = [kernel, recurrent_kernel, bias]
+
+ if layer.__class__.__name__ == 'ConvLSTM2D':
+ if len(weights) == 12:
+ kernel = np.concatenate(
+ [weights[0], weights[6], weights[3], weights[9]], axis=-1)
+ recurrent_kernel = np.concatenate(
+ [weights[1], weights[7], weights[4], weights[10]], axis=-1)
+ bias = np.concatenate(
+ [weights[2], weights[8], weights[5], weights[11]], axis=-1)
+ if layer.data_format == 'channels_first':
+ # old: (filters, stack_size, kernel_rows, kernel_cols)
+ # new: (kernel_rows, kernel_cols, stack_size, filters)
+ kernel = np.transpose(kernel, (2, 3, 1, 0))
+ recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
+ weights = [kernel, recurrent_kernel, bias]
+
+ if layer.__class__.__name__ in ['Model', 'Sequential']:
+ new_weights = []
+ # trainable weights
+ for sublayer in layer.layers:
+ num_weights = len(sublayer.trainable_weights)
+ if num_weights > 0:
+ new_weights.extend(
+ preprocess_weights_for_loading(
+ layer=sublayer,
+ weights=weights[:num_weights],
+ original_keras_version=original_keras_version,
+ original_backend=original_backend))
+ weights = weights[num_weights:]
+
+ # non-trainable weights
+ for sublayer in layer.layers:
+ num_weights = len([
+ l for l in sublayer.weights if l not in sublayer.trainable_weights
+ ])
+ if num_weights > 0:
+ new_weights.extend(
+ preprocess_weights_for_loading(
+ layer=sublayer,
+ weights=weights[:num_weights],
+ original_keras_version=original_keras_version,
+ original_backend=original_backend))
+ weights = weights[num_weights:]
+ weights = new_weights
+
+ conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
+ if layer.__class__.__name__ in conv_layers:
+ if original_backend == 'theano':
+ weights[0] = conv_utils.convert_kernel(weights[0])
+ if layer.__class__.__name__ == 'ConvLSTM2D':
+ weights[1] = conv_utils.convert_kernel(weights[1])
+ if K.int_shape(layer.weights[0]) != weights[0].shape:
+ weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
+ if layer.__class__.__name__ == 'ConvLSTM2D':
+ weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
+
+ # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM
+ if layer.__class__.__name__ == 'LSTM' and len(weights) == 3:
+ # Determine if loading a CuDNNLSTM layer from the number of bias weights:
+ # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
+ # if there's no bias weight in the file, skip this conversion
+ units = weights[1].shape[0]
+ bias = weights[2]
+ if len(bias) == units * 8:
+ # reshape the kernels
+ kernels = np.split(weights[0], 4, axis=1)
+ kernels = [
+ kernel.reshape(-1).reshape(kernel.shape, order='F')
+ for kernel in kernels
+ ]
+ weights[0] = np.concatenate(kernels, axis=1)
+
+ # transpose the recurrent kernels
+ recurrent_kernels = np.split(weights[1], 4, axis=1)
+ recurrent_kernels = [kernel.T for kernel in recurrent_kernels]
+ weights[1] = np.concatenate(recurrent_kernels, axis=1)
+
+ # split the bias into half and merge
+ weights[2] = bias[:units * 4] + bias[units * 4:]
+
+ return weights
+
+
+def load_weights_from_hdf5_group(f, layers):
+ """Implements topological (order-based) weight loading.
+
+ Arguments:
+ f: A pointer to a HDF5 group.
+ layers: a list of target layers.
+
+ Raises:
+ ValueError: in case of mismatch between provided layers
+ and weights file.
+ """
+ if 'keras_version' in f.attrs:
+ original_keras_version = f.attrs['keras_version'].decode('utf8')
+ else:
+ original_keras_version = '1'
+ if 'backend' in f.attrs:
+ original_backend = f.attrs['backend'].decode('utf8')
+ else:
+ original_backend = None
+
+ filtered_layers = []
+ for layer in layers:
+ weights = layer.weights
+ if weights:
+ filtered_layers.append(layer)
+
+ layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
+ filtered_layer_names = []
+ for name in layer_names:
+ g = f[name]
+ weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
+ if weight_names:
+ filtered_layer_names.append(name)
+ layer_names = filtered_layer_names
+ if len(layer_names) != len(filtered_layers):
+ raise ValueError('You are trying to load a weight file '
+ 'containing ' + str(len(layer_names)) +
+ ' layers into a model with ' + str(len(filtered_layers)) +
+ ' layers.')
+
+ # We batch weight value assignments in a single backend call
+ # which provides a speedup in TensorFlow.
+ weight_value_tuples = []
+ for k, name in enumerate(layer_names):
+ g = f[name]
+ weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
+ weight_values = [g[weight_name] for weight_name in weight_names]
+ layer = filtered_layers[k]
+ symbolic_weights = layer.weights
+ weight_values = preprocess_weights_for_loading(
+ layer, weight_values, original_keras_version, original_backend)
+ if len(weight_values) != len(symbolic_weights):
+ raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
+ '" in the current model) was found to '
+ 'correspond to layer ' + name + ' in the save file. '
+ 'However the new layer ' + layer.name + ' expects ' +
+ str(len(symbolic_weights)) +
+ ' weights, but the saved weights have ' +
+ str(len(weight_values)) + ' elements.')
+ weight_value_tuples += zip(symbolic_weights, weight_values)
+ K.batch_set_value(weight_value_tuples)
+
+
+def load_weights_from_hdf5_group_by_name(f, layers):
+ """Implements name-based weight loading.
+
+ (instead of topological weight loading).
+
+ Layers that have no matching name are skipped.
+
+ Arguments:
+ f: A pointer to a HDF5 group.
+ layers: a list of target layers.
+
+ Raises:
+ ValueError: in case of mismatch between provided layers
+ and weights file.
+ """
+ if 'keras_version' in f.attrs:
+ original_keras_version = f.attrs['keras_version'].decode('utf8')
+ else:
+ original_keras_version = '1'
+ if 'backend' in f.attrs:
+ original_backend = f.attrs['backend'].decode('utf8')
+ else:
+ original_backend = None
+
+ # New file format.
+ layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
+
+ # Reverse index of layer name to list of layers with name.
+ index = {}
+ for layer in layers:
+ if layer.name:
+ index.setdefault(layer.name, []).append(layer)
+
+ # We batch weight value assignments in a single backend call
+ # which provides a speedup in TensorFlow.
+ weight_value_tuples = []
+ for k, name in enumerate(layer_names):
+ g = f[name]
+ weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
+ weight_values = [g[weight_name] for weight_name in weight_names]
+
+ for layer in index.get(name, []):
+ symbolic_weights = layer.weights
+ weight_values = preprocess_weights_for_loading(
+ layer, weight_values, original_keras_version, original_backend)
+ if len(weight_values) != len(symbolic_weights):
+ raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
+ '") expects ' + str(len(symbolic_weights)) +
+ ' weight(s), but the saved weights' + ' have ' +
+ str(len(weight_values)) + ' element(s).')
+ # Set values.
+ for i in range(len(weight_values)):
+ weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
+ K.batch_set_value(weight_value_tuples)
diff --git a/tensorflow/python/keras/_impl/keras/engine/saving_test.py b/tensorflow/python/keras/_impl/keras/engine/saving_test.py
new file mode 100644
index 0000000000..bdb17641b0
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/engine/saving_test.py
@@ -0,0 +1,375 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#,============================================================================
+"""Tests for model saving."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import tempfile
+
+import numpy as np
+
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.platform import test
+from tensorflow.python.training import training as training_module
+
+try:
+ import h5py # pylint:disable=g-import-not-at-top
+except ImportError:
+ h5py = None
+
+
+class TestWeightSavingAndLoading(test.TestCase):
+
+ def test_weight_loading(self):
+ with self.test_session():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3)(a)
+ b = keras.layers.Dense(1)(x)
+ model = keras.models.Model(a, b)
+
+ x = np.random.random((3, 2))
+ ref_y = model.predict(x)
+ weights = model.get_weights()
+ model.set_weights(weights)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ with self.assertRaises(ValueError):
+ model.set_weights(weights[1:])
+ with self.assertRaises(ValueError):
+ model.set_weights(weights[::-1])
+
+ if h5py is None:
+ return # Skip rest of test if H5py isn't available.
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ h5_path = os.path.join(temp_dir, 'test.h5')
+ model.save_weights(h5_path)
+ model.load_weights(h5_path)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ model.load_weights(h5_path, by_name=True)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ def test_weight_preprocessing(self):
+ input_dim = 3
+ output_dim = 3
+ size = 2
+ cases = [
+ [
+ (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))),
+ [np.random.random((2, 1)), np.random.random((2, 1))],
+ (None, 3, 2),
+ ],
+ [
+ (keras.layers.TimeDistributed(keras.layers.Dense(1))),
+ [np.random.random((2, 1)), np.random.random((1,))],
+ (None, 3, 2),
+ ],
+ [
+ (keras.layers.Conv1D(output_dim, size, use_bias=False)),
+ [np.random.random((output_dim, input_dim, size, 1))],
+ (None, 4, input_dim),
+ ],
+ [
+ (keras.layers.Conv2D(output_dim, size,
+ use_bias=False, data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size))],
+ (None, input_dim, 4, 4),
+ ],
+ [
+ (keras.layers.Conv2DTranspose(output_dim, size,
+ use_bias=False,
+ data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size))],
+ (None, input_dim, 4, 4),
+ ],
+ [
+ (keras.layers.Conv2DTranspose(output_dim, size,
+ use_bias=False,
+ data_format='channels_last')),
+ [np.random.random((size, size, input_dim, output_dim))],
+ (None, 4, 4, input_dim),
+ ],
+ [
+ (keras.layers.Conv3D(output_dim, size,
+ use_bias=False, data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size, size))],
+ (None, input_dim, 4, 4, 4),
+ ],
+ [
+ (keras.layers.GRU(output_dim)),
+ [np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,))],
+ (None, 4, input_dim),
+ ],
+ [
+ (keras.layers.LSTM(output_dim)),
+ [np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,))],
+ (None, 4, input_dim),
+ ],
+ ]
+ for layer, weights, input_shape in cases:
+ layer.build(input_shape)
+ _ = keras.engine.saving.preprocess_weights_for_loading(
+ layer, weights, original_keras_version='1')
+
+ model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)])
+ _ = keras.engine.saving.preprocess_weights_for_loading(
+ model, model.weights, original_keras_version='1')
+
+ x = keras.Input((2,))
+ y = keras.layers.Dense(2)(x)
+ model = keras.models.Model(x, y)
+ _ = keras.engine.saving.preprocess_weights_for_loading(
+ model, model.weights, original_keras_version='1')
+
+ def test_sequential_weight_loading(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ h5_path = os.path.join(temp_dir, 'test.h5')
+
+ num_hidden = 5
+ input_dim = 3
+ batch_size = 5
+ num_classes = 2
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.add(keras.layers.Dense(num_classes))
+
+ x = np.random.random((batch_size, input_dim))
+ ref_y = model.predict(x)
+
+ model.save_weights(h5_path)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.add(keras.layers.Dense(num_classes))
+ model.load_weights(h5_path)
+ y = model.predict(x)
+
+ self.assertAllClose(y, ref_y)
+
+
+class TestWholeModelSaving(test.TestCase):
+
+ def test_sequential_model_saving(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+
+ new_model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ out2 = new_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # test that new updates are the same with both models
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ new_model.train_on_batch(x, y)
+ out = model.predict(x)
+ out2 = new_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ def test_sequential_model_saving_2(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ # test with custom optimizer, loss
+
+ class CustomOp(keras.optimizers.RMSprop):
+ pass
+
+ def custom_loss(y_true, y_pred):
+ return keras.losses.mse(y_true, y_pred)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc'])
+
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+
+ model = keras.models.load_model(
+ fname,
+ custom_objects={'CustomOp': CustomOp,
+ 'custom_loss': custom_loss})
+ os.close(fd)
+ os.remove(fname)
+
+ out2 = model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ def test_functional_model_saving(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ out2 = model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ def test_saving_without_compilation(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ def test_saving_with_tf_optimizer(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse',
+ optimizer=training_module.AdadeltaOptimizer(0.1),
+ metrics=['acc'])
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ def test_saving_right_after_compilation(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+ model.model._make_train_function()
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ def test_saving_lambda_numpy_array_arguments(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ mean = np.random.random((4, 2, 3))
+ std = np.abs(np.random.random((4, 2, 3))) + 1e-5
+ inputs = keras.layers.Input(shape=(4, 2, 3))
+ output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
+ arguments={'mu': mean, 'std': std})(inputs)
+ model = keras.models.Model(inputs, output)
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ self.assertAllClose(mean, model.layers[1].arguments['mu'])
+ self.assertAllClose(std, model.layers[1].arguments['std'])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential.py b/tensorflow/python/keras/_impl/keras/engine/sequential.py
new file mode 100644
index 0000000000..db5e7754bc
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/engine/sequential.py
@@ -0,0 +1,997 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Home of the `Sequential` model.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import os
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import layers as layer_module
+from tensorflow.python.keras._impl.keras.engine import base_layer
+from tensorflow.python.keras._impl.keras.engine import network
+from tensorflow.python.keras._impl.keras.engine import saving
+from tensorflow.python.keras._impl.keras.engine.input_layer import Input
+from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer
+from tensorflow.python.keras._impl.keras.engine.training import Model
+from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
+
+try:
+ import h5py # pylint: disable=g-import-not-at-top
+except ImportError:
+ h5py = None
+
+
+@tf_export('keras.models.Sequential', 'keras.Sequential')
+class Sequential(Model):
+ """Linear stack of layers.
+
+ Arguments:
+ layers: list of layers to add to the model.
+
+ # Note
+ The first layer passed to a Sequential model
+ should have a defined input shape. What that
+ means is that it should have received an `input_shape`
+ or `batch_input_shape` argument,
+ or for some type of layers (recurrent, Dense...)
+ an `input_dim` argument.
+
+ Example:
+
+ ```python
+ model = Sequential()
+ # first layer must have a defined input shape
+ model.add(Dense(32, input_dim=500))
+ # afterwards, Keras does automatic shape inference
+ model.add(Dense(32))
+
+ # also possible (equivalent to the above):
+ model = Sequential()
+ model.add(Dense(32, input_shape=(500,)))
+ model.add(Dense(32))
+
+ # also possible (equivalent to the above):
+ model = Sequential()
+ # here the batch dimension is None,
+ # which means any batch size will be accepted by the model.
+ model.add(Dense(32, batch_input_shape=(None, 500)))
+ model.add(Dense(32))
+ ```
+ """
+
+ def __init__(self, layers=None, name=None):
+ self._is_graph_network = True
+ self._is_compiled = False
+ self._layers = [] # Stack of layers.
+ self.model = None # Internal Model instance.
+ self.inputs = [] # List of input tensors
+ self.outputs = [] # List of length 1: the output tensor (unique).
+ self._trainable = True
+ self._initial_weights = None
+ self._input_layers = []
+
+ # Model attributes.
+ self._inbound_nodes = []
+ self._outbound_nodes = []
+ self.built = False
+
+ # Set model name.
+ if not name:
+ prefix = 'sequential_'
+ name = prefix + str(K.get_uid(prefix))
+ self._name = name
+
+ # Used by Layer base class.
+ self._dtype = None
+ self._activity_regularizer = None
+
+ # The following properties are not actually used by Keras;
+ # they exist for compatibility with TF's variable scoping mechanism.
+ self._updates = []
+ self._losses = []
+ self._scope = None
+ self._reuse = None
+ self._base_name = name
+ self._graph = ops.get_default_graph()
+
+ # Add to the model any layers passed to the constructor.
+ if layers:
+ for layer in layers:
+ self.add(layer)
+
+ def add(self, layer):
+ """Adds a layer instance on top of the layer stack.
+
+ Arguments:
+ layer: layer instance.
+
+ Raises:
+ TypeError: If `layer` is not a layer instance.
+ ValueError: In case the `layer` argument does not
+ know its input shape.
+ ValueError: In case the `layer` argument has
+ multiple output tensors, or is already connected
+ somewhere else (forbidden in `Sequential` models).
+ """
+ if not isinstance(layer, (base_layer.Layer, base_layer.TFBaseLayer)):
+ raise TypeError('The added layer must be '
+ 'an instance of class Layer. '
+ 'Found: ' + str(layer))
+ if not self.outputs:
+ # First layer in model: check that it is an input layer.
+ if not isinstance(layer, InputLayer):
+ # Create an input layer.
+ # First, we need to infer its expected input shape and dtype.
+ if isinstance(layer, (Model, Sequential)):
+ # We were passed a model as first layer.
+ # This requires a specific way to figure out the
+ # input shape and dtype.
+ if not layer.layers:
+ raise ValueError('Cannot add an empty model '
+ 'to a `Sequential` model.')
+ # In case of nested models: recover the first layer
+ # of the deepest model to infer input shape and dtype.
+ first_layer = layer.layers[0]
+ while isinstance(first_layer, (Model, Sequential)):
+ first_layer = first_layer.layers[0]
+ batch_shape = first_layer._batch_input_shape
+ dtype = first_layer.dtype
+ else:
+ # We were passed a regular layer, and it should
+ # know about its input shape. Otherwise, that's an error.
+ if not hasattr(layer, '_batch_input_shape'):
+ raise ValueError('The first layer in a '
+ 'Sequential model must '
+ 'get an `input_shape` argument.')
+ batch_shape = layer._batch_input_shape
+ dtype = layer.dtype
+ # Instantiate the input layer.
+ x = Input(
+ batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input')
+ # This will build the current layer
+ # and create the node connecting the current layer
+ # to the input layer we just created.
+ layer(x)
+
+ if len(layer._inbound_nodes[-1].output_tensors) != 1:
+ raise ValueError('All layers in a Sequential model '
+ 'should have a single output tensor. '
+ 'For multi-output layers, '
+ 'use the functional API.')
+
+ self.outputs = [layer._inbound_nodes[-1].output_tensors[0]]
+ self.inputs = network.get_source_inputs(self.outputs[0])
+
+ # We create an input node, which we will keep updated
+ # as we add more layers
+ base_layer.Node(
+ outbound_layer=self,
+ inbound_layers=[],
+ node_indices=[],
+ tensor_indices=[],
+ input_tensors=self.inputs,
+ output_tensors=self.outputs)
+ else:
+ output_tensor = layer(self.outputs[0])
+ if isinstance(output_tensor, list):
+ raise TypeError('All layers in a Sequential model '
+ 'should have a single output tensor. '
+ 'For multi-output layers, '
+ 'use the functional API.')
+ self.outputs = [output_tensor]
+ # update self._inbound_nodes
+ self._inbound_nodes[0].output_tensors = self.outputs
+ self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])]
+
+ self._layers.append(layer)
+ self.built = False
+
+ def pop(self):
+ """Removes the last layer in the model.
+
+ Raises:
+ TypeError: if there are no layers in the model.
+ """
+ if not self.layers:
+ raise TypeError('There are no layers in the model.')
+
+ self.layers.pop()
+ if not self.layers:
+ self.outputs = []
+ self._inbound_nodes = []
+ self._outbound_nodes = []
+ else:
+ self.layers[-1]._outbound_nodes = []
+ self.outputs = [self.layers[-1].output]
+ # update self._inbound_nodes
+ self._inbound_nodes[0].output_tensors = self.outputs
+ self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])]
+ self.built = False
+
+ def get_layer(self, name=None, index=None):
+ """Retrieve a layer that is part of the model.
+
+ Returns a layer based on either its name (unique)
+ or its index in the graph. Indices are based on
+ order of horizontal graph traversal (bottom-up).
+
+ Arguments:
+ name: string, name of layer.
+ index: integer, index of layer.
+
+ Returns:
+ A layer instance.
+ """
+ if not self.built:
+ self.build()
+ return self.model.get_layer(name, index)
+
+ def call(self, inputs, **kwargs):
+ if not self.built:
+ self.build()
+ return self.model.call(inputs, **kwargs)
+
+ def build(self, input_shape=None):
+ if not self.inputs or not self.outputs:
+ raise TypeError('Sequential model cannot be built: model is empty.'
+ ' Add some layers first.')
+ # actually create the model
+ self.model = Model(self.inputs, self.outputs[0], name=self.name + '_model')
+ self.model.trainable = self.trainable
+
+ # mirror model attributes
+ self.supports_masking = self.model.supports_masking
+ self._output_mask_cache = self.model._output_mask_cache
+ self._output_tensor_cache = self.model._output_tensor_cache
+ self._output_shape_cache = self.model._output_shape_cache
+ self._input_layers = self.model._input_layers
+ self._output_layers = self.model._output_layers
+ self._input_coordinates = self.model._input_coordinates
+ self._output_coordinates = self.model._output_coordinates
+ self._nodes_by_depth = self.model._nodes_by_depth
+ self._network_nodes = self.model._network_nodes
+ self.output_names = self.model.output_names
+ self.input_names = self.model.input_names
+ self._feed_input_names = self.model._feed_input_names
+ self._feed_inputs = self.model._feed_inputs
+
+ # Make sure child model callbacks
+ # will call the parent Sequential model.
+ self.model.callback_model = self
+
+ self.built = True
+
+ @property
+ def uses_learning_phase(self):
+ if not self.built:
+ self.build()
+ return self.model.uses_learning_phase
+
+ def _gather_list_attr(self, attr):
+ all_attrs = []
+ for layer in self.layers:
+ all_attrs += getattr(layer, attr, [])
+ return all_attrs
+
+ def _make_train_function(self):
+ self.model._make_train_function()
+
+ def _make_test_function(self):
+ self.model._make_test_function()
+
+ def _make_predict_function(self):
+ self.model._make_predict_function()
+
+ @property
+ def trainable(self):
+ return self._trainable
+
+ @trainable.setter
+ def trainable(self, value):
+ if self.model:
+ self.model.trainable = value
+ self._trainable = value
+
+ @property
+ def trainable_weights(self):
+ if not self.trainable:
+ return []
+ return self._gather_list_attr('trainable_weights')
+
+ @property
+ def non_trainable_weights(self):
+ weights = self._gather_list_attr('non_trainable_weights')
+ if not self.trainable:
+ trainable_weights = self._gather_list_attr('trainable_weights')
+ return trainable_weights + weights
+ return weights
+
+ @property
+ def regularizers(self):
+ if not self.built:
+ self.build()
+ return self.model.regularizers
+
+ def get_weights(self):
+ """Retrieves the weights of the model.
+
+ Returns:
+ A flat list of Numpy arrays
+ (one array per model weight).
+ """
+ if not self.built:
+ self.build()
+ return self.model.get_weights()
+
+ def set_weights(self, weights):
+ """Sets the weights of the model.
+
+ Arguments:
+ weights: Should be a list
+ of Numpy arrays with shapes and types matching
+ the output of `model.get_weights()`.
+ """
+ if not self.built:
+ self.build()
+ self.model.set_weights(weights)
+
+ def load_weights(self, filepath, by_name=False):
+ if h5py is None:
+ raise ImportError('`load_weights` requires h5py.')
+ f = h5py.File(filepath, mode='r')
+ if 'layer_names' not in f.attrs and 'model_weights' in f:
+ f = f['model_weights']
+ layers = self.layers
+ if by_name:
+ saving.load_weights_from_hdf5_group_by_name(f, layers)
+ else:
+ saving.load_weights_from_hdf5_group(f, layers)
+ if hasattr(f, 'close'):
+ f.close()
+
+ def save_weights(self, filepath, overwrite=True):
+ if h5py is None:
+ raise ImportError('`save_weights` requires h5py.')
+ # If file exists and should not be overwritten:
+ if not overwrite and os.path.isfile(filepath):
+ proceed = ask_to_proceed_with_overwrite(filepath)
+ if not proceed:
+ return
+ layers = self.layers
+ f = h5py.File(filepath, 'w')
+ saving.save_weights_to_hdf5_group(f, layers)
+ f.flush()
+ f.close()
+
+ def compile(self,
+ optimizer,
+ loss,
+ metrics=None,
+ sample_weight_mode=None,
+ weighted_metrics=None,
+ target_tensors=None,
+ **kwargs):
+ """Configures the model for training.
+
+ Arguments:
+ optimizer: String (name of optimizer) or optimizer object.
+ See [optimizers](/optimizers).
+ loss: String (name of objective function) or objective function.
+ See [losses](/losses).
+ If the model has multiple outputs, you can use a different loss
+ on each output by passing a dictionary or a list of losses.
+ The loss value that will be minimized by the model
+ will then be the sum of all individual losses.
+ metrics: List of metrics to be evaluated by the model
+ during training and testing.
+ Typically you will use `metrics=['accuracy']`.
+ To specify different metrics for different outputs of a
+ multi-output model, you could also pass a dictionary,
+ such as `metrics={'output_a': 'accuracy'}`.
+ sample_weight_mode: If you need to do timestep-wise
+ sample weighting (2D weights), set this to `"temporal"`.
+ `None` defaults to sample-wise weights (1D).
+ If the model has multiple outputs, you can use a different
+ `sample_weight_mode` on each output by passing a
+ dictionary or a list of modes.
+ weighted_metrics: list of metrics to be evaluated and weighted
+ by `sample_weight` or `class_weight` during training and testing.
+ target_tensors: By default, Keras will create a placeholder for the
+ model's target, which will be fed with the target data during
+ training. If instead you would like to use your own
+ target tensor (in turn, Keras will not expect external
+ Numpy data for these targets at training time), you
+ can specify them via the `target_tensors` argument.
+ It should be a single tensor
+ (for a single-output `Sequential` model).
+ **kwargs: These arguments are passed into `tf.Session.run`.
+
+ Example:
+ ```python
+ model = Sequential()
+ model.add(Dense(32, input_shape=(500,)))
+ model.add(Dense(10, activation='softmax'))
+ model.compile(optimizer='rmsprop',
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+ ```
+ """
+ # create the underlying model
+ self.build()
+ # call compile method of Model class
+ self.model.compile(
+ optimizer,
+ loss,
+ metrics=metrics,
+ sample_weight_mode=sample_weight_mode,
+ weighted_metrics=weighted_metrics,
+ target_tensors=target_tensors,
+ **kwargs)
+ self.optimizer = self.model.optimizer
+ self.loss = self.model.loss
+ self.metrics = self.model.metrics
+ self.loss_weights = self.model.loss_weights
+ self.sample_weight_mode = self.model.sample_weight_mode
+ self.weighted_metrics = self.model.weighted_metrics
+ self.targets = self.model.targets
+ self.metrics_tensors = self.model.metrics_tensors
+ self.metrics_names = self.model.metrics_names
+ self.sample_weights = self.model.sample_weights
+ self.total_loss = self.model.total_loss
+
+ def fit(self,
+ x=None,
+ y=None,
+ batch_size=None,
+ epochs=1,
+ verbose=1,
+ callbacks=None,
+ validation_split=0.,
+ validation_data=None,
+ shuffle=True,
+ class_weight=None,
+ sample_weight=None,
+ initial_epoch=0,
+ steps_per_epoch=None,
+ validation_steps=None,
+ **kwargs):
+ """Trains the model for a fixed number of epochs.
+
+ Arguments:
+ x: Numpy array of training data.
+ If the input layer in the model is named, you can also pass a
+ dictionary mapping the input name to a Numpy array.
+ `x` can be `None` (default) if feeding from
+ TensorFlow data tensors.
+ y: Numpy array of target (label) data.
+ If the output layer in the model is named, you can also pass a
+ dictionary mapping the output name to a Numpy array.
+ `y` can be `None` (default) if feeding from
+ TensorFlow data tensors.
+ batch_size: Integer or `None`.
+ Number of samples per gradient update.
+ If unspecified, it will default to 32.
+ epochs: Integer. Number of epochs to train the model.
+ An epoch is an iteration over the entire `x` and `y`
+ data provided.
+ Note that in conjunction with `initial_epoch`,
+ `epochs` is to be understood as "final epoch".
+ The model is not trained for a number of iterations
+ given by `epochs`, but merely until the epoch
+ of index `epochs` is reached.
+ verbose: 0, 1, or 2. Verbosity mode.
+ 0 = silent, 1 = progress bar, 2 = one line per epoch.
+ callbacks: List of `keras.callbacks.Callback` instances.
+ List of callbacks to apply during training.
+ See [callbacks](/callbacks).
+ validation_split: Float between 0 and 1:
+ Fraction of the training data to be used as validation data.
+ The model will set apart this fraction of the training data,
+ will not train on it, and will evaluate
+ the loss and any model metrics
+ on this data at the end of each epoch.
+ The validation data is selected from the last samples
+ in the `x` and `y` data provided, before shuffling.
+ validation_data: tuple `(x_val, y_val)` or tuple
+ `(x_val, y_val, val_sample_weights)` on which to evaluate
+ the loss and any model metrics at the end of each epoch.
+ The model will not be trained on this data.
+ This will override `validation_split`.
+ shuffle: Boolean (whether to shuffle the training data
+ before each epoch) or str (for 'batch').
+ 'batch' is a special option for dealing with the
+ limitations of HDF5 data; it shuffles in batch-sized chunks.
+ Has no effect when `steps_per_epoch` is not `None`.
+ class_weight: Optional dictionary mapping class indices (integers)
+ to a weight (float) value, used for weighting the loss function
+ (during training only).
+ This can be useful to tell the model to
+ "pay more attention" to samples from
+ an under-represented class.
+ sample_weight: Optional Numpy array of weights for
+ the training samples, used for weighting the loss function
+ (during training only). You can either pass a flat (1D)
+ Numpy array with the same length as the input samples
+ (1:1 mapping between weights and samples),
+ or in the case of temporal data,
+ you can pass a 2D array with shape
+ `(samples, sequence_length)`,
+ to apply a different weight to every timestep of every sample.
+ In this case you should make sure to specify
+ `sample_weight_mode="temporal"` in `compile()`.
+ initial_epoch: Epoch at which to start training
+ (useful for resuming a previous training run).
+ steps_per_epoch: Total number of steps (batches of samples)
+ before declaring one epoch finished and starting the
+ next epoch. When training with input tensors such as
+ TensorFlow data tensors, the default `None` is equal to
+ the number of unique samples in your dataset divided by
+ the batch size, or 1 if that cannot be determined.
+ validation_steps: Only relevant if `steps_per_epoch`
+ is specified. Total number of steps (batches of samples)
+ to validate before stopping.
+ **kwargs: Used for backwards compatibility support.
+
+ Returns:
+ A `History` object. Its `History.history` attribute is
+ a record of training loss values and metrics values
+ at successive epochs, as well as validation loss values
+ and validation metrics values (if applicable).
+
+ Raises:
+ RuntimeError: If the model was never compiled.
+ ValueError: In case of mismatch between the provided input data
+ and what the model expects.
+ """
+ if not self.built:
+ raise RuntimeError('The model needs to be compiled before being used.')
+ return self.model.fit(
+ x,
+ y,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ validation_split=validation_split,
+ validation_data=validation_data,
+ shuffle=shuffle,
+ class_weight=class_weight,
+ sample_weight=sample_weight,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps)
+
+ def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):
+ """Computes the loss on some input data, batch by batch.
+
+ Arguments:
+ x: input data, as a Numpy array or list of Numpy arrays
+ (if the model has multiple inputs).
+ y: labels, as a Numpy array.
+ batch_size: integer. Number of samples per gradient update.
+ verbose: verbosity mode, 0 or 1.
+ sample_weight: sample weights, as a Numpy array.
+
+ Returns:
+ Scalar test loss (if the model has no metrics)
+ or list of scalars (if the model computes other metrics).
+ The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+
+ Raises:
+ RuntimeError: if the model was never compiled.
+ """
+ if not self.built:
+ raise RuntimeError('The model needs to be compiled before being used.')
+ return self.model.evaluate(
+ x,
+ y,
+ batch_size=batch_size,
+ verbose=verbose,
+ sample_weight=sample_weight)
+
+ def predict(self, x, batch_size=32, verbose=0):
+ """Generates output predictions for the input samples.
+
+ The input samples are processed batch by batch.
+
+ Arguments:
+ x: the input data, as a Numpy array.
+ batch_size: integer.
+ verbose: verbosity mode, 0 or 1.
+
+ Returns:
+ A Numpy array of predictions.
+ """
+ if not self.built:
+ self.build()
+ return self.model.predict(x, batch_size=batch_size, verbose=verbose)
+
+ def predict_on_batch(self, x):
+ """Returns predictions for a single batch of samples.
+
+ Arguments:
+ x: input data, as a Numpy array or list of Numpy arrays
+ (if the model has multiple inputs).
+
+ Returns:
+ A Numpy array of predictions.
+ """
+ if not self.built:
+ self.build()
+ return self.model.predict_on_batch(x)
+
+ def train_on_batch(self, x, y, class_weight=None, sample_weight=None):
+ """Single gradient update over one batch of samples.
+
+ Arguments:
+ x: input data, as a Numpy array or list of Numpy arrays
+ (if the model has multiple inputs).
+ y: labels, as a Numpy array.
+ class_weight: dictionary mapping classes to a weight value,
+ used for scaling the loss function (during training only).
+ sample_weight: sample weights, as a Numpy array.
+
+ Returns:
+ Scalar training loss (if the model has no metrics)
+ or list of scalars (if the model computes other metrics).
+ The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+
+ Raises:
+ RuntimeError: if the model was never compiled.
+ """
+ if not self.built:
+ raise RuntimeError('The model needs to be compiled before being used.')
+ return self.model.train_on_batch(
+ x, y, sample_weight=sample_weight, class_weight=class_weight)
+
+ def test_on_batch(self, x, y, sample_weight=None):
+ """Evaluates the model over a single batch of samples.
+
+ Arguments:
+ x: input data, as a Numpy array or list of Numpy arrays
+ (if the model has multiple inputs).
+ y: labels, as a Numpy array.
+ sample_weight: sample weights, as a Numpy array.
+
+ Returns:
+ Scalar test loss (if the model has no metrics)
+ or list of scalars (if the model computes other metrics).
+ The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+
+ Raises:
+ RuntimeError: if the model was never compiled.
+ """
+ if not self.built:
+ raise RuntimeError('The model needs to be compiled before being used.')
+ return self.model.test_on_batch(x, y, sample_weight=sample_weight)
+
+ def predict_proba(self, x, batch_size=32, verbose=0):
+ """Generates class probability predictions for the input samples.
+
+ The input samples are processed batch by batch.
+
+ Arguments:
+ x: input data, as a Numpy array or list of Numpy arrays
+ (if the model has multiple inputs).
+ batch_size: integer.
+ verbose: verbosity mode, 0 or 1.
+
+ Returns:
+ A Numpy array of probability predictions.
+ """
+ preds = self.predict(x, batch_size, verbose)
+ if preds.min() < 0. or preds.max() > 1.:
+ logging.warning('Network returning invalid probability values. '
+ 'The last layer might not normalize predictions '
+ 'into probabilities '
+ '(like softmax or sigmoid would).')
+ return preds
+
+ def predict_classes(self, x, batch_size=32, verbose=0):
+ """Generate class predictions for the input samples.
+
+ The input samples are processed batch by batch.
+
+ Arguments:
+ x: input data, as a Numpy array or list of Numpy arrays
+ (if the model has multiple inputs).
+ batch_size: integer.
+ verbose: verbosity mode, 0 or 1.
+
+ Returns:
+ A numpy array of class predictions.
+ """
+ proba = self.predict(x, batch_size=batch_size, verbose=verbose)
+ if proba.shape[-1] > 1:
+ return proba.argmax(axis=-1)
+ else:
+ return (proba > 0.5).astype('int32')
+
+ def fit_generator(self,
+ generator,
+ steps_per_epoch=None,
+ epochs=1,
+ verbose=1,
+ callbacks=None,
+ validation_data=None,
+ validation_steps=None,
+ class_weight=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False,
+ shuffle=True,
+ initial_epoch=0,
+ **kwargs):
+ """Fits the model on data generated batch-by-batch by a Python generator.
+
+ The generator is run in parallel to the model, for efficiency.
+ For instance, this allows you to do real-time data augmentation
+ on images on CPU in parallel to training your model on GPU.
+
+ Arguments:
+ generator: A generator.
+ The output of the generator must be either
+ - a tuple (inputs, targets)
+ - a tuple (inputs, targets, sample_weights).
+ All arrays should contain the same number of samples.
+ The generator is expected to loop over its data
+ indefinitely. An epoch finishes when `steps_per_epoch`
+ batches have been seen by the model.
+ steps_per_epoch: Total number of steps (batches of samples)
+ to yield from `generator` before declaring one epoch
+ finished and starting the next epoch. It should typically
+ be equal to the number of samples of your dataset
+ divided by the batch size.
+ Optional for `Sequence`: if unspecified, will use
+ the `len(generator)` as a number of steps.
+ epochs: Integer, total number of iterations on the data.
+ Note that in conjunction with initial_epoch, the parameter
+ epochs is to be understood as "final epoch". The model is
+ not trained for n steps given by epochs, but until the
+ epoch epochs is reached.
+ verbose: Verbosity mode, 0, 1, or 2.
+ callbacks: List of callbacks to be called during training.
+ validation_data: This can be either
+ - A generator for the validation data
+ - A tuple (inputs, targets)
+ - A tuple (inputs, targets, sample_weights).
+ validation_steps: Only relevant if `validation_data`
+ is a generator.
+ Number of steps to yield from validation generator
+ at the end of every epoch. It should typically
+ be equal to the number of samples of your
+ validation dataset divided by the batch size.
+ Optional for `Sequence`: if unspecified, will use
+ the `len(validation_data)` as a number of steps.
+ class_weight: Dictionary mapping class indices to a weight
+ for the class.
+ max_queue_size: Maximum size for the generator queue
+ workers: Maximum number of processes to spin up
+ use_multiprocessing: If True, use process based threading.
+ Note that because
+ this implementation relies on multiprocessing,
+ you should not pass
+ non picklable arguments to the generator
+ as they can't be passed
+ easily to children processes.
+ shuffle: Whether to shuffle the order of the batches at
+ the beginning of each epoch. Only used with instances
+ of `Sequence` (keras.utils.Sequence).
+ initial_epoch: Epoch at which to start training
+ (useful for resuming a previous training run)
+ **kwargs: support for legacy arguments.
+
+ Returns:
+ A `History` object.
+
+ Raises:
+ RuntimeError: if the model was never compiled.
+ ValueError: In case the generator yields
+ data in an invalid format.
+
+ Example:
+
+ ```python
+ def generate_arrays_from_file(path):
+ while 1:
+ f = open(path)
+ for line in f:
+ # create Numpy arrays of input data
+ # and labels, from each line in the file
+ x, y = process_line(line)
+ yield (x, y)
+ f.close()
+
+ model.fit_generator(generate_arrays_from_file('/my_file.txt'),
+ steps_per_epoch=1000, epochs=10)
+ ```
+ """
+ # Legacy support
+ if 'max_q_size' in kwargs:
+ max_queue_size = kwargs.pop('max_q_size')
+ logging.warning('The argument `max_q_size` has been renamed '
+ '`max_queue_size`. Update your method calls accordingly.')
+ if 'pickle_safe' in kwargs:
+ use_multiprocessing = kwargs.pop('pickle_safe')
+ logging.warning('The argument `pickle_safe` has been renamed '
+ '`use_multiprocessing`. '
+ 'Update your method calls accordingly.')
+ if kwargs:
+ raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
+
+ if not self.built:
+ raise RuntimeError('The model needs to be compiled before being used.')
+ return self.model.fit_generator(
+ generator,
+ steps_per_epoch,
+ epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ validation_data=validation_data,
+ validation_steps=validation_steps,
+ class_weight=class_weight,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing,
+ shuffle=shuffle,
+ initial_epoch=initial_epoch)
+
+ def evaluate_generator(self,
+ generator,
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False,
+ **kwargs):
+ """Evaluates the model on a data generator.
+
+ The generator should return the same kind of data
+ as accepted by `test_on_batch`.
+
+ Arguments:
+ generator: Generator yielding tuples (inputs, targets)
+ or (inputs, targets, sample_weights)
+ steps: Total number of steps (batches of samples)
+ to yield from `generator` before stopping.
+ Optional for `Sequence`: if unspecified, will use
+ the `len(generator)` as a number of steps.
+ max_queue_size: maximum size for the generator queue
+ workers: maximum number of processes to spin up
+ use_multiprocessing: if True, use process based threading.
+ Note that because this implementation
+ relies on multiprocessing, you should not pass
+ non picklable arguments to the generator
+ as they can't be passed easily to children processes.
+ **kwargs: support for legacy arguments.
+
+ Returns:
+ Scalar test loss (if the model has no metrics)
+ or list of scalars (if the model computes other metrics).
+ The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+
+ Raises:
+ RuntimeError: if the model was never compiled.
+ ValueError: In case the generator yields
+ data in an invalid format.
+ """
+ # Legacy support
+ if 'max_q_size' in kwargs:
+ max_queue_size = kwargs.pop('max_q_size')
+ logging.warning('The argument `max_q_size` has been renamed '
+ '`max_queue_size`. Update your method calls accordingly.')
+ if 'pickle_safe' in kwargs:
+ use_multiprocessing = kwargs.pop('pickle_safe')
+ logging.warning('The argument `pickle_safe` has been renamed '
+ '`use_multiprocessing`. '
+ 'Update your method calls accordingly.')
+ if kwargs:
+ raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
+
+ if not self.built:
+ raise RuntimeError('The model needs to be compiled before being used.')
+ return self.model.evaluate_generator(
+ generator,
+ steps,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
+ def predict_generator(self,
+ generator,
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False,
+ verbose=0,
+ **kwargs):
+ """Generates predictions for the input samples from a data generator.
+
+ The generator should return the same kind of data as accepted by
+ `predict_on_batch`.
+
+ Arguments:
+ generator: generator yielding batches of input samples.
+ steps: Total number of steps (batches of samples)
+ to yield from `generator` before stopping.
+ Optional for `Sequence`: if unspecified, will use
+ the `len(generator)` as a number of steps.
+ max_queue_size: maximum size for the generator queue
+ workers: maximum number of processes to spin up
+ use_multiprocessing: if True, use process based threading.
+ Note that because this implementation
+ relies on multiprocessing, you should not pass
+ non picklable arguments to the generator
+ as they can't be passed easily to children processes.
+ verbose: verbosity mode, 0 or 1.
+ **kwargs: support for legacy arguments.
+
+ Returns:
+ A Numpy array of predictions.
+
+ Raises:
+ ValueError: In case the generator yields
+ data in an invalid format.
+ """
+ # Legacy support
+ if 'max_q_size' in kwargs:
+ max_queue_size = kwargs.pop('max_q_size')
+ logging.warning('The argument `max_q_size` has been renamed '
+ '`max_queue_size`. Update your method calls accordingly.')
+ if 'pickle_safe' in kwargs:
+ use_multiprocessing = kwargs.pop('pickle_safe')
+ logging.warning('The argument `pickle_safe` has been renamed '
+ '`use_multiprocessing`. '
+ 'Update your method calls accordingly.')
+ if kwargs:
+ raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
+
+ if not self.built:
+ self.build()
+ return self.model.predict_generator(
+ generator,
+ steps,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing,
+ verbose=verbose)
+
+ def get_config(self):
+ config = []
+ for layer in self.layers:
+ config.append({
+ 'class_name': layer.__class__.__name__,
+ 'config': layer.get_config()
+ })
+ return copy.deepcopy(config)
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ model = cls()
+ for conf in config:
+ layer = layer_module.deserialize(conf, custom_objects=custom_objects)
+ model.add(layer)
+ return model
diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py
new file mode 100644
index 0000000000..166634bd82
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py
@@ -0,0 +1,152 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests specific to `Sequential` model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.platform import test
+
+
+class TestSequential(test.TestCase):
+ """Most Sequential model API tests are covered in `training_test.py`.
+ """
+
+ def test_basic_methods(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_dim=2))
+ model.add(keras.layers.Dropout(0.3, name='dp'))
+ model.add(keras.layers.Dense(2, kernel_regularizer='l2',
+ kernel_constraint='max_norm'))
+ model.build()
+ self.assertEqual(model.state_updates, model.model.state_updates)
+ self.assertEqual(model.get_layer(name='dp').name, 'dp')
+
+ def test_sequential_pop(self):
+ num_hidden = 5
+ input_dim = 3
+ batch_size = 5
+ num_classes = 2
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.add(keras.layers.Dense(num_classes))
+ model.compile(loss='mse', optimizer='sgd')
+ x = np.random.random((batch_size, input_dim))
+ y = np.random.random((batch_size, num_classes))
+ model.fit(x, y, epochs=1)
+ model.pop()
+ self.assertEqual(len(model.layers), 1)
+ self.assertEqual(model.output_shape, (None, num_hidden))
+ model.compile(loss='mse', optimizer='sgd')
+ y = np.random.random((batch_size, num_hidden))
+ model.fit(x, y, epochs=1)
+
+ # Test popping single-layer model
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.pop()
+ self.assertEqual(len(model.layers), 0)
+ self.assertEqual(len(model.outputs), 0)
+
+ # Invalid use case
+ model = keras.models.Sequential()
+ with self.assertRaises(TypeError):
+ model.pop()
+
+ def test_invalid_use_cases(self):
+ with self.test_session():
+ # Added objects must be layer instances
+ with self.assertRaises(TypeError):
+ model = keras.models.Sequential()
+ model.add(None)
+
+ # Added layers must have an inputs shape
+ with self.assertRaises(ValueError):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1))
+
+ # Added layers cannot have multiple outputs
+ class MyLayer(keras.layers.Layer):
+
+ def call(self, inputs):
+ return [3 * inputs, 2 * inputs]
+
+ def compute_output_shape(self, input_shape):
+ return [input_shape, input_shape]
+
+ with self.assertRaises(ValueError):
+ model = keras.models.Sequential()
+ model.add(MyLayer(input_shape=(3,)))
+ with self.assertRaises(TypeError):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_dim=1))
+ model.add(MyLayer())
+
+ # Building empty model
+ model = keras.models.Sequential()
+ with self.assertRaises(TypeError):
+ model.build()
+
+ def test_nested_sequential_trainability(self):
+ input_dim = 20
+ num_units = 10
+ num_classes = 2
+
+ inner_model = keras.models.Sequential()
+ inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,)))
+
+ model = keras.models.Sequential()
+ model.add(inner_model)
+ model.add(keras.layers.Dense(num_classes))
+
+ self.assertEqual(len(model.trainable_weights), 4)
+ inner_model.trainable = False
+ self.assertEqual(len(model.trainable_weights), 2)
+ inner_model.trainable = True
+ self.assertEqual(len(model.trainable_weights), 4)
+
+ def test_sequential_update_disabling(self):
+ val_a = np.random.random((10, 4))
+ val_out = np.random.random((10, 4))
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.BatchNormalization(input_shape=(4,)))
+
+ model.trainable = False
+ assert not model.updates
+
+ model.compile('sgd', 'mse')
+ assert not model.updates
+ assert not model.model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
+ model.trainable = True
+ model.compile('sgd', 'mse')
+ assert model.updates
+ assert model.model.updates
+
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ assert np.abs(np.sum(x1 - x2)) > 1e-5
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
index 139621db6d..04434323d6 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
@@ -18,9 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
-import shutil
-
import numpy as np
from tensorflow.python.eager import context
@@ -28,7 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl import keras
-from tensorflow.python.layers import base as base_layers
+from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
@@ -39,11 +36,6 @@ try:
except ImportError:
yaml = None
-try:
- import h5py # pylint:disable=g-import-not-at-top
-except ImportError:
- h5py = None
-
class TopologyConstructionTest(test.TestCase):
@@ -84,7 +76,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(len(layer.get_updates_for(x2)), 1)
self.assertEqual(len(layer.get_updates_for(None)), 1)
- network = keras.engine.topology.Network(x2, y2)
+ network = keras.engine.Network(x2, y2)
self.assertEqual(len(network.updates), 2)
self.assertEqual(len(network.get_updates_for(x1)), 0)
self.assertEqual(len(network.get_updates_for(x2)), 1)
@@ -146,7 +138,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(len(layer.get_losses_for(x2)), 1)
self.assertEqual(len(layer.get_losses_for(None)), 1)
- network = keras.engine.topology.Network(x2, y2)
+ network = keras.engine.Network(x2, y2)
self.assertEqual(len(network.losses), 2)
self.assertEqual(len(network.get_losses_for(x1)), 0)
self.assertEqual(len(network.get_losses_for(x2)), 1)
@@ -267,7 +259,7 @@ class TopologyConstructionTest(test.TestCase):
x = keras.Input(shape=(32,))
dense = keras.layers.Dense(2)
y = dense(x)
- network = keras.engine.topology.Network(x, y, name='dense_network')
+ network = keras.engine.Network(x, y, name='dense_network')
# test basic attributes
self.assertEqual(network.name, 'dense_network')
@@ -502,7 +494,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)])
# test get_source_inputs
- self.assertListEqual(keras.engine.topology.get_source_inputs(c), [a, b])
+ self.assertListEqual(keras.engine.network.get_source_inputs(c), [a, b])
# serialization / deserialization
json_config = model.to_json()
@@ -762,7 +754,7 @@ class TopologyConstructionTest(test.TestCase):
if context.in_graph_mode():
x = keras.Input(shape=(32,))
y = MaskedLayer()(x) # pylint: disable=not-callable
- network = keras.engine.topology.Network(x, y)
+ network = keras.engine.Network(x, y)
# test callability on Input
x_2 = keras.Input(shape=(32,))
@@ -875,139 +867,12 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(np.min(preds), 0.) # At least one unit was dropped.
-class TestSaving(test.TestCase):
-
- def test_weight_loading(self):
- with self.test_session():
- a = keras.layers.Input(shape=(2,))
- x = keras.layers.Dense(3)(a)
- b = keras.layers.Dense(1)(x)
- model = keras.models.Model(a, b)
-
- x = np.random.random((3, 2))
- ref_y = model.predict(x)
- weights = model.get_weights()
- model.set_weights(weights)
- y = model.predict(x)
- self.assertAllClose(ref_y, y)
-
- with self.assertRaises(ValueError):
- model.set_weights(weights[1:])
- with self.assertRaises(ValueError):
- model.set_weights(weights[::-1])
-
- if h5py is None:
- return # Skip rest of test if H5py isn't available.
-
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
-
- h5_path = os.path.join(temp_dir, 'test.h5')
- model.save_weights(h5_path)
- model.load_weights(h5_path)
- y = model.predict(x)
- self.assertAllClose(ref_y, y)
-
- model.load_weights(h5_path, by_name=True)
- y = model.predict(x)
- self.assertAllClose(ref_y, y)
-
- def test_weight_preprocessing(self):
- input_dim = 3
- output_dim = 3
- size = 2
- cases = [
- [
- (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))),
- [np.random.random((2, 1)), np.random.random((2, 1))],
- (None, 3, 2),
- ],
- [
- (keras.layers.TimeDistributed(keras.layers.Dense(1))),
- [np.random.random((2, 1)), np.random.random((1,))],
- (None, 3, 2),
- ],
- [
- (keras.layers.Conv1D(output_dim, size, use_bias=False)),
- [np.random.random((output_dim, input_dim, size, 1))],
- (None, 4, input_dim),
- ],
- [
- (keras.layers.Conv2D(output_dim, size,
- use_bias=False, data_format='channels_first')),
- [np.random.random((output_dim, input_dim, size, size))],
- (None, input_dim, 4, 4),
- ],
- [
- (keras.layers.Conv2DTranspose(output_dim, size,
- use_bias=False,
- data_format='channels_first')),
- [np.random.random((output_dim, input_dim, size, size))],
- (None, input_dim, 4, 4),
- ],
- [
- (keras.layers.Conv2DTranspose(output_dim, size,
- use_bias=False,
- data_format='channels_last')),
- [np.random.random((size, size, input_dim, output_dim))],
- (None, 4, 4, input_dim),
- ],
- [
- (keras.layers.Conv3D(output_dim, size,
- use_bias=False, data_format='channels_first')),
- [np.random.random((output_dim, input_dim, size, size, size))],
- (None, input_dim, 4, 4, 4),
- ],
- [
- (keras.layers.GRU(output_dim)),
- [np.random.random((input_dim, output_dim)),
- np.random.random((output_dim, output_dim)),
- np.random.random((output_dim,)),
- np.random.random((input_dim, output_dim)),
- np.random.random((output_dim, output_dim)),
- np.random.random((output_dim,)),
- np.random.random((input_dim, output_dim)),
- np.random.random((output_dim, output_dim)),
- np.random.random((output_dim,))],
- (None, 4, input_dim),
- ],
- [
- (keras.layers.LSTM(output_dim)),
- [np.random.random((input_dim, output_dim)),
- np.random.random((output_dim, output_dim)),
- np.random.random((output_dim,)),
- np.random.random((input_dim, output_dim)),
- np.random.random((output_dim, output_dim)),
- np.random.random((output_dim,)),
- np.random.random((input_dim, output_dim)),
- np.random.random((output_dim, output_dim)),
- np.random.random((output_dim,)),
- np.random.random((input_dim, output_dim)),
- np.random.random((output_dim, output_dim)),
- np.random.random((output_dim,))],
- (None, 4, input_dim),
- ],
- ]
- for layer, weights, input_shape in cases:
- layer.build(input_shape)
- _ = keras.engine.topology.preprocess_weights_for_loading(
- layer, weights, original_keras_version='1')
-
- model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)])
- _ = keras.engine.topology.preprocess_weights_for_loading(
- model, model.weights, original_keras_version='1')
-
- x = keras.Input((2,))
- y = keras.layers.Dense(2)(x)
- model = keras.models.Model(x, y)
- _ = keras.engine.topology.preprocess_weights_for_loading(
- model, model.weights, original_keras_version='1')
-
-
class DeferredModeTest(test.TestCase):
def testDeferredTensorAttributes(self):
- x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x')
+ x = tf_base_layers._DeferredTensor(shape=(None, 2),
+ dtype='float32',
+ name='x')
self.assertEqual(str(x),
'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)')
self.assertEqual(repr(x),
@@ -1015,21 +880,21 @@ class DeferredModeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testSimpleNetworkBuilding(self):
- inputs = keras.engine.topology.Input(shape=(32,))
+ inputs = keras.engine.Input(shape=(32,))
if context.in_eager_mode():
- self.assertIsInstance(inputs, base_layers._DeferredTensor)
+ self.assertIsInstance(inputs, tf_base_layers._DeferredTensor)
self.assertEqual(inputs.dtype.name, 'float32')
self.assertEqual(inputs.shape.as_list(), [None, 32])
x = keras.layers.Dense(2)(inputs)
if context.in_eager_mode():
- self.assertIsInstance(x, base_layers._DeferredTensor)
+ self.assertIsInstance(x, tf_base_layers._DeferredTensor)
self.assertEqual(x.dtype.name, 'float32')
self.assertEqual(x.shape.as_list(), [None, 2])
outputs = keras.layers.Dense(4)(x)
- network = keras.engine.topology.Network(inputs, outputs)
- self.assertIsInstance(network, keras.engine.topology.Network)
+ network = keras.engine.Network(inputs, outputs)
+ self.assertIsInstance(network, keras.engine.Network)
if context.in_eager_mode():
# It should be possible to call such a network on EagerTensors.
@@ -1040,8 +905,8 @@ class DeferredModeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testMultiIONetworkbuilding(self):
- input_a = keras.engine.topology.Input(shape=(32,))
- input_b = keras.engine.topology.Input(shape=(16,))
+ input_a = keras.engine.Input(shape=(32,))
+ input_b = keras.engine.Input(shape=(16,))
a = keras.layers.Dense(16)(input_a)
class AddLayer(keras.layers.Layer):
@@ -1055,7 +920,7 @@ class DeferredModeTest(test.TestCase):
c = AddLayer()([a, input_b]) # pylint: disable=not-callable
c = keras.layers.Dense(2)(c)
- network = keras.engine.topology.Network([input_a, input_b], [a, c])
+ network = keras.engine.Network([input_a, input_b], [a, c])
if context.in_eager_mode():
a_val = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index d8ea2fe3db..57451ad470 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -31,8 +31,8 @@ from tensorflow.python.keras._impl.keras import losses
from tensorflow.python.keras._impl.keras import metrics as metrics_module
from tensorflow.python.keras._impl.keras import optimizers
from tensorflow.python.keras._impl.keras.engine import training_eager
-from tensorflow.python.keras._impl.keras.engine.topology import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import Network
+from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
+from tensorflow.python.keras._impl.keras.engine.network import Network
from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer
from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
index 7cac17c51a..c40ee109aa 100644
--- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
@@ -25,7 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
index d2792b9636..d95a094245 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
@@ -26,7 +26,7 @@ from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
index ca92899a45..006ecd3135 100644
--- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
@@ -23,7 +23,7 @@ from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/_impl/keras/layers/local.py b/tensorflow/python/keras/_impl/keras/layers/local.py
index df0efe6b8b..13d96e9392 100644
--- a/tensorflow/python/keras/_impl/keras/layers/local.py
+++ b/tensorflow/python/keras/_impl/keras/layers/local.py
@@ -25,7 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py
index cdf2878e83..c660cbd449 100644
--- a/tensorflow/python/keras/_impl/keras/layers/merge.py
+++ b/tensorflow/python/keras/_impl/keras/layers/merge.py
@@ -21,8 +21,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras._impl.keras import backend as K
-from tensorflow.python.keras._impl.keras.engine.topology import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py
index 9010f49615..e309d160e5 100644
--- a/tensorflow/python/keras/_impl/keras/layers/noise.py
+++ b/tensorflow/python/keras/_impl/keras/layers/noise.py
@@ -22,7 +22,7 @@ import numpy as np
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index a81971d9ee..0264c7ae01 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -31,7 +31,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
index 61f1a758e4..76ddd9299d 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.layers import utils as tf_layers_util
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py
index 8000eaabab..9602e7ba39 100644
--- a/tensorflow/python/keras/_impl/keras/models.py
+++ b/tensorflow/python/keras/_impl/keras/models.py
@@ -13,1305 +13,30 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
-"""Home of the Sequential model, and the `save_model`/`load_model` functions.
+"""Code for model cloning, plus model-related API entries.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import copy
-import json
-import os
-
-import numpy as np
-
-from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras import backend as K
-from tensorflow.python.keras._impl.keras import layers as layer_module
-from tensorflow.python.keras._impl.keras import optimizers
-from tensorflow.python.keras._impl.keras.engine import topology
-from tensorflow.python.keras._impl.keras.engine.topology import Input
-from tensorflow.python.keras._impl.keras.engine.topology import InputLayer
-from tensorflow.python.keras._impl.keras.engine.topology import Layer
-from tensorflow.python.keras._impl.keras.engine.topology import TFBaseLayer
-from tensorflow.python.keras._impl.keras.engine.training import Model
+from tensorflow.python.keras._impl.keras.engine import saving
+from tensorflow.python.keras._impl.keras.engine import sequential
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.engine.input_layer import Input
+from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer
+from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
-from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.util.tf_export import tf_export
-
-
-# pylint: disable=g-import-not-at-top
-try:
- import h5py
-except ImportError:
- h5py = None
-
-try:
- import yaml
-except ImportError:
- yaml = None
-# pylint: enable=g-import-not-at-top
-
-
-@tf_export('keras.models.save_model')
-def save_model(model, filepath, overwrite=True, include_optimizer=True):
- """Save a model to a HDF5 file.
-
- The saved model contains:
- - the model's configuration (topology)
- - the model's weights
- - the model's optimizer's state (if any)
-
- Thus the saved model can be reinstantiated in
- the exact same state, without any of the code
- used for model definition or training.
-
- Arguments:
- model: Keras model instance to be saved.
- filepath: String, path where to save the model.
- overwrite: Whether we should overwrite any existing
- model at the target location, or instead
- ask the user with a manual prompt.
- include_optimizer: If True, save optimizer's state together.
-
- Raises:
- ImportError: if h5py is not available.
- """
-
- if h5py is None:
- raise ImportError('`save_model` requires h5py.')
-
- def get_json_type(obj):
- """Serialize any object to a JSON-serializable structure.
-
- Arguments:
- obj: the object to serialize
-
- Returns:
- JSON-serializable structure representing `obj`.
-
- Raises:
- TypeError: if `obj` cannot be serialized.
- """
- # if obj is a serializable Keras class instance
- # e.g. optimizer, layer
- if hasattr(obj, 'get_config'):
- return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
-
- # if obj is any numpy type
- if type(obj).__module__ == np.__name__:
- if isinstance(obj, np.ndarray):
- return {'type': type(obj), 'value': obj.tolist()}
- else:
- return obj.item()
-
- # misc functions (e.g. loss function)
- if callable(obj):
- return obj.__name__
-
- # if obj is a python 'type'
- if type(obj).__name__ == type.__name__:
- return obj.__name__
-
- raise TypeError('Not JSON Serializable:', obj)
-
- from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
-
- # If file exists and should not be overwritten.
- if not overwrite and os.path.isfile(filepath):
- proceed = ask_to_proceed_with_overwrite(filepath)
- if not proceed:
- return
-
- with h5py.File(filepath, mode='w') as f:
- f.attrs['keras_version'] = str(keras_version).encode('utf8')
- f.attrs['backend'] = K.backend().encode('utf8')
- f.attrs['model_config'] = json.dumps(
- {
- 'class_name': model.__class__.__name__,
- 'config': model.get_config()
- },
- default=get_json_type).encode('utf8')
-
- model_weights_group = f.create_group('model_weights')
- model_layers = model.layers
- topology.save_weights_to_hdf5_group(model_weights_group, model_layers)
-
- if include_optimizer and hasattr(model, 'optimizer'):
- if isinstance(model.optimizer, optimizers.TFOptimizer):
- logging.warning(
- 'TensorFlow optimizers do not '
- 'make it possible to access '
- 'optimizer attributes or optimizer state '
- 'after instantiation. '
- 'As a result, we cannot save the optimizer '
- 'as part of the model save file.'
- 'You will have to compile your model again after loading it. '
- 'Prefer using a Keras optimizer instead '
- '(see keras.io/optimizers).')
- else:
- f.attrs['training_config'] = json.dumps(
- {
- 'optimizer_config': {
- 'class_name': model.optimizer.__class__.__name__,
- 'config': model.optimizer.get_config()
- },
- 'loss': model.loss,
- 'metrics': model.metrics,
- 'sample_weight_mode': model.sample_weight_mode,
- 'loss_weights': model.loss_weights,
- },
- default=get_json_type).encode('utf8')
-
- # Save optimizer weights.
- symbolic_weights = getattr(model.optimizer, 'weights')
- if symbolic_weights:
- optimizer_weights_group = f.create_group('optimizer_weights')
- weight_values = K.batch_get_value(symbolic_weights)
- weight_names = []
- for w, val in zip(symbolic_weights, weight_values):
- name = str(w.name)
- weight_names.append(name.encode('utf8'))
- optimizer_weights_group.attrs['weight_names'] = weight_names
- for name, val in zip(weight_names, weight_values):
- param_dset = optimizer_weights_group.create_dataset(
- name, val.shape, dtype=val.dtype)
- if not val.shape:
- # scalar
- param_dset[()] = val
- else:
- param_dset[:] = val
- f.flush()
-
-
-@tf_export('keras.models.load_model')
-def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin
- """Loads a model saved via `save_model`.
-
- Arguments:
- filepath: String, path to the saved model.
- custom_objects: Optional dictionary mapping names
- (strings) to custom classes or functions to be
- considered during deserialization.
- compile: Boolean, whether to compile the model
- after loading.
-
- Returns:
- A Keras model instance. If an optimizer was found
- as part of the saved model, the model is already
- compiled. Otherwise, the model is uncompiled and
- a warning will be displayed. When `compile` is set
- to False, the compilation is omitted without any
- warning.
-
- Raises:
- ImportError: if h5py is not available.
- ValueError: In case of an invalid savefile.
- """
- if h5py is None:
- raise ImportError('`load_model` requires h5py.')
-
- if not custom_objects:
- custom_objects = {}
-
- def convert_custom_objects(obj):
- """Handles custom object lookup.
-
- Arguments:
- obj: object, dict, or list.
-
- Returns:
- The same structure, where occurrences
- of a custom object name have been replaced
- with the custom object.
- """
- if isinstance(obj, list):
- deserialized = []
- for value in obj:
- deserialized.append(convert_custom_objects(value))
- return deserialized
- if isinstance(obj, dict):
- deserialized = {}
- for key, value in obj.items():
- deserialized[key] = convert_custom_objects(value)
- return deserialized
- if obj in custom_objects:
- return custom_objects[obj]
- return obj
-
- with h5py.File(filepath, mode='r') as f:
- # instantiate model
- model_config = f.attrs.get('model_config')
- if model_config is None:
- raise ValueError('No model found in config file.')
- model_config = json.loads(model_config.decode('utf-8'))
- model = model_from_config(model_config, custom_objects=custom_objects)
-
- # set weights
- topology.load_weights_from_hdf5_group(f['model_weights'], model.layers)
-
- # Early return if compilation is not required.
- if not compile:
- return model
-
- # instantiate optimizer
- training_config = f.attrs.get('training_config')
- if training_config is None:
- logging.warning('No training configuration found in save file: '
- 'the model was *not* compiled. Compile it manually.')
- return model
- training_config = json.loads(training_config.decode('utf-8'))
- optimizer_config = training_config['optimizer_config']
- optimizer = optimizers.deserialize(
- optimizer_config, custom_objects=custom_objects)
-
- # Recover loss functions and metrics.
- loss = convert_custom_objects(training_config['loss'])
- metrics = convert_custom_objects(training_config['metrics'])
- sample_weight_mode = training_config['sample_weight_mode']
- loss_weights = training_config['loss_weights']
-
- # Compile model.
- model.compile(
- optimizer=optimizer,
- loss=loss,
- metrics=metrics,
- loss_weights=loss_weights,
- sample_weight_mode=sample_weight_mode)
-
- # Set optimizer weights.
- if 'optimizer_weights' in f:
- # Build train function (to get weight updates).
- if isinstance(model, Sequential):
- model.model._make_train_function()
- else:
- model._make_train_function()
- optimizer_weights_group = f['optimizer_weights']
- optimizer_weight_names = [
- n.decode('utf8')
- for n in optimizer_weights_group.attrs['weight_names']
- ]
- optimizer_weight_values = [
- optimizer_weights_group[n] for n in optimizer_weight_names
- ]
- try:
- model.optimizer.set_weights(optimizer_weight_values)
- except ValueError:
- logging.warning('Error in loading the saved optimizer '
- 'state. As a result, your model is '
- 'starting with a freshly initialized '
- 'optimizer.')
- return model
-
-
-@tf_export('keras.models.model_from_config')
-def model_from_config(config, custom_objects=None):
- """Instantiates a Keras model from its config.
-
- Arguments:
- config: Configuration dictionary.
- custom_objects: Optional dictionary mapping names
- (strings) to custom classes or functions to be
- considered during deserialization.
-
- Returns:
- A Keras model instance (uncompiled).
-
- Raises:
- TypeError: if `config` is not a dictionary.
- """
- if isinstance(config, list):
- raise TypeError('`model_from_config` expects a dictionary, not a list. '
- 'Maybe you meant to use '
- '`Sequential.from_config(config)`?')
- return layer_module.deserialize(config, custom_objects=custom_objects)
-
-
-@tf_export('keras.models.model_from_yaml')
-def model_from_yaml(yaml_string, custom_objects=None):
- """Parses a yaml model configuration file and returns a model instance.
-
- Arguments:
- yaml_string: YAML string encoding a model configuration.
- custom_objects: Optional dictionary mapping names
- (strings) to custom classes or functions to be
- considered during deserialization.
-
- Returns:
- A Keras model instance (uncompiled).
-
- Raises:
- ImportError: if yaml module is not found.
- """
- if yaml is None:
- raise ImportError('Requires yaml module installed.')
- config = yaml.load(yaml_string)
- return layer_module.deserialize(config, custom_objects=custom_objects)
-
-
-@tf_export('keras.models.model_from_json')
-def model_from_json(json_string, custom_objects=None):
- """Parses a JSON model configuration file and returns a model instance.
-
- Arguments:
- json_string: JSON string encoding a model configuration.
- custom_objects: Optional dictionary mapping names
- (strings) to custom classes or functions to be
- considered during deserialization.
-
- Returns:
- A Keras model instance (uncompiled).
- """
- config = json.loads(json_string)
- return layer_module.deserialize(config, custom_objects=custom_objects)
-
-
-@tf_export('keras.models.Sequential', 'keras.Sequential')
-class Sequential(Model):
- """Linear stack of layers.
-
- Arguments:
- layers: list of layers to add to the model.
-
- # Note
- The first layer passed to a Sequential model
- should have a defined input shape. What that
- means is that it should have received an `input_shape`
- or `batch_input_shape` argument,
- or for some type of layers (recurrent, Dense...)
- an `input_dim` argument.
-
- Example:
-
- ```python
- model = Sequential()
- # first layer must have a defined input shape
- model.add(Dense(32, input_dim=500))
- # afterwards, Keras does automatic shape inference
- model.add(Dense(32))
-
- # also possible (equivalent to the above):
- model = Sequential()
- model.add(Dense(32, input_shape=(500,)))
- model.add(Dense(32))
-
- # also possible (equivalent to the above):
- model = Sequential()
- # here the batch dimension is None,
- # which means any batch size will be accepted by the model.
- model.add(Dense(32, batch_input_shape=(None, 500)))
- model.add(Dense(32))
- ```
- """
-
- def __init__(self, layers=None, name=None):
- self._is_graph_network = True
- self._is_compiled = False
- self._layers = [] # Stack of layers.
- self.model = None # Internal Model instance.
- self.inputs = [] # List of input tensors
- self.outputs = [] # List of length 1: the output tensor (unique).
- self._trainable = True
- self._initial_weights = None
- self._input_layers = []
-
- # Model attributes.
- self._inbound_nodes = []
- self._outbound_nodes = []
- self.built = False
-
- # Set model name.
- if not name:
- prefix = 'sequential_'
- name = prefix + str(K.get_uid(prefix))
- self._name = name
-
- # Used by Layer base class.
- self._dtype = None
- self._activity_regularizer = None
-
- # The following properties are not actually used by Keras;
- # they exist for compatibility with TF's variable scoping mechanism.
- self._updates = []
- self._losses = []
- self._scope = None
- self._reuse = None
- self._base_name = name
- self._graph = ops.get_default_graph()
-
- # Add to the model any layers passed to the constructor.
- if layers:
- for layer in layers:
- self.add(layer)
-
- def add(self, layer):
- """Adds a layer instance on top of the layer stack.
-
- Arguments:
- layer: layer instance.
-
- Raises:
- TypeError: If `layer` is not a layer instance.
- ValueError: In case the `layer` argument does not
- know its input shape.
- ValueError: In case the `layer` argument has
- multiple output tensors, or is already connected
- somewhere else (forbidden in `Sequential` models).
- """
- if not isinstance(layer, (Layer, TFBaseLayer)):
- raise TypeError('The added layer must be '
- 'an instance of class Layer. '
- 'Found: ' + str(layer))
- if not self.outputs:
- # First layer in model: check that it is an input layer.
- if not isinstance(layer, InputLayer):
- # Create an input layer.
- # First, we need to infer its expected input shape and dtype.
- if isinstance(layer, (Model, Sequential)):
- # We were passed a model as first layer.
- # This requires a specific way to figure out the
- # input shape and dtype.
- if not layer.layers:
- raise ValueError('Cannot add an empty model '
- 'to a `Sequential` model.')
- # In case of nested models: recover the first layer
- # of the deepest model to infer input shape and dtype.
- first_layer = layer.layers[0]
- while isinstance(first_layer, (Model, Sequential)):
- first_layer = first_layer.layers[0]
- batch_shape = first_layer._batch_input_shape
- dtype = first_layer.dtype
- else:
- # We were passed a regular layer, and it should
- # know about its input shape. Otherwise, that's an error.
- if not hasattr(layer, '_batch_input_shape'):
- raise ValueError('The first layer in a '
- 'Sequential model must '
- 'get an `input_shape` argument.')
- batch_shape = layer._batch_input_shape
- dtype = layer.dtype
- # Instantiate the input layer.
- x = Input(
- batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input')
- # This will build the current layer
- # and create the node connecting the current layer
- # to the input layer we just created.
- layer(x)
-
- if len(layer._inbound_nodes[-1].output_tensors) != 1:
- raise ValueError('All layers in a Sequential model '
- 'should have a single output tensor. '
- 'For multi-output layers, '
- 'use the functional API.')
-
- self.outputs = [layer._inbound_nodes[-1].output_tensors[0]]
- self.inputs = topology.get_source_inputs(self.outputs[0])
-
- # We create an input node, which we will keep updated
- # as we add more layers
- topology.Node(
- outbound_layer=self,
- inbound_layers=[],
- node_indices=[],
- tensor_indices=[],
- input_tensors=self.inputs,
- output_tensors=self.outputs)
- else:
- output_tensor = layer(self.outputs[0])
- if isinstance(output_tensor, list):
- raise TypeError('All layers in a Sequential model '
- 'should have a single output tensor. '
- 'For multi-output layers, '
- 'use the functional API.')
- self.outputs = [output_tensor]
- # update self._inbound_nodes
- self._inbound_nodes[0].output_tensors = self.outputs
- self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])]
-
- self._layers.append(layer)
- self.built = False
-
- def pop(self):
- """Removes the last layer in the model.
-
- Raises:
- TypeError: if there are no layers in the model.
- """
- if not self.layers:
- raise TypeError('There are no layers in the model.')
-
- self.layers.pop()
- if not self.layers:
- self.outputs = []
- self._inbound_nodes = []
- self._outbound_nodes = []
- else:
- self.layers[-1]._outbound_nodes = []
- self.outputs = [self.layers[-1].output]
- # update self._inbound_nodes
- self._inbound_nodes[0].output_tensors = self.outputs
- self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])]
- self.built = False
-
- def get_layer(self, name=None, index=None):
- """Retrieve a layer that is part of the model.
-
- Returns a layer based on either its name (unique)
- or its index in the graph. Indices are based on
- order of horizontal graph traversal (bottom-up).
-
- Arguments:
- name: string, name of layer.
- index: integer, index of layer.
-
- Returns:
- A layer instance.
- """
- if not self.built:
- self.build()
- return self.model.get_layer(name, index)
-
- def call(self, inputs, **kwargs):
- if not self.built:
- self.build()
- return self.model.call(inputs, **kwargs)
-
- def build(self, input_shape=None):
- if not self.inputs or not self.outputs:
- raise TypeError('Sequential model cannot be built: model is empty.'
- ' Add some layers first.')
- # actually create the model
- self.model = Model(self.inputs, self.outputs[0], name=self.name + '_model')
- self.model.trainable = self.trainable
-
- # mirror model attributes
- self.supports_masking = self.model.supports_masking
- self._output_mask_cache = self.model._output_mask_cache
- self._output_tensor_cache = self.model._output_tensor_cache
- self._output_shape_cache = self.model._output_shape_cache
- self._input_layers = self.model._input_layers
- self._output_layers = self.model._output_layers
- self._input_coordinates = self.model._input_coordinates
- self._output_coordinates = self.model._output_coordinates
- self._nodes_by_depth = self.model._nodes_by_depth
- self._network_nodes = self.model._network_nodes
- self.output_names = self.model.output_names
- self.input_names = self.model.input_names
- self._feed_input_names = self.model._feed_input_names
- self._feed_inputs = self.model._feed_inputs
-
- # Make sure child model callbacks
- # will call the parent Sequential model.
- self.model.callback_model = self
-
- self.built = True
-
- @property
- def uses_learning_phase(self):
- if not self.built:
- self.build()
- return self.model.uses_learning_phase
-
- def _gather_list_attr(self, attr):
- all_attrs = []
- for layer in self.layers:
- all_attrs += getattr(layer, attr, [])
- return all_attrs
-
- @property
- def trainable(self):
- return self._trainable
-
- @trainable.setter
- def trainable(self, value):
- if self.model:
- self.model.trainable = value
- self._trainable = value
-
- @property
- def trainable_weights(self):
- if not self.trainable:
- return []
- return self._gather_list_attr('trainable_weights')
-
- @property
- def non_trainable_weights(self):
- weights = self._gather_list_attr('non_trainable_weights')
- if not self.trainable:
- trainable_weights = self._gather_list_attr('trainable_weights')
- return trainable_weights + weights
- return weights
-
- @property
- def regularizers(self):
- if not self.built:
- self.build()
- return self.model.regularizers
-
- def get_weights(self):
- """Retrieves the weights of the model.
-
- Returns:
- A flat list of Numpy arrays
- (one array per model weight).
- """
- if not self.built:
- self.build()
- return self.model.get_weights()
-
- def set_weights(self, weights):
- """Sets the weights of the model.
-
- Arguments:
- weights: Should be a list
- of Numpy arrays with shapes and types matching
- the output of `model.get_weights()`.
- """
- if not self.built:
- self.build()
- self.model.set_weights(weights)
-
- def load_weights(self, filepath, by_name=False):
- if h5py is None:
- raise ImportError('`load_weights` requires h5py.')
- f = h5py.File(filepath, mode='r')
- if 'layer_names' not in f.attrs and 'model_weights' in f:
- f = f['model_weights']
- layers = self.layers
- if by_name:
- topology.load_weights_from_hdf5_group_by_name(f, layers)
- else:
- topology.load_weights_from_hdf5_group(f, layers)
- if hasattr(f, 'close'):
- f.close()
-
- def save_weights(self, filepath, overwrite=True):
- if h5py is None:
- raise ImportError('`save_weights` requires h5py.')
- # If file exists and should not be overwritten:
- if not overwrite and os.path.isfile(filepath):
- proceed = ask_to_proceed_with_overwrite(filepath)
- if not proceed:
- return
- layers = self.layers
- f = h5py.File(filepath, 'w')
- topology.save_weights_to_hdf5_group(f, layers)
- f.flush()
- f.close()
-
- def compile(self,
- optimizer,
- loss,
- metrics=None,
- sample_weight_mode=None,
- weighted_metrics=None,
- target_tensors=None,
- **kwargs):
- """Configures the model for training.
-
- Arguments:
- optimizer: String (name of optimizer) or optimizer object.
- See [optimizers](/optimizers).
- loss: String (name of objective function) or objective function.
- See [losses](/losses).
- If the model has multiple outputs, you can use a different loss
- on each output by passing a dictionary or a list of losses.
- The loss value that will be minimized by the model
- will then be the sum of all individual losses.
- metrics: List of metrics to be evaluated by the model
- during training and testing.
- Typically you will use `metrics=['accuracy']`.
- To specify different metrics for different outputs of a
- multi-output model, you could also pass a dictionary,
- such as `metrics={'output_a': 'accuracy'}`.
- sample_weight_mode: If you need to do timestep-wise
- sample weighting (2D weights), set this to `"temporal"`.
- `None` defaults to sample-wise weights (1D).
- If the model has multiple outputs, you can use a different
- `sample_weight_mode` on each output by passing a
- dictionary or a list of modes.
- weighted_metrics: list of metrics to be evaluated and weighted
- by `sample_weight` or `class_weight` during training and testing.
- target_tensors: By default, Keras will create a placeholder for the
- model's target, which will be fed with the target data during
- training. If instead you would like to use your own
- target tensor (in turn, Keras will not expect external
- Numpy data for these targets at training time), you
- can specify them via the `target_tensors` argument.
- It should be a single tensor
- (for a single-output `Sequential` model).
- **kwargs: These arguments are passed into `tf.Session.run`.
-
- Example:
- ```python
- model = Sequential()
- model.add(Dense(32, input_shape=(500,)))
- model.add(Dense(10, activation='softmax'))
- model.compile(optimizer='rmsprop',
- loss='categorical_crossentropy',
- metrics=['accuracy'])
- ```
- """
- # create the underlying model
- self.build()
- # call compile method of Model class
- self.model.compile(
- optimizer,
- loss,
- metrics=metrics,
- sample_weight_mode=sample_weight_mode,
- weighted_metrics=weighted_metrics,
- target_tensors=target_tensors,
- **kwargs)
- self.optimizer = self.model.optimizer
- self.loss = self.model.loss
- self.metrics = self.model.metrics
- self.loss_weights = self.model.loss_weights
- self.sample_weight_mode = self.model.sample_weight_mode
- self.weighted_metrics = self.model.weighted_metrics
- self.targets = self.model.targets
- self.metrics_tensors = self.model.metrics_tensors
- self.metrics_names = self.model.metrics_names
- self.sample_weights = self.model.sample_weights
- self.total_loss = self.model.total_loss
-
- def fit(self,
- x=None,
- y=None,
- batch_size=None,
- epochs=1,
- verbose=1,
- callbacks=None,
- validation_split=0.,
- validation_data=None,
- shuffle=True,
- class_weight=None,
- sample_weight=None,
- initial_epoch=0,
- steps_per_epoch=None,
- validation_steps=None,
- **kwargs):
- """Trains the model for a fixed number of epochs.
-
- Arguments:
- x: Numpy array of training data.
- If the input layer in the model is named, you can also pass a
- dictionary mapping the input name to a Numpy array.
- `x` can be `None` (default) if feeding from
- TensorFlow data tensors.
- y: Numpy array of target (label) data.
- If the output layer in the model is named, you can also pass a
- dictionary mapping the output name to a Numpy array.
- `y` can be `None` (default) if feeding from
- TensorFlow data tensors.
- batch_size: Integer or `None`.
- Number of samples per gradient update.
- If unspecified, it will default to 32.
- epochs: Integer. Number of epochs to train the model.
- An epoch is an iteration over the entire `x` and `y`
- data provided.
- Note that in conjunction with `initial_epoch`,
- `epochs` is to be understood as "final epoch".
- The model is not trained for a number of iterations
- given by `epochs`, but merely until the epoch
- of index `epochs` is reached.
- verbose: 0, 1, or 2. Verbosity mode.
- 0 = silent, 1 = progress bar, 2 = one line per epoch.
- callbacks: List of `keras.callbacks.Callback` instances.
- List of callbacks to apply during training.
- See [callbacks](/callbacks).
- validation_split: Float between 0 and 1:
- Fraction of the training data to be used as validation data.
- The model will set apart this fraction of the training data,
- will not train on it, and will evaluate
- the loss and any model metrics
- on this data at the end of each epoch.
- The validation data is selected from the last samples
- in the `x` and `y` data provided, before shuffling.
- validation_data: tuple `(x_val, y_val)` or tuple
- `(x_val, y_val, val_sample_weights)` on which to evaluate
- the loss and any model metrics at the end of each epoch.
- The model will not be trained on this data.
- This will override `validation_split`.
- shuffle: Boolean (whether to shuffle the training data
- before each epoch) or str (for 'batch').
- 'batch' is a special option for dealing with the
- limitations of HDF5 data; it shuffles in batch-sized chunks.
- Has no effect when `steps_per_epoch` is not `None`.
- class_weight: Optional dictionary mapping class indices (integers)
- to a weight (float) value, used for weighting the loss function
- (during training only).
- This can be useful to tell the model to
- "pay more attention" to samples from
- an under-represented class.
- sample_weight: Optional Numpy array of weights for
- the training samples, used for weighting the loss function
- (during training only). You can either pass a flat (1D)
- Numpy array with the same length as the input samples
- (1:1 mapping between weights and samples),
- or in the case of temporal data,
- you can pass a 2D array with shape
- `(samples, sequence_length)`,
- to apply a different weight to every timestep of every sample.
- In this case you should make sure to specify
- `sample_weight_mode="temporal"` in `compile()`.
- initial_epoch: Epoch at which to start training
- (useful for resuming a previous training run).
- steps_per_epoch: Total number of steps (batches of samples)
- before declaring one epoch finished and starting the
- next epoch. When training with input tensors such as
- TensorFlow data tensors, the default `None` is equal to
- the number of unique samples in your dataset divided by
- the batch size, or 1 if that cannot be determined.
- validation_steps: Only relevant if `steps_per_epoch`
- is specified. Total number of steps (batches of samples)
- to validate before stopping.
- **kwargs: Used for backwards compatibility support.
-
- Returns:
- A `History` object. Its `History.history` attribute is
- a record of training loss values and metrics values
- at successive epochs, as well as validation loss values
- and validation metrics values (if applicable).
-
- Raises:
- RuntimeError: If the model was never compiled.
- ValueError: In case of mismatch between the provided input data
- and what the model expects.
- """
- if not self.built:
- raise RuntimeError('The model needs to be compiled before being used.')
- return self.model.fit(
- x,
- y,
- batch_size=batch_size,
- epochs=epochs,
- verbose=verbose,
- callbacks=callbacks,
- validation_split=validation_split,
- validation_data=validation_data,
- shuffle=shuffle,
- class_weight=class_weight,
- sample_weight=sample_weight,
- initial_epoch=initial_epoch,
- steps_per_epoch=steps_per_epoch,
- validation_steps=validation_steps)
-
- def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):
- """Computes the loss on some input data, batch by batch.
-
- Arguments:
- x: input data, as a Numpy array or list of Numpy arrays
- (if the model has multiple inputs).
- y: labels, as a Numpy array.
- batch_size: integer. Number of samples per gradient update.
- verbose: verbosity mode, 0 or 1.
- sample_weight: sample weights, as a Numpy array.
-
- Returns:
- Scalar test loss (if the model has no metrics)
- or list of scalars (if the model computes other metrics).
- The attribute `model.metrics_names` will give you
- the display labels for the scalar outputs.
-
- Raises:
- RuntimeError: if the model was never compiled.
- """
- if not self.built:
- raise RuntimeError('The model needs to be compiled before being used.')
- return self.model.evaluate(
- x,
- y,
- batch_size=batch_size,
- verbose=verbose,
- sample_weight=sample_weight)
-
- def predict(self, x, batch_size=32, verbose=0):
- """Generates output predictions for the input samples.
-
- The input samples are processed batch by batch.
-
- Arguments:
- x: the input data, as a Numpy array.
- batch_size: integer.
- verbose: verbosity mode, 0 or 1.
-
- Returns:
- A Numpy array of predictions.
- """
- if not self.built:
- self.build()
- return self.model.predict(x, batch_size=batch_size, verbose=verbose)
-
- def predict_on_batch(self, x):
- """Returns predictions for a single batch of samples.
-
- Arguments:
- x: input data, as a Numpy array or list of Numpy arrays
- (if the model has multiple inputs).
-
- Returns:
- A Numpy array of predictions.
- """
- if not self.built:
- self.build()
- return self.model.predict_on_batch(x)
-
- def train_on_batch(self, x, y, class_weight=None, sample_weight=None):
- """Single gradient update over one batch of samples.
-
- Arguments:
- x: input data, as a Numpy array or list of Numpy arrays
- (if the model has multiple inputs).
- y: labels, as a Numpy array.
- class_weight: dictionary mapping classes to a weight value,
- used for scaling the loss function (during training only).
- sample_weight: sample weights, as a Numpy array.
-
- Returns:
- Scalar training loss (if the model has no metrics)
- or list of scalars (if the model computes other metrics).
- The attribute `model.metrics_names` will give you
- the display labels for the scalar outputs.
-
- Raises:
- RuntimeError: if the model was never compiled.
- """
- if not self.built:
- raise RuntimeError('The model needs to be compiled before being used.')
- return self.model.train_on_batch(
- x, y, sample_weight=sample_weight, class_weight=class_weight)
-
- def test_on_batch(self, x, y, sample_weight=None):
- """Evaluates the model over a single batch of samples.
-
- Arguments:
- x: input data, as a Numpy array or list of Numpy arrays
- (if the model has multiple inputs).
- y: labels, as a Numpy array.
- sample_weight: sample weights, as a Numpy array.
-
- Returns:
- Scalar test loss (if the model has no metrics)
- or list of scalars (if the model computes other metrics).
- The attribute `model.metrics_names` will give you
- the display labels for the scalar outputs.
-
- Raises:
- RuntimeError: if the model was never compiled.
- """
- if not self.built:
- raise RuntimeError('The model needs to be compiled before being used.')
- return self.model.test_on_batch(x, y, sample_weight=sample_weight)
-
- def predict_proba(self, x, batch_size=32, verbose=0):
- """Generates class probability predictions for the input samples.
-
- The input samples are processed batch by batch.
-
- Arguments:
- x: input data, as a Numpy array or list of Numpy arrays
- (if the model has multiple inputs).
- batch_size: integer.
- verbose: verbosity mode, 0 or 1.
-
- Returns:
- A Numpy array of probability predictions.
- """
- preds = self.predict(x, batch_size, verbose)
- if preds.min() < 0. or preds.max() > 1.:
- logging.warning('Network returning invalid probability values. '
- 'The last layer might not normalize predictions '
- 'into probabilities '
- '(like softmax or sigmoid would).')
- return preds
-
- def predict_classes(self, x, batch_size=32, verbose=0):
- """Generate class predictions for the input samples.
-
- The input samples are processed batch by batch.
-
- Arguments:
- x: input data, as a Numpy array or list of Numpy arrays
- (if the model has multiple inputs).
- batch_size: integer.
- verbose: verbosity mode, 0 or 1.
-
- Returns:
- A numpy array of class predictions.
- """
- proba = self.predict(x, batch_size=batch_size, verbose=verbose)
- if proba.shape[-1] > 1:
- return proba.argmax(axis=-1)
- else:
- return (proba > 0.5).astype('int32')
-
- def fit_generator(self,
- generator,
- steps_per_epoch=None,
- epochs=1,
- verbose=1,
- callbacks=None,
- validation_data=None,
- validation_steps=None,
- class_weight=None,
- max_queue_size=10,
- workers=1,
- use_multiprocessing=False,
- shuffle=True,
- initial_epoch=0,
- **kwargs):
- """Fits the model on data generated batch-by-batch by a Python generator.
-
- The generator is run in parallel to the model, for efficiency.
- For instance, this allows you to do real-time data augmentation
- on images on CPU in parallel to training your model on GPU.
-
- Arguments:
- generator: A generator.
- The output of the generator must be either
- - a tuple (inputs, targets)
- - a tuple (inputs, targets, sample_weights).
- All arrays should contain the same number of samples.
- The generator is expected to loop over its data
- indefinitely. An epoch finishes when `steps_per_epoch`
- batches have been seen by the model.
- steps_per_epoch: Total number of steps (batches of samples)
- to yield from `generator` before declaring one epoch
- finished and starting the next epoch. It should typically
- be equal to the number of samples of your dataset
- divided by the batch size.
- Optional for `Sequence`: if unspecified, will use
- the `len(generator)` as a number of steps.
- epochs: Integer, total number of iterations on the data.
- Note that in conjunction with initial_epoch, the parameter
- epochs is to be understood as "final epoch". The model is
- not trained for n steps given by epochs, but until the
- epoch epochs is reached.
- verbose: Verbosity mode, 0, 1, or 2.
- callbacks: List of callbacks to be called during training.
- validation_data: This can be either
- - A generator for the validation data
- - A tuple (inputs, targets)
- - A tuple (inputs, targets, sample_weights).
- validation_steps: Only relevant if `validation_data`
- is a generator.
- Number of steps to yield from validation generator
- at the end of every epoch. It should typically
- be equal to the number of samples of your
- validation dataset divided by the batch size.
- Optional for `Sequence`: if unspecified, will use
- the `len(validation_data)` as a number of steps.
- class_weight: Dictionary mapping class indices to a weight
- for the class.
- max_queue_size: Maximum size for the generator queue
- workers: Maximum number of processes to spin up
- use_multiprocessing: If True, use process based threading.
- Note that because
- this implementation relies on multiprocessing,
- you should not pass
- non picklable arguments to the generator
- as they can't be passed
- easily to children processes.
- shuffle: Whether to shuffle the order of the batches at
- the beginning of each epoch. Only used with instances
- of `Sequence` (keras.utils.Sequence).
- initial_epoch: Epoch at which to start training
- (useful for resuming a previous training run)
- **kwargs: support for legacy arguments.
-
- Returns:
- A `History` object.
-
- Raises:
- RuntimeError: if the model was never compiled.
- ValueError: In case the generator yields
- data in an invalid format.
-
- Example:
-
- ```python
- def generate_arrays_from_file(path):
- while 1:
- f = open(path)
- for line in f:
- # create Numpy arrays of input data
- # and labels, from each line in the file
- x, y = process_line(line)
- yield (x, y)
- f.close()
-
- model.fit_generator(generate_arrays_from_file('/my_file.txt'),
- steps_per_epoch=1000, epochs=10)
- ```
- """
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
- if kwargs:
- raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
-
- if not self.built:
- raise RuntimeError('The model needs to be compiled before being used.')
- return self.model.fit_generator(
- generator,
- steps_per_epoch,
- epochs,
- verbose=verbose,
- callbacks=callbacks,
- validation_data=validation_data,
- validation_steps=validation_steps,
- class_weight=class_weight,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing,
- shuffle=shuffle,
- initial_epoch=initial_epoch)
-
- def evaluate_generator(self,
- generator,
- steps=None,
- max_queue_size=10,
- workers=1,
- use_multiprocessing=False,
- **kwargs):
- """Evaluates the model on a data generator.
-
- The generator should return the same kind of data
- as accepted by `test_on_batch`.
-
- Arguments:
- generator: Generator yielding tuples (inputs, targets)
- or (inputs, targets, sample_weights)
- steps: Total number of steps (batches of samples)
- to yield from `generator` before stopping.
- Optional for `Sequence`: if unspecified, will use
- the `len(generator)` as a number of steps.
- max_queue_size: maximum size for the generator queue
- workers: maximum number of processes to spin up
- use_multiprocessing: if True, use process based threading.
- Note that because this implementation
- relies on multiprocessing, you should not pass
- non picklable arguments to the generator
- as they can't be passed easily to children processes.
- **kwargs: support for legacy arguments.
-
- Returns:
- Scalar test loss (if the model has no metrics)
- or list of scalars (if the model computes other metrics).
- The attribute `model.metrics_names` will give you
- the display labels for the scalar outputs.
-
- Raises:
- RuntimeError: if the model was never compiled.
- ValueError: In case the generator yields
- data in an invalid format.
- """
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
- if kwargs:
- raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
-
- if not self.built:
- raise RuntimeError('The model needs to be compiled before being used.')
- return self.model.evaluate_generator(
- generator,
- steps,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing)
-
- def predict_generator(self,
- generator,
- steps=None,
- max_queue_size=10,
- workers=1,
- use_multiprocessing=False,
- verbose=0,
- **kwargs):
- """Generates predictions for the input samples from a data generator.
-
- The generator should return the same kind of data as accepted by
- `predict_on_batch`.
-
- Arguments:
- generator: generator yielding batches of input samples.
- steps: Total number of steps (batches of samples)
- to yield from `generator` before stopping.
- Optional for `Sequence`: if unspecified, will use
- the `len(generator)` as a number of steps.
- max_queue_size: maximum size for the generator queue
- workers: maximum number of processes to spin up
- use_multiprocessing: if True, use process based threading.
- Note that because this implementation
- relies on multiprocessing, you should not pass
- non picklable arguments to the generator
- as they can't be passed easily to children processes.
- verbose: verbosity mode, 0 or 1.
- **kwargs: support for legacy arguments.
-
- Returns:
- A Numpy array of predictions.
-
- Raises:
- ValueError: In case the generator yields
- data in an invalid format.
- """
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
- if kwargs:
- raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
-
- if not self.built:
- self.build()
- return self.model.predict_generator(
- generator,
- steps,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing,
- verbose=verbose)
- def get_config(self):
- config = []
- for layer in self.layers:
- config.append({
- 'class_name': layer.__class__.__name__,
- 'config': layer.get_config()
- })
- return copy.deepcopy(config)
- @classmethod
- def from_config(cls, config, custom_objects=None):
- model = cls()
- for conf in config:
- layer = layer_module.deserialize(conf, custom_objects=custom_objects)
- model.add(layer)
- return model
+# API entries importable from `keras.models`:
+Model = training.Model # pylint: disable=invalid-name
+Sequential = sequential.Sequential # pylint: disable=invalid-name
+save_model = saving.save_model
+load_model = saving.load_model
+model_from_config = saving.model_from_config
+model_from_yaml = saving.model_from_yaml
+model_from_json = saving.model_from_json
def _clone_functional_model(model, input_tensors=None):
@@ -1365,7 +90,7 @@ def _clone_functional_model(model, input_tensors=None):
else:
# Make sure that all input tensors come from a Keras layer.
# If tensor comes from an input layer: cache the input layer.
- input_tensors = topology.to_list(input_tensors)
+ input_tensors = generic_utils.to_list(input_tensors)
input_tensors_ = []
for i, x in enumerate(input_tensors):
if not K.is_keras_tensor(x):
@@ -1402,7 +127,7 @@ def _clone_functional_model(model, input_tensors=None):
# Reuse previously cloned layer.
layer = layer_map[layer]
# Don't call InputLayer multiple times.
- if isinstance(layer, topology.InputLayer):
+ if isinstance(layer, InputLayer):
continue
# Gather inputs to call the new layer.
@@ -1427,8 +152,9 @@ def _clone_functional_model(model, input_tensors=None):
if has_arg(layer.call, 'mask'):
if 'mask' not in kwargs:
kwargs['mask'] = computed_mask
- output_tensors = topology.to_list(layer(computed_tensor, **kwargs))
- output_masks = topology.to_list(
+ output_tensors = generic_utils.to_list(layer(computed_tensor,
+ **kwargs))
+ output_masks = generic_utils.to_list(
layer.compute_mask(computed_tensor, computed_mask))
computed_tensors = [computed_tensor]
computed_masks = [computed_mask]
@@ -1438,8 +164,9 @@ def _clone_functional_model(model, input_tensors=None):
if has_arg(layer.call, 'mask'):
if 'mask' not in kwargs:
kwargs['mask'] = computed_masks
- output_tensors = topology.to_list(layer(computed_tensors, **kwargs))
- output_masks = topology.to_list(
+ output_tensors = generic_utils.to_list(layer(computed_tensors,
+ **kwargs))
+ output_masks = generic_utils.to_list(
layer.compute_mask(computed_tensors, computed_masks))
# Update tensor_map.
for x, y, mask in zip(reference_output_tensors, output_tensors,
@@ -1489,14 +216,14 @@ def _clone_sequential_model(model, input_tensors=None):
if input_tensors is None:
return Sequential(layers=layers, name=model.name)
else:
- if len(topology.to_list(input_tensors)) != 1:
+ if len(generic_utils.to_list(input_tensors)) != 1:
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
- x = topology.to_list(input_tensors)[0]
+ x = generic_utils.to_list(input_tensors)[0]
if K.is_keras_tensor(x):
origin_layer = x._keras_history[0]
- if isinstance(origin_layer, topology.InputLayer):
+ if isinstance(origin_layer, InputLayer):
return Sequential(layers=[origin_layer] + layers, name=model.name)
else:
raise ValueError('Cannot clone a `Sequential` model on top '
diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py
index 04017e4b28..5978ddd987 100644
--- a/tensorflow/python/keras/_impl/keras/models_test.py
+++ b/tensorflow/python/keras/_impl/keras/models_test.py
@@ -12,362 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for training routines."""
+"""Tests for `models.py` (model cloning, mainly)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
-import shutil
-import tempfile
-
import numpy as np
from tensorflow.python.keras._impl import keras
from tensorflow.python.platform import test
-from tensorflow.python.training import training as training_module
-
-try:
- import h5py # pylint:disable=g-import-not-at-top
-except ImportError:
- h5py = None
-
-
-class TestModelSaving(test.TestCase):
-
- def test_sequential_model_saving(self):
- if h5py is None:
- return # Skip test if models cannot be saved.
-
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.RepeatVector(3))
- model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
- model.compile(loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(lr=0.0001),
- metrics=[keras.metrics.categorical_accuracy],
- sample_weight_mode='temporal')
- x = np.random.random((1, 3))
- y = np.random.random((1, 3, 3))
- model.train_on_batch(x, y)
-
- out = model.predict(x)
- fd, fname = tempfile.mkstemp('.h5')
- keras.models.save_model(model, fname)
-
- new_model = keras.models.load_model(fname)
- os.close(fd)
- os.remove(fname)
-
- out2 = new_model.predict(x)
- self.assertAllClose(out, out2, atol=1e-05)
-
- # test that new updates are the same with both models
- x = np.random.random((1, 3))
- y = np.random.random((1, 3, 3))
- model.train_on_batch(x, y)
- new_model.train_on_batch(x, y)
- out = model.predict(x)
- out2 = new_model.predict(x)
- self.assertAllClose(out, out2, atol=1e-05)
-
- def test_sequential_model_saving_2(self):
- if h5py is None:
- return # Skip test if models cannot be saved.
-
- with self.test_session():
- # test with custom optimizer, loss
-
- class CustomOp(keras.optimizers.RMSprop):
- pass
-
- def custom_loss(y_true, y_pred):
- return keras.losses.mse(y_true, y_pred)
-
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.Dense(3))
- model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc'])
-
- x = np.random.random((1, 3))
- y = np.random.random((1, 3))
- model.train_on_batch(x, y)
-
- out = model.predict(x)
- fd, fname = tempfile.mkstemp('.h5')
- keras.models.save_model(model, fname)
-
- model = keras.models.load_model(
- fname,
- custom_objects={'CustomOp': CustomOp,
- 'custom_loss': custom_loss})
- os.close(fd)
- os.remove(fname)
-
- out2 = model.predict(x)
- self.assertAllClose(out, out2, atol=1e-05)
-
- def test_functional_model_saving(self):
- if h5py is None:
- return # Skip test if models cannot be saved.
-
- with self.test_session():
- inputs = keras.layers.Input(shape=(3,))
- x = keras.layers.Dense(2)(inputs)
- output = keras.layers.Dense(3)(x)
-
- model = keras.models.Model(inputs, output)
- model.compile(loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(lr=0.0001),
- metrics=[keras.metrics.categorical_accuracy])
- x = np.random.random((1, 3))
- y = np.random.random((1, 3))
- model.train_on_batch(x, y)
-
- out = model.predict(x)
- fd, fname = tempfile.mkstemp('.h5')
- keras.models.save_model(model, fname)
-
- model = keras.models.load_model(fname)
- os.close(fd)
- os.remove(fname)
-
- out2 = model.predict(x)
- self.assertAllClose(out, out2, atol=1e-05)
-
- def test_saving_without_compilation(self):
- if h5py is None:
- return # Skip test if models cannot be saved.
-
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.Dense(3))
- model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
-
- fd, fname = tempfile.mkstemp('.h5')
- keras.models.save_model(model, fname)
- model = keras.models.load_model(fname)
- os.close(fd)
- os.remove(fname)
-
- def test_saving_with_tf_optimizer(self):
- if h5py is None:
- return # Skip test if models cannot be saved.
-
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.Dense(3))
- model.compile(loss='mse',
- optimizer=training_module.AdadeltaOptimizer(0.1),
- metrics=['acc'])
-
- fd, fname = tempfile.mkstemp('.h5')
- keras.models.save_model(model, fname)
- model = keras.models.load_model(fname)
- os.close(fd)
- os.remove(fname)
-
- def test_saving_right_after_compilation(self):
- if h5py is None:
- return # Skip test if models cannot be saved.
-
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.Dense(3))
- model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
- model.model._make_train_function()
-
- fd, fname = tempfile.mkstemp('.h5')
- keras.models.save_model(model, fname)
- model = keras.models.load_model(fname)
- os.close(fd)
- os.remove(fname)
-
- def test_saving_lambda_numpy_array_arguments(self):
- if h5py is None:
- return # Skip test if models cannot be saved.
-
- mean = np.random.random((4, 2, 3))
- std = np.abs(np.random.random((4, 2, 3))) + 1e-5
- inputs = keras.layers.Input(shape=(4, 2, 3))
- output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
- arguments={'mu': mean, 'std': std})(inputs)
- model = keras.models.Model(inputs, output)
- model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
-
- fd, fname = tempfile.mkstemp('.h5')
- keras.models.save_model(model, fname)
-
- model = keras.models.load_model(fname)
- os.close(fd)
- os.remove(fname)
-
- self.assertAllClose(mean, model.layers[1].arguments['mu'])
- self.assertAllClose(std, model.layers[1].arguments['std'])
-
-
-class TestSequential(test.TestCase):
- """Most Sequential model API tests are covered in `training_test.py`.
- """
-
- def test_basic_methods(self):
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(1, input_dim=2))
- model.add(keras.layers.Dropout(0.3, name='dp'))
- model.add(keras.layers.Dense(2, kernel_regularizer='l2',
- kernel_constraint='max_norm'))
- model.build()
- self.assertEqual(model.state_updates, model.model.state_updates)
- self.assertEqual(model.get_layer(name='dp').name, 'dp')
-
- def test_sequential_pop(self):
- num_hidden = 5
- input_dim = 3
- batch_size = 5
- num_classes = 2
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
- model.add(keras.layers.Dense(num_classes))
- model.compile(loss='mse', optimizer='sgd')
- x = np.random.random((batch_size, input_dim))
- y = np.random.random((batch_size, num_classes))
- model.fit(x, y, epochs=1)
- model.pop()
- self.assertEqual(len(model.layers), 1)
- self.assertEqual(model.output_shape, (None, num_hidden))
- model.compile(loss='mse', optimizer='sgd')
- y = np.random.random((batch_size, num_hidden))
- model.fit(x, y, epochs=1)
-
- # Test popping single-layer model
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
- model.pop()
- self.assertEqual(len(model.layers), 0)
- self.assertEqual(len(model.outputs), 0)
-
- # Invalid use case
- model = keras.models.Sequential()
- with self.assertRaises(TypeError):
- model.pop()
-
- def test_sequential_weight_loading(self):
- if h5py is None:
- return
-
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- h5_path = os.path.join(temp_dir, 'test.h5')
-
- num_hidden = 5
- input_dim = 3
- batch_size = 5
- num_classes = 2
-
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
- model.add(keras.layers.Dense(num_classes))
-
- x = np.random.random((batch_size, input_dim))
- ref_y = model.predict(x)
-
- model.save_weights(h5_path)
-
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
- model.add(keras.layers.Dense(num_classes))
- model.load_weights(h5_path)
- y = model.predict(x)
-
- self.assertAllClose(y, ref_y)
-
- def test_invalid_use_cases(self):
- with self.test_session():
- # Added objects must be layer instances
- with self.assertRaises(TypeError):
- model = keras.models.Sequential()
- model.add(None)
-
- # Added layers must have an inputs shape
- with self.assertRaises(ValueError):
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(1))
-
- # Added layers cannot have multiple outputs
- class MyLayer(keras.layers.Layer):
-
- def call(self, inputs):
- return [3 * inputs, 2 * inputs]
-
- def compute_output_shape(self, input_shape):
- return [input_shape, input_shape]
-
- with self.assertRaises(ValueError):
- model = keras.models.Sequential()
- model.add(MyLayer(input_shape=(3,)))
- with self.assertRaises(TypeError):
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(1, input_dim=1))
- model.add(MyLayer())
-
- # Building empty model
- model = keras.models.Sequential()
- with self.assertRaises(TypeError):
- model.build()
-
- def test_nested_sequential_trainability(self):
- input_dim = 20
- num_units = 10
- num_classes = 2
-
- inner_model = keras.models.Sequential()
- inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,)))
-
- model = keras.models.Sequential()
- model.add(inner_model)
- model.add(keras.layers.Dense(num_classes))
-
- self.assertEqual(len(model.trainable_weights), 4)
- inner_model.trainable = False
- self.assertEqual(len(model.trainable_weights), 2)
- inner_model.trainable = True
- self.assertEqual(len(model.trainable_weights), 4)
-
- def test_sequential_update_disabling(self):
- val_a = np.random.random((10, 4))
- val_out = np.random.random((10, 4))
-
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.BatchNormalization(input_shape=(4,)))
-
- model.trainable = False
- assert not model.updates
-
- model.compile('sgd', 'mse')
- assert not model.updates
- assert not model.model.updates
-
- x1 = model.predict(val_a)
- model.train_on_batch(val_a, val_out)
- x2 = model.predict(val_a)
- self.assertAllClose(x1, x2, atol=1e-7)
-
- model.trainable = True
- model.compile('sgd', 'mse')
- assert model.updates
- assert model.model.updates
-
- model.train_on_batch(val_a, val_out)
- x2 = model.predict(val_a)
- assert np.abs(np.sum(x1 - x2)) > 1e-5
class TestModelCloning(test.TestCase):
diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
index 462d600bf8..5196bf1740 100644
--- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
@@ -509,3 +509,20 @@ def slice_arrays(arrays, start=None, stop=None):
return arrays[start:stop]
else:
return [None]
+
+
+def to_list(x):
+ """Normalizes a list/tensor into a list.
+
+ If a tensor is passed, we return
+ a list of size 1 containing the tensor.
+
+ Arguments:
+ x: target object to be normalized.
+
+ Returns:
+ A list.
+ """
+ if isinstance(x, list):
+ return x
+ return [x]
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index fedbf9e696..5e8937ad2c 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -326,6 +326,18 @@ class DynamicPartitionTest(test.TestCase):
with self.assertRaises(ValueError):
data_flow_ops.dynamic_partition(data, indices, num_partitions=4)
+ # see https://github.com/tensorflow/tensorflow/issues/17106
+ def testCUBBug(self):
+ x = constant_op.constant(np.random.randn(3072))
+ inds = [0]*189 + [1]*184 + [2]*184 + [3]*191 + [4]*192 + [5]*195 + [6]*195
+ inds += [7]*195 + [8]*188 + [9]*195 + [10]*188 + [11]*202 + [12]*194
+ inds += [13]*194 + [14]*194 + [15]*192
+ self.assertEqual(len(inds), x.shape[0])
+ partitioned = data_flow_ops.dynamic_partition(x, inds, 16)
+ with self.test_session() as sess:
+ res = sess.run(partitioned)
+ self.assertEqual(res[-1].shape[0], 192)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 14824962ea..96f5f81c1f 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -134,7 +134,10 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
input = ops.convert_to_tensor(input)
in_device = input.device
# TODO(ashankar): Does 'identity' need to invoke execution callbacks?
- if context.context().device_name != in_device:
+ context_device = context.context().device_name
+ if not context_device:
+ context_device = "/job:localhost/replica:0/task:0/device:CPU:0"
+ if context_device != in_device:
return input._copy() # pylint: disable=protected-access
return input
@@ -401,8 +404,11 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
else:
input_tensor = ops.convert_to_tensor(input)
input_shape = input_tensor.get_shape()
- if optimize and input_shape.is_fully_defined():
- return constant(input_shape.num_elements(), out_type, name=name)
+ if optimize:
+ if input_shape.is_fully_defined():
+ return constant(input_shape.num_elements(), out_type, name=name)
+ if input_shape.dims and any(dim == 0 for dim in input_shape.dims):
+ return constant(0, out_type, name=name)
return gen_array_ops.size(input, name=name, out_type=out_type)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 9e7f37d80f..69afa618e2 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -35,6 +35,12 @@ def _safe_shape_div(x, y):
return x // math_ops.maximum(y, 1)
+@ops.RegisterGradient("ArgMax")
+def _ArgMaxGrad(op, grad):
+ del op, grad
+ return [None, None]
+
+
@ops.RegisterGradient("Sum")
def _SumGrad(op, grad):
"""Gradient for Sum."""
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 678d6322aa..454cc3add5 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -98,6 +98,9 @@ class _RefVariableProcessor(_OptimizableVariable):
def __init__(self, v):
self._v = v
+ def __str__(self):
+ return "<_RefVariableProcessor(%s)>" % self._v
+
def target(self):
return self._v._ref() # pylint: disable=protected-access
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 3888e9bba4..9afd1e6643 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -50,6 +50,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpointable
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.util import compat
@@ -196,8 +197,8 @@ class BaseSaverBuilder(object):
# Copy the restored tensor to the variable's device.
with ops.device(self._var_device):
restored_tensor = array_ops.identity(restored_tensor)
- return resource_variable_ops.shape_safe_assign_variable_handle(
- self.handle_op, self._var_shape, restored_tensor)
+ return resource_variable_ops.shape_safe_assign_variable_handle(
+ self.handle_op, self._var_shape, restored_tensor)
def __init__(self, write_version=saver_pb2.SaverDef.V2):
self._write_version = write_version
@@ -577,6 +578,11 @@ class BaseSaverBuilder(object):
names_to_saveables[name].append(var)
else:
names_to_saveables[name] = [var]
+ elif (isinstance(var, checkpointable.CheckpointableBase)
+ and not isinstance(var, variables.Variable)):
+ names_to_saveables.update(
+ BaseSaverBuilder.OpListToDict(
+ list(var._gather_saveables_for_checkpoint().values())))
else:
if context.in_graph_mode():
if convert_variable_to_tensor:
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index c5a6f49df5..f00f98db00 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -66,6 +66,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import adam
+from tensorflow.python.training import checkpointable
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
@@ -2660,5 +2661,92 @@ class ScopedGraphTest(test.TestCase):
self.assertEqual(2.0, var_dict2["variable2:0"].eval())
+class _OwnsAVariableSimple(checkpointable.CheckpointableBase):
+ """A Checkpointable object which can be saved using a tf.train.Saver."""
+
+ def __init__(self):
+ self.non_dep_variable = variable_scope.get_variable(
+ name="non_dep_variable", initializer=6., use_resource=True)
+
+ def _gather_saveables_for_checkpoint(self):
+ return {checkpointable.VARIABLE_VALUE_KEY: self.non_dep_variable}
+
+ # The Saver sorts by name before parsing, so we need a name property.
+ @property
+ def name(self):
+ return self.non_dep_variable.name
+
+
+class _MirroringSaveable(
+ saver_module.BaseSaverBuilder.ResourceVariableSaveable):
+
+ def __init__(self, primary_variable, mirrored_variable):
+ self._primary_variable = primary_variable
+ self._mirrored_variable = mirrored_variable
+ super(_MirroringSaveable, self).__init__(
+ self._primary_variable, "", self._primary_variable.name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ tensor, = restored_tensors
+ return control_flow_ops.group(
+ self._primary_variable.assign(tensor),
+ self._mirrored_variable.assign(tensor))
+
+
+class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
+ """A Checkpointable object which returns a more complex SaveableObject."""
+
+ def __init__(self):
+ self.non_dep_variable = variable_scope.get_variable(
+ name="non_dep_variable", initializer=6., use_resource=True)
+ self.mirrored = variable_scope.get_variable(
+ name="mirrored", initializer=15., use_resource=True)
+
+ def _gather_saveables_for_checkpoint(self):
+ saveable = _MirroringSaveable(
+ primary_variable=self.non_dep_variable,
+ mirrored_variable=self.mirrored)
+ return {checkpointable.VARIABLE_VALUE_KEY: saveable}
+
+ # The Saver sorts by name before parsing, so we need a name property.
+ @property
+ def name(self):
+ return self.non_dep_variable.name
+
+
+@test_util.with_c_api
+class CheckpointableCompatibilityTests(test.TestCase):
+
+ # TODO(allenl): Track down python3 reference cycles in these tests.
+ @test_util.run_in_graph_and_eager_modes()
+ def testNotSaveableButIsCheckpointable(self):
+ v = _OwnsAVariableSimple()
+ saver = saver_module.Saver(var_list=[v])
+ test_dir = self.get_temp_dir()
+ prefix = os.path.join(test_dir, "ckpt")
+ self.evaluate(v.non_dep_variable.assign(42.))
+ with self.test_session() as sess:
+ save_path = saver.save(sess, prefix)
+ self.evaluate(v.non_dep_variable.assign(43.))
+ saver.restore(sess, save_path)
+ self.assertEqual(42., self.evaluate(v.non_dep_variable))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testMoreComplexSaveableReturned(self):
+ v = _OwnsMirroredVariables()
+ saver = saver_module.Saver(var_list=[v])
+ test_dir = self.get_temp_dir()
+ prefix = os.path.join(test_dir, "ckpt")
+ self.evaluate(v.non_dep_variable.assign(42.))
+ with self.test_session() as sess:
+ save_path = saver.save(sess, prefix)
+ self.evaluate(v.non_dep_variable.assign(43.))
+ self.evaluate(v.mirrored.assign(44.))
+ saver.restore(sess, save_path)
+ self.assertEqual(42., self.evaluate(v.non_dep_variable))
+ self.assertEqual(42., self.evaluate(v.mirrored))
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
index 04724e3a1a..241db8956a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
@@ -1,8 +1,8 @@
path: "tensorflow.keras.Model"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Network\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.network.Network\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
index c94bd2faa4..9673a508d6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
@@ -1,9 +1,9 @@
path: "tensorflow.keras.Sequential"
tf_class {
- is_instance: "<class \'tensorflow.python.keras._impl.keras.models.Sequential\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.sequential.Sequential\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Network\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.network.Network\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
index f4ab075959..041acf29ff 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Activation"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Activation\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
index eb558cddaf..48143b2cd6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.ActivityRegularization"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.ActivityRegularization\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
index 770a107b66..11f78fed97 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Add"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge.Add\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge._Merge\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
index 0ce42b706e..84eb825632 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.AlphaDropout"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.noise.AlphaDropout\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index d6c98fa225..ab377a248f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.AveragePooling1D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.AveragePooling1D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
index 754fd310c6..c2edd79f52 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.AveragePooling2D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.AveragePooling2D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
index 9b62880c79..f3f37eed99 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.AveragePooling3D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.AveragePooling3D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
index b371ad148c..31d1d1c049 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Average"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge.Average\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge._Merge\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index 3e2aba55fd..6582e1b18e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.AveragePooling1D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.AveragePooling1D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
index fb37308cce..12f66095d2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.AveragePooling2D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.AveragePooling2D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
index 813470ffc7..3a45fa180e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.AveragePooling3D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.AveragePooling3D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
index e251ac18e5..a0f272c178 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.BatchNormalization"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.normalization.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.layers.normalization.BatchNormalization\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
index 699208a0b9..9c7d3154ad 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Bidirectional"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.wrappers.Bidirectional\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.wrappers.Wrapper\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
index ff08def0a0..949b225e54 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Concatenate"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge.Concatenate\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge._Merge\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index 6db22ca032..a736c84a10 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional_recurrent.ConvLSTM2D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional_recurrent.ConvRecurrent2D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
index 577f206e35..95f9afed28 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv1D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv1D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index 72924c32b4..38ba15400a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2DTranspose\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
index 16be08d9b2..bc84e2a97e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index 11e05f884d..0802578c22 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3DTranspose\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
index 72b72d6b3b..8ad4646c74 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv3D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
index ee93247f63..110e267b75 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv1D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv1D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index e5023287e5..24cfc83af6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2DTranspose\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
index ba38cb7121..c56e89187f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 58724a1e16..3674f2746c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3DTranspose\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
index 98d52c430c..5a8f9d7702 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv3D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
index 33b6ebe1af..caa748be81 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Cropping1D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Cropping1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
index 4b241ebb0f..97bd4a265a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Cropping2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Cropping2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
index 1856a9ee21..20c43eeed1 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Cropping3D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Cropping3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
index a8c37af31f..256f0e4bdf 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Dense"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Dense\'>"
is_instance: "<class \'tensorflow.python.layers.core.Dense\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
index 07d3f023e5..d1e53f900c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Dot"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge.Dot\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge._Merge\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
index e2e21b5f12..b010ff6805 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Dropout"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Dropout\'>"
is_instance: "<class \'tensorflow.python.layers.core.Dropout\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
index 92b9760d53..fffd3854bb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.ELU"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.advanced_activations.ELU\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
index 83c528b401..1155fe03fc 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Embedding"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.embeddings.Embedding\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
index 7360975288..5e4bebb15b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Flatten"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index b329f1c46b..cb9bb3d821 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.GRUCell"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.GRUCell\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index c741d4d6e6..9a36e80649 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GRU"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.GRU\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index 57596badf1..eb32238e15 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.GaussianDropout"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.noise.GaussianDropout\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
index 3829353cc3..37fc8e29ae 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.GaussianNoise"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.noise.GaussianNoise\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index e53e78a977..490816458b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAveragePooling1D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalAveragePooling1D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
index 48fcd1044e..ab49f67f33 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAveragePooling2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalAveragePooling2D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
index 66c06ed472..3d7cb3ba49 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAveragePooling3D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalAveragePooling3D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 4f2420f74a..c99ddab4f3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAvgPool1D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalAveragePooling1D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
index 7912a6d933..290d2eaebe 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAvgPool2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalAveragePooling2D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
index d5b2d2c274..cf63069641 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAvgPool3D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalAveragePooling3D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index d88ff17eb6..2dadc67c09 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPool1D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalMaxPooling1D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
index c8cc5a0ddf..1a1a1dcf64 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPool2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalMaxPooling2D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
index 7956c5a340..44898e23ad 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPool3D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalMaxPooling3D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 0a7e16413d..941d867d24 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPooling1D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalMaxPooling1D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
index 6c8a58a996..9a5a6325f8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPooling2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalMaxPooling2D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
index 7678ce8aab..7a0c1932f6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPooling3D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.GlobalMaxPooling3D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling._GlobalPooling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
index 1e9370b02f..f679c1d006 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.InputLayer"
tf_class {
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.InputLayer\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.input_layer.InputLayer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 3b171b137a..ad1e7f2cad 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.LSTMCell"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.LSTMCell\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 29d9cf78ab..6dad4b4897 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.LSTM"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.LSTM\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
index ca01449299..fa45d8c902 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Lambda"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Lambda\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
index c52ad72754..023d6c0d69 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.keras.layers.Layer"
tf_class {
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index 8134fb7386..e429fced77 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.LeakyReLU"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.advanced_activations.LeakyReLU\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index c5d4523009..462568124f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.LocallyConnected1D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.local.LocallyConnected1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index bcbed9241b..11bf6a2b42 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.LocallyConnected2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.local.LocallyConnected2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
index 244e79b4ff..a932448891 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Masking"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Masking\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 56cbf5df78..6ff2adddac 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.MaxPooling1D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.MaxPooling1D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
index 33c2d30e86..2957673d4d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.MaxPooling2D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.MaxPooling2D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
index 94f91059b7..2191c10b73 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.MaxPooling3D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.MaxPooling3D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 247230a6d6..af750ac1b6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.MaxPooling1D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.MaxPooling1D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
index 8d61b67e7c..9046061510 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.MaxPooling2D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.MaxPooling2D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
index ad2e308020..a40666807b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.pooling.MaxPooling3D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling.MaxPooling3D\'>"
is_instance: "<class \'tensorflow.python.layers.pooling._Pooling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
index ff0db15f19..65378cef42 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Maximum"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge.Maximum\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge._Merge\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
index 1d3f33f045..b037559e02 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Multiply"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge.Multiply\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.merge._Merge\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
index c86bc49b22..b3a7f47fa5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.PReLU"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.advanced_activations.PReLU\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
index 2043e1a126..b2f22f7da3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Permute"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Permute\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
index ad539a7c4c..792eacf90d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.RNN"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
index 4b0e98520a..5b79a021ca 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.RepeatVector"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.RepeatVector\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
index 34bc71af8a..99c64505ee 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Reshape"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Reshape\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
index dd67b76523..d5873ccf76 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv1D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index 5d898fb2bd..76b4c10a46 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
index bf62c095e7..40cd87de5f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv1D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index c758d87993..c44c0da148 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 6e3cde3e3e..bd70c31c38 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.SimpleRNNCell"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.SimpleRNNCell\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index 6fafc77b94..de717976cf 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.SimpleRNN"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.SimpleRNN\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
index ee4b2fa39e..a93b7b8f6e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Softmax"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.advanced_activations.Softmax\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
index e4727072e3..4dc24b195e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.SpatialDropout1D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Dropout\'>"
is_instance: "<class \'tensorflow.python.layers.core.Dropout\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
index c5ff704311..a3bb1cc414 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.SpatialDropout2D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Dropout\'>"
is_instance: "<class \'tensorflow.python.layers.core.Dropout\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
index 476a7f362c..f9a78106fa 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
@@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.SpatialDropout3D\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Dropout\'>"
is_instance: "<class \'tensorflow.python.layers.core.Dropout\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index 3dde1e5769..5aa21f4022 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.StackedRNNCells"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.StackedRNNCells\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index ef31c5443e..88e8a46572 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.ThresholdedReLU"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.advanced_activations.ThresholdedReLU\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
index 1e176d8d4b..f2a7673998 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.TimeDistributed"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.wrappers.TimeDistributed\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.wrappers.Wrapper\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
index a81b83be49..4db82ddfa9 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.UpSampling1D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.UpSampling1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
index 5403279d45..61e65ad56d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.UpSampling2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.UpSampling2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
index 96c337caf2..3d9402db4e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.UpSampling3D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.UpSampling3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
index ea3bb2f8f5..0223799ed4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.Wrapper"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.wrappers.Wrapper\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
index b81a4b1c50..2e4429833a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.ZeroPadding1D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.ZeroPadding1D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
index 1a26f2f3c9..26cf7b9e49 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.ZeroPadding2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.ZeroPadding2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
index 310277fe67..64d35d9447 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.ZeroPadding3D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.ZeroPadding3D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
index 88eb237cec..18be9c9701 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -1,8 +1,8 @@
path: "tensorflow.keras.models.Model"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Network\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.network.Network\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
index 34f10f01ad..b934632922 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
@@ -1,9 +1,9 @@
path: "tensorflow.keras.models.Sequential"
tf_class {
- is_instance: "<class \'tensorflow.python.keras._impl.keras.models.Sequential\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.sequential.Sequential\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Network\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.network.Network\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index d5c61baa8b..85f423f236 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -667,15 +667,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "cub_archive",
urls = [
- "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip",
- "https://github.com/NVlabs/cub/archive/1.7.4.zip",
+ "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip",
+ "https://github.com/NVlabs/cub/archive/1.8.0.zip",
],
- sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31",
- strip_prefix = "cub-1.7.4",
+ sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
+ strip_prefix = "cub-1.8.0",
build_file = str(Label("//third_party:cub.BUILD")),
- # TODO: remove the patch when upstream fix is accepted and released.
- # PR with a fix: https://github.com/NVlabs/cub/pull/125
- patch_file = str(Label("//third_party/cub:fix_compilation_in_clang.patch")),
)
tf_http_archive(
diff --git a/third_party/cub/BUILD b/third_party/cub/BUILD
deleted file mode 100644
index e69de29bb2..0000000000
--- a/third_party/cub/BUILD
+++ /dev/null
diff --git a/third_party/cub/fix_compilation_in_clang.patch b/third_party/cub/fix_compilation_in_clang.patch
deleted file mode 100644
index 384e674f20..0000000000
--- a/third_party/cub/fix_compilation_in_clang.patch
+++ /dev/null
@@ -1,23 +0,0 @@
-From 565b77f7c82048871a4d5e3e506dc663d53cd469 Mon Sep 17 00:00:00 2001
-From: Ilya Biryukov <ibiryukov@google.com>
-Date: Fri, 26 Jan 2018 18:46:06 +0100
-Subject: [PATCH] Added missing 'template' keyword.
-
-To unbreak compilation with clang.
----
- cub/device/dispatch/dispatch_radix_sort.cuh | 2 +-
- 1 file changed, 1 insertion(+), 1 deletion(-)
-
-diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh
-index 7fbc621f..f622e212 100644
---- a/cub/device/dispatch/dispatch_radix_sort.cuh
-+++ b/cub/device/dispatch/dispatch_radix_sort.cuh
-@@ -104,7 +104,7 @@ __global__ void DeviceRadixSortUpsweepKernel(
- CTA_SYNC();
-
- // Write out digit counts (striped)
-- upsweep.ExtractCounts<IS_DESCENDING>(d_spine, gridDim.x, blockIdx.x);
-+ upsweep.template ExtractCounts<IS_DESCENDING>(d_spine, gridDim.x, blockIdx.x);
- }
-
-