diff options
-rw-r--r-- | tensorflow/contrib/ffmpeg/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/contrib/ffmpeg/decode_audio_op.cc | 230 | ||||
-rw-r--r-- | tensorflow/contrib/ffmpeg/decode_audio_op_test.py | 89 | ||||
-rw-r--r-- | tensorflow/contrib/ffmpeg/encode_audio_op.cc | 149 | ||||
-rw-r--r-- | tensorflow/contrib/ffmpeg/encode_audio_op_test.py | 59 | ||||
-rw-r--r-- | tensorflow/contrib/ffmpeg/ffmpeg_ops.py | 49 |
6 files changed, 477 insertions, 103 deletions
diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index e495ab4880..15224fbbcb 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -87,6 +87,8 @@ tf_py_test( srcs = ["decode_audio_op_test.py"], additional_deps = [ ":ffmpeg_ops_py", + "@six_archive//:six", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", ], @@ -102,6 +104,8 @@ tf_py_test( srcs = ["encode_audio_op_test.py"], additional_deps = [ ":ffmpeg_ops_py", + "@six_archive//:six", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", ], diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index a6a945094b..4b1c8a337e 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -60,8 +60,175 @@ class FileDeleter { const string filename_; }; +/* + * Decoding implementation, shared across V1 and V2 ops. Creates a new + * output in the context. + */ +void Decode(OpKernelContext* context, + const tensorflow::StringPiece& file_contents, + const string& file_format, const int32 samples_per_second, + const int32 channel_count) { + // Write the input data to a temp file. + const string temp_filename = GetTempFilename(file_format); + OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents)); + FileDeleter deleter(temp_filename); + + // Run FFmpeg on the data and verify results. + std::vector<float> output_samples; + Status result = + ffmpeg::ReadAudioFile(temp_filename, file_format, samples_per_second, + channel_count, &output_samples); + if (result.code() == error::Code::NOT_FOUND) { + OP_REQUIRES( + context, result.ok(), + errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg " + "can be found at http://www.ffmpeg.org.")); + } else if (result.code() == error::UNKNOWN) { + LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message() + << "'. Returning empty tensor."; + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({0, 0}), &output)); + return; + } else { + OP_REQUIRES_OK(context, result); + } + OP_REQUIRES(context, !output_samples.empty(), + errors::Unknown("No output created by FFmpeg.")); + OP_REQUIRES( + context, output_samples.size() % channel_count == 0, + errors::Unknown("FFmpeg created non-integer number of audio frames.")); + + // Copy the output data to the output Tensor. + Tensor* output = nullptr; + const int64 frame_count = output_samples.size() / channel_count; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({frame_count, channel_count}), &output)); + auto matrix = output->tensor<float, 2>(); + for (int32 frame = 0; frame < frame_count; ++frame) { + for (int32 channel = 0; channel < channel_count; ++channel) { + matrix(frame, channel) = output_samples[frame * channel_count + channel]; + } + } +} + } // namespace +/* + * Supersedes `DecodeAudioOp`. Allows all parameters to be inputs + * instead of attributes, so that they can be given as tensors rather + * than constants only. + */ +class DecodeAudioOpV2 : public OpKernel { + public: + explicit DecodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OP_REQUIRES( + context, context->num_inputs() == 4, + errors::InvalidArgument("DecodeAudio requires exactly four inputs.")); + + const Tensor& contents_tensor = context->input(0); + const Tensor& file_format_tensor = context->input(1); + const Tensor& samples_per_second_tensor = context->input(2); + const Tensor& channel_count_tensor = context->input(3); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_tensor.shape()), + errors::InvalidArgument( + "contents must be a rank-0 tensor but got shape ", + contents_tensor.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(file_format_tensor.shape()), + errors::InvalidArgument( + "file_format must be a rank-0 tensor but got shape ", + file_format_tensor.shape().DebugString())); + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(samples_per_second_tensor.shape()), + errors::InvalidArgument( + "samples_per_second must be a rank-0 tensor but got shape ", + samples_per_second_tensor.shape().DebugString())); + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(channel_count_tensor.shape()), + errors::InvalidArgument( + "channel_count must be a rank-0 tensor but got shape ", + channel_count_tensor.shape().DebugString())); + + const tensorflow::StringPiece contents = contents_tensor.scalar<string>()(); + const string file_format = + str_util::Lowercase(file_format_tensor.scalar<string>()()); + const int32 samples_per_second = + samples_per_second_tensor.scalar<int32>()(); + const int32 channel_count = channel_count_tensor.scalar<int32>()(); + + const std::set<string> valid_file_formats( + kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats)); + OP_REQUIRES( + context, valid_file_formats.count(file_format) == 1, + errors::InvalidArgument("file_format must be one of {", + str_util::Join(valid_file_formats, ", "), + "}, but was: \"", file_format, "\"")); + OP_REQUIRES(context, samples_per_second > 0, + errors::InvalidArgument( + "samples_per_second must be positive, but got: ", + samples_per_second)); + OP_REQUIRES( + context, channel_count > 0, + errors::InvalidArgument("channel_count must be positive, but got: ", + channel_count)); + + Decode(context, contents, file_format, samples_per_second, channel_count); + } +}; + +REGISTER_KERNEL_BUILDER(Name("DecodeAudioV2").Device(DEVICE_CPU), + DecodeAudioOpV2); + +REGISTER_OP("DecodeAudioV2") + .Input("contents: string") + .Input("file_format: string") + .Input("samples_per_second: int32") + .Input("channel_count: int32") + .Output("sampled_audio: float") + .SetShapeFn([](shape_inference::InferenceContext* c) { + const Tensor* channels_tensor = c->input_tensor(3); + if (channels_tensor == nullptr) { + c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim())); + return Status::OK(); + } + const int32 channels = channels_tensor->scalar<int32>()(); + if (channels <= 0) { + return errors::InvalidArgument( + "channel_count must be positive, but got: ", channels); + } + c->set_output(0, c->Matrix(c->UnknownDim(), channels)); + return Status::OK(); + }) + .Doc(R"doc( +Processes the contents of an audio file into a tensor using FFmpeg to decode +the file. + +One row of the tensor is created for each channel in the audio file. Each +channel contains audio samples starting at the beginning of the audio and +having `1/samples_per_second` time between them. If the `channel_count` is +different from the contents of the file, channels will be merged or created. + +contents: The binary audio file contents, as a string or rank-0 string + tensor. +file_format: A string or rank-0 string tensor describing the audio file + format. This must be one of: "mp3", "mp4", "ogg", "wav". +samples_per_second: The number of samples per second that the audio + should have, as an `int` or rank-0 `int32` tensor. This value must + be positive. +channel_count: The number of channels of audio to read, as an int rank-0 + int32 tensor. Must be a positive integer. +sampled_audio: A rank-2 tensor containing all tracks of the audio. + Dimension 0 is time and dimension 1 is the channel. If ffmpeg fails + to decode the audio then an empty tensor will be returned. +)doc"); + +/* + * Deprecated in favor of DecodeAudioOpV2. + */ class DecodeAudioOp : public OpKernel { public: explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) { @@ -69,15 +236,11 @@ class DecodeAudioOp : public OpKernel { file_format_ = str_util::Lowercase(file_format_); const std::set<string> valid_file_formats( kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats)); - OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1, - errors::InvalidArgument( - "file_format arg must be in {", - str_util::Join(valid_file_formats, ", "), "}.")); - - OP_REQUIRES_OK( - context, context->GetAttr("samples_per_second", &samples_per_second_)); - OP_REQUIRES(context, samples_per_second_ > 0, - errors::InvalidArgument("samples_per_second must be > 0.")); + OP_REQUIRES( + context, valid_file_formats.count(file_format_) == 1, + errors::InvalidArgument("file_format must be one of {", + str_util::Join(valid_file_formats, ", "), + "}, but was: \"", file_format_, "\"")); OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_)); OP_REQUIRES(context, channel_count_ > 0, @@ -95,51 +258,9 @@ class DecodeAudioOp : public OpKernel { errors::InvalidArgument("contents must be scalar but got shape ", contents.shape().DebugString())); - // Write the input data to a temp file. const tensorflow::StringPiece file_contents = contents.scalar<string>()(); - const string input_filename = GetTempFilename(file_format_); - OP_REQUIRES_OK(context, WriteFile(input_filename, file_contents)); - FileDeleter deleter(input_filename); - - // Run FFmpeg on the data and verify results. - std::vector<float> output_samples; - Status result = - ffmpeg::ReadAudioFile(input_filename, file_format_, samples_per_second_, - channel_count_, &output_samples); - if (result.code() == error::Code::NOT_FOUND) { - OP_REQUIRES( - context, result.ok(), - errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg " - "can be found at http://www.ffmpeg.org.")); - } else if (result.code() == error::UNKNOWN) { - LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message() - << "'. Returning empty tensor."; - Tensor* output = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({0, 0}), &output)); - return; - } else { - OP_REQUIRES_OK(context, result); - } - OP_REQUIRES(context, !output_samples.empty(), - errors::Unknown("No output created by FFmpeg.")); - OP_REQUIRES( - context, output_samples.size() % channel_count_ == 0, - errors::Unknown("FFmpeg created non-integer number of audio frames.")); - - // Copy the output data to the output Tensor. - Tensor* output = nullptr; - const int64 frame_count = output_samples.size() / channel_count_; - OP_REQUIRES_OK(context, - context->allocate_output( - 0, TensorShape({frame_count, channel_count_}), &output)); - auto matrix = output->tensor<float, 2>(); - for (int32 frame = 0; frame < frame_count; ++frame) { - for (int32 channel = 0; channel < channel_count_; ++channel) { - matrix(frame, channel) = - output_samples[frame * channel_count_ + channel]; - } - } + Decode(context, file_contents, file_format_, samples_per_second_, + channel_count_); } private: @@ -178,8 +299,7 @@ contents: The binary audio file contents. sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0 is time and dimension 1 is the channel. If ffmpeg fails to decode the audio then an empty tensor will be returned. -file_format: A string describing the audio file format. This can be "wav" or - "mp3". +file_format: A string describing the audio file format. This can be "mp3", "mp4", "ogg", or "wav". samples_per_second: The number of samples per second that the audio should have. channel_count: The number of channels of audio to read. )doc"); diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py index e4ed46b1e2..0d7c9cb99e 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py +++ b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py @@ -20,7 +20,11 @@ from __future__ import print_function import os.path +import six + from tensorflow.contrib import ffmpeg +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -28,7 +32,8 @@ from tensorflow.python.platform import test class DecodeAudioOpTest(test.TestCase): def _loadFileAndTest(self, filename, file_format, duration_sec, - samples_per_second, channel_count): + samples_per_second, channel_count, + samples_per_second_tensor=None, feed_dict=None): """Loads an audio file and validates the output tensor. Args: @@ -37,7 +42,16 @@ class DecodeAudioOpTest(test.TestCase): duration_sec: The duration of the audio contained in the file in seconds. samples_per_second: The desired sample rate in the output tensor. channel_count: The desired channel count in the output tensor. + samples_per_second_tensor: The value to pass to the corresponding + parameter in the instantiated `decode_audio` op. If not + provided, will default to a constant value of + `samples_per_second`. Useful for providing a placeholder. + feed_dict: Used when evaluating the `decode_audio` op. If not + provided, will be empty. Useful when providing a placeholder for + `samples_per_second_tensor`. """ + if samples_per_second_tensor is None: + samples_per_second_tensor = samples_per_second with self.test_session(): path = os.path.join(resource_loader.get_data_files_path(), 'testdata', filename) @@ -47,9 +61,9 @@ class DecodeAudioOpTest(test.TestCase): audio_op = ffmpeg.decode_audio( contents, file_format=file_format, - samples_per_second=samples_per_second, + samples_per_second=samples_per_second_tensor, channel_count=channel_count) - audio = audio_op.eval() + audio = audio_op.eval(feed_dict=feed_dict or {}) self.assertEqual(len(audio.shape), 2) self.assertNear( duration_sec * samples_per_second, @@ -104,6 +118,75 @@ class DecodeAudioOpTest(test.TestCase): audio = audio_op.eval() self.assertEqual(audio.shape, (0, 0)) + def testSampleRatePlaceholder(self): + placeholder = array_ops.placeholder(dtypes.int32) + self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1, + samples_per_second_tensor=placeholder, + feed_dict={placeholder: 20000}) + + def testSampleRateBadType(self): + placeholder = array_ops.placeholder(dtypes.float32) + with self.assertRaises(TypeError): + self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1, + samples_per_second_tensor=placeholder, + feed_dict={placeholder: 20000.0}) + + def testSampleRateBadValue_Zero(self): + placeholder = array_ops.placeholder(dtypes.int32) + with six.assertRaisesRegex(self, Exception, + r'samples_per_second must be positive'): + self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1, + samples_per_second_tensor=placeholder, + feed_dict={placeholder: 0}) + + def testSampleRateBadValue_Negative(self): + placeholder = array_ops.placeholder(dtypes.int32) + with six.assertRaisesRegex(self, Exception, + r'samples_per_second must be positive'): + self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1, + samples_per_second_tensor=placeholder, + feed_dict={placeholder: -2}) + + def testInvalidFileFormat(self): + with six.assertRaisesRegex(self, Exception, + r'file_format must be one of'): + self._loadFileAndTest('mono_16khz.mp3', 'docx', 0.57, 20000, 1) + + def testStaticShapeInference_ConstantChannelCount(self): + with self.test_session(): + audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~', + file_format='wav', + samples_per_second=44100, + channel_count=2) + self.assertEqual([None, 2], audio_op.shape.as_list()) + + def testStaticShapeInference_NonConstantChannelCount(self): + with self.test_session(): + channel_count = array_ops.placeholder(dtypes.int32) + audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~', + file_format='wav', + samples_per_second=44100, + channel_count=channel_count) + self.assertEqual([None, None], audio_op.shape.as_list()) + + def testStaticShapeInference_ZeroChannelCountInvalid(self): + with self.test_session(): + with six.assertRaisesRegex(self, Exception, + r'channel_count must be positive'): + ffmpeg.decode_audio(b'~~~ wave ~~~', + file_format='wav', + samples_per_second=44100, + channel_count=0) + + def testStaticShapeInference_NegativeChannelCountInvalid(self): + with self.test_session(): + with six.assertRaisesRegex(self, Exception, + r'channel_count must be positive'): + ffmpeg.decode_audio(b'~~~ wave ~~~', + file_format='wav', + samples_per_second=44100, + channel_count=-2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op.cc b/tensorflow/contrib/ffmpeg/encode_audio_op.cc index bd3d6ae699..c00cccd846 100644 --- a/tensorflow/contrib/ffmpeg/encode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/encode_audio_op.cc @@ -22,7 +22,137 @@ namespace tensorflow { namespace ffmpeg { +namespace { +/* + * Encoding implementation, shared across V1 and V2 ops. Creates a new + * output in the context. + */ +void Encode(OpKernelContext* context, const Tensor& contents, + const string& file_format, const int32 bits_per_second, + const int32 samples_per_second) { + std::vector<float> samples; + samples.reserve(contents.NumElements()); + for (int32 i = 0; i < contents.NumElements(); ++i) { + samples.push_back(contents.flat<float>()(i)); + } + const int32 channel_count = contents.dim_size(1); + string encoded_audio; + OP_REQUIRES_OK( + context, CreateAudioFile(file_format, bits_per_second, samples_per_second, + channel_count, samples, &encoded_audio)); + + // Copy the encoded audio file to the output tensor. + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output)); + output->scalar<string>()() = encoded_audio; +} + +} // namespace + +/* + * Supersedes `EncodeAudioOp`. Allows all parameters to be inputs + * instead of attributes, so that the sample rate (and, probably less + * usefully, the output file format) can be given as tensors rather than + * constants only. + */ +class EncodeAudioOpV2 : public OpKernel { + public: + explicit EncodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OP_REQUIRES( + context, context->num_inputs() == 4, + errors::InvalidArgument("EncodeAudio requires exactly four inputs.")); + + const Tensor& contents = context->input(0); + const Tensor& file_format_tensor = context->input(1); + const Tensor& samples_per_second_tensor = context->input(2); + const Tensor& bits_per_second_tensor = context->input(3); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()), + errors::InvalidArgument( + "sampled_audio must be a rank-2 tensor but got shape ", + contents.shape().DebugString())); + OP_REQUIRES( + context, contents.NumElements() <= std::numeric_limits<int32>::max(), + errors::InvalidArgument( + "sampled_audio cannot have more than 2^31 entries. Shape = ", + contents.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(file_format_tensor.shape()), + errors::InvalidArgument( + "file_format must be a rank-0 tensor but got shape ", + file_format_tensor.shape().DebugString())); + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(samples_per_second_tensor.shape()), + errors::InvalidArgument( + "samples_per_second must be a rank-0 tensor but got shape ", + samples_per_second_tensor.shape().DebugString())); + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(bits_per_second_tensor.shape()), + errors::InvalidArgument( + "bits_per_second must be a rank-0 tensor but got shape ", + bits_per_second_tensor.shape().DebugString())); + + const string file_format = + str_util::Lowercase(file_format_tensor.scalar<string>()()); + const int32 samples_per_second = + samples_per_second_tensor.scalar<int32>()(); + const int32 bits_per_second = bits_per_second_tensor.scalar<int32>()(); + + OP_REQUIRES(context, file_format == "wav", + errors::InvalidArgument( + "file_format must be \"wav\", but got: ", file_format)); + OP_REQUIRES(context, samples_per_second > 0, + errors::InvalidArgument( + "samples_per_second must be positive, but got: ", + samples_per_second)); + OP_REQUIRES( + context, bits_per_second > 0, + errors::InvalidArgument("bits_per_second must be positive, but got: ", + bits_per_second)); + + Encode(context, contents, file_format, bits_per_second, samples_per_second); + } +}; + +REGISTER_KERNEL_BUILDER(Name("EncodeAudioV2").Device(DEVICE_CPU), + EncodeAudioOpV2); + +REGISTER_OP("EncodeAudioV2") + .Input("sampled_audio: float") + .Input("file_format: string") + .Input("samples_per_second: int32") + .Input("bits_per_second: int32") + .Output("contents: string") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Processes a `Tensor` containing sampled audio with the number of channels +and length of the audio specified by the dimensions of the `Tensor`. The +audio is converted into a string that, when saved to disk, will be equivalent +to the audio in the specified audio format. + +The input audio has one row of the tensor for each channel in the audio file. +Each channel contains audio samples starting at the beginning of the audio and +having `1/samples_per_second` time between them. The output file will contain +all of the audio channels contained in the tensor. + +sampled_audio: A rank-2 float tensor containing all tracks of the audio. + Dimension 0 is time and dimension 1 is the channel. +file_format: A string or rank-0 string tensor describing the audio file + format. This value must be `"wav"`. +samples_per_second: The number of samples per second that the audio should + have, as an int or rank-0 `int32` tensor. This value must be + positive. +bits_per_second: The approximate bitrate of the encoded audio file, as + an int or rank-0 `int32` tensor. This is ignored by the "wav" file + format. +contents: The binary audio file contents, as a rank-0 string tensor. +)doc"); + +/* + * Deprecated in favor of EncodeAudioOpV2. + */ class EncodeAudioOp : public OpKernel { public: explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) { @@ -55,23 +185,8 @@ class EncodeAudioOp : public OpKernel { "sampled_audio cannot have more than 2^31 entries. Shape = ", contents.shape().DebugString())); - // Create the encoded audio file. - std::vector<float> samples; - samples.reserve(contents.NumElements()); - for (int32 i = 0; i < contents.NumElements(); ++i) { - samples.push_back(contents.flat<float>()(i)); - } - const int32 channel_count = contents.dim_size(1); - string encoded_audio; - OP_REQUIRES_OK(context, CreateAudioFile(file_format_, bits_per_second_, - samples_per_second_, channel_count, - samples, &encoded_audio)); - - // Copy the encoded audio file to the output tensor. - Tensor* output = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape(), &output)); - output->scalar<string>()() = encoded_audio; + Encode(context, contents, file_format_, bits_per_second_, + samples_per_second_); } private: diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op_test.py b/tensorflow/contrib/ffmpeg/encode_audio_op_test.py index 18d992911d..870290dc10 100644 --- a/tensorflow/contrib/ffmpeg/encode_audio_op_test.py +++ b/tensorflow/contrib/ffmpeg/encode_audio_op_test.py @@ -20,13 +20,24 @@ from __future__ import print_function import os.path +import six + from tensorflow.contrib import ffmpeg +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test class EncodeAudioOpTest(test.TestCase): + def setUp(self): + super(EncodeAudioOpTest, self).setUp() + path = os.path.join(resource_loader.get_data_files_path(), + 'testdata/mono_10khz.wav') + with open(path, 'rb') as f: + self._contents = f.read() + def _compareWavFiles(self, original, encoded): """Compares the important bits of two WAV files. @@ -51,20 +62,54 @@ class EncodeAudioOpTest(test.TestCase): def testRoundTrip(self): """Reads a wav file, writes it, and compares them.""" with self.test_session(): - path = os.path.join(resource_loader.get_data_files_path(), - 'testdata/mono_10khz.wav') - with open(path, 'rb') as f: - original_contents = f.read() - audio_op = ffmpeg.decode_audio( - original_contents, + self._contents, file_format='wav', samples_per_second=10000, channel_count=1) encode_op = ffmpeg.encode_audio( audio_op, file_format='wav', samples_per_second=10000) encoded_contents = encode_op.eval() - self._compareWavFiles(original_contents, encoded_contents) + self._compareWavFiles(self._contents, encoded_contents) + + def testRoundTripWithPlaceholderSampleRate(self): + with self.test_session(): + placeholder = array_ops.placeholder(dtypes.int32) + audio_op = ffmpeg.decode_audio( + self._contents, + file_format='wav', + samples_per_second=placeholder, + channel_count=1) + encode_op = ffmpeg.encode_audio( + audio_op, file_format='wav', samples_per_second=placeholder) + encoded_contents = encode_op.eval(feed_dict={placeholder: 10000}) + self._compareWavFiles(self._contents, encoded_contents) + + def testFloatingPointSampleRateInvalid(self): + with self.test_session(): + with self.assertRaises(TypeError): + ffmpeg.encode_audio( + [[0.0], [1.0]], + file_format='wav', + samples_per_second=12345.678) + + def testZeroSampleRateInvalid(self): + with self.test_session() as sess: + encode_op = ffmpeg.encode_audio( + [[0.0], [1.0]], + file_format='wav', + samples_per_second=0) + with six.assertRaisesRegex(self, Exception, 'must be positive'): + sess.run(encode_op) + + def testNegativeSampleRateInvalid(self): + with self.test_session() as sess: + encode_op = ffmpeg.encode_audio( + [[0.0], [1.0]], + file_format='wav', + samples_per_second=-2) + with six.assertRaisesRegex(self, Exception, 'must be positive'): + sess.run(encode_op) if __name__ == '__main__': diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 5f608cdb82..18b0b8b812 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -38,23 +38,26 @@ def decode_audio(contents, file_format=None, samples_per_second=None, Args: contents: The binary contents of the audio file to decode. This is a scalar. - file_format: A string specifying which format the contents will conform - to. This can be mp3, mp4, ogg, or wav. - samples_per_second: The number of samples per second that is assumed. - In some cases, resampling will occur to generate the correct sample - rate. + file_format: A string or scalar string tensor specifying which + format the contents will conform to. This can be mp3, mp4, ogg, + or wav. + samples_per_second: The number of samples per second that is + assumed, as an `int` or scalar `int32` tensor. In some cases, + resampling will occur to generate the correct sample rate. channel_count: The number of channels that should be created from the - audio contents. If the contents have more than this number, then - some channels will be merged or dropped. If contents has fewer than - this, then additional channels will be created from the existing ones. + audio contents, as an `int` or scalar `int32` tensor. If the + `contents` have more than this number, then some channels will + be merged or dropped. If `contents` has fewer than this, then + additional channels will be created from the existing ones. Returns: - A rank 2 tensor that has time along dimension 0 and channels along - dimension 1. Dimension 0 will be `samples_per_second * length` wide, and - dimension 1 will be `channel_count` wide. If ffmpeg fails to decode the - audio then an empty tensor will be returned. + A rank-2 tensor that has time along dimension 0 and channels along + dimension 1. Dimension 0 will be `samples_per_second * + length_in_seconds` wide, and dimension 1 will be `channel_count` + wide. If ffmpeg fails to decode the audio then an empty tensor will + be returned. """ - return gen_decode_audio_op_py.decode_audio( + return gen_decode_audio_op_py.decode_audio_v2( contents, file_format=file_format, samples_per_second=samples_per_second, channel_count=channel_count) @@ -66,19 +69,23 @@ def encode_audio(audio, file_format=None, samples_per_second=None): """Creates an op that encodes an audio file using sampled audio from a tensor. Args: - audio: A rank 2 tensor that has time along dimension 0 and channels along - dimension 1. Dimension 0 is `samples_per_second * length` long in - seconds. - file_format: The type of file to encode. "wav" is the only supported format. - samples_per_second: The number of samples in the audio tensor per second of - audio. + audio: A rank-2 `Tensor` that has time along dimension 0 and + channels along dimension 1. Dimension 0 is `samples_per_second * + length_in_seconds` long. + file_format: The type of file to encode, as a string or rank-0 + string tensor. "wav" is the only supported format. + samples_per_second: The number of samples in the audio tensor per + second of audio, as an `int` or rank-0 `int32` tensor. Returns: A scalar tensor that contains the encoded audio in the specified file format. """ - return gen_encode_audio_op_py.encode_audio( - audio, file_format=file_format, samples_per_second=samples_per_second) + return gen_encode_audio_op_py.encode_audio_v2( + audio, + file_format=file_format, + samples_per_second=samples_per_second, + bits_per_second=192000) # not used by WAV ops.NotDifferentiable('EncodeAudio') |