From 2e7bd84469eba730f24dd3e448cca22f5aed16f4 Mon Sep 17 00:00:00 2001 From: Benjamin Barenblat Date: Tue, 4 Sep 2018 18:29:34 -0400 Subject: Enable error checking for HKDF computations --- btls.cabal | 3 ++- src/BTLS/BoringSSL/HKDF.chs | 5 ++--- src/BTLS/Buffer.hs | 25 ++++++++++++++++++++----- src/Codec/Crypto/HKDF.hs | 33 ++++++++++++++++++++------------- tests/BTLS/Assertions.hs | 27 +++++++++++++++++++++++++++ tests/Codec/Crypto/HKDFTests.hs | 12 +++++++----- tests/Data/HMACTests.hs | 11 ++--------- 7 files changed, 80 insertions(+), 36 deletions(-) create mode 100644 tests/BTLS/Assertions.hs diff --git a/btls.cabal b/btls.cabal index d0c4442..1ff34f0 100644 --- a/btls.cabal +++ b/btls.cabal @@ -125,7 +125,8 @@ test-suite tests -threaded -optl-Wl,-z,relro -optl-Wl,-z,now -optl-Wl,-s main-is: Tests.hs - other-modules: Codec.Crypto.HKDFTests + other-modules: BTLS.Assertions + , Codec.Crypto.HKDFTests , Data.DigestTests , Data.Digest.HashTests , Data.Digest.MD5Tests diff --git a/src/BTLS/BoringSSL/HKDF.chs b/src/BTLS/BoringSSL/HKDF.chs index 1a28ccc..8ad3df2 100644 --- a/src/BTLS/BoringSSL/HKDF.chs +++ b/src/BTLS/BoringSSL/HKDF.chs @@ -22,15 +22,14 @@ import Foreign.C.Types {#import BTLS.BoringSSL.Base#} import BTLS.Buffer (unsafeUseAsCBuffer) -import BTLS.Result #include {#fun HKDF_extract as hkdfExtract { id `Ptr CUChar', id `Ptr CULong', `Ptr EVPMD' , unsafeUseAsCBuffer* `ByteString'&, unsafeUseAsCBuffer* `ByteString'& } - -> `()' requireSuccess*-#} + -> `Int'#} {#fun HKDF_expand as hkdfExpand { id `Ptr CUChar', `Int', `Ptr EVPMD', unsafeUseAsCBuffer* `ByteString'& - , unsafeUseAsCBuffer* `ByteString'& } -> `()' requireSuccess*-#} + , unsafeUseAsCBuffer* `ByteString'& } -> `Int'#} diff --git a/src/BTLS/Buffer.hs b/src/BTLS/Buffer.hs index 7168a10..354c787 100644 --- a/src/BTLS/Buffer.hs +++ b/src/BTLS/Buffer.hs @@ -15,9 +15,11 @@ module BTLS.Buffer ( unsafeUseAsCBuffer , packCUStringLen - , onBufferOfMaxSize + , onBufferOfMaxSize, onBufferOfMaxSize' ) where +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Except (ExceptT, runExceptT) import Data.ByteString (ByteString) import qualified Data.ByteString as ByteString import qualified Data.ByteString.Unsafe as ByteString @@ -44,9 +46,22 @@ onBufferOfMaxSize :: => Int -> (Ptr CUChar -> Ptr size -> IO ()) -> IO ByteString -onBufferOfMaxSize maxSize f = +onBufferOfMaxSize maxSize f = do + Right r <- onBufferOfMaxSize' maxSize (compose2 lift f) + return r + +-- | Like 'onBufferOfMaxSize' but may fail. +onBufferOfMaxSize' :: + (Integral size, Storable size) + => Int + -> (Ptr CUChar -> Ptr size -> ExceptT e IO ()) + -> IO (Either e ByteString) +onBufferOfMaxSize' maxSize f = allocaArray maxSize $ \pOut -> - alloca $ \pOutLen -> do + alloca $ \pOutLen -> runExceptT $ do f pOut pOutLen - outLen <- peek pOutLen - packCUStringLen (pOut, outLen) + outLen <- lift $ peek pOutLen + lift $ packCUStringLen (pOut, outLen) + +compose2 :: (r -> r') -> (a -> b -> r) -> a -> b -> r' +compose2 f g = \a b -> f (g a b) diff --git a/src/Codec/Crypto/HKDF.hs b/src/Codec/Crypto/HKDF.hs index bc1ea61..31d0be3 100644 --- a/src/Codec/Crypto/HKDF.hs +++ b/src/Codec/Crypto/HKDF.hs @@ -48,16 +48,23 @@ module Codec.Crypto.HKDF -- to do so, specify the empty string as the associated data. , AssociatedData(AssociatedData) + -- * Error handling + , Error + -- * Legacy functions , md5 ) where +import Control.Monad ((>=>)) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Except (runExceptT) import Foreign (allocaArray) import Foreign.Marshal.Unsafe (unsafeLocalState) import BTLS.BoringSSL.Digest (evpMaxMDSize) import BTLS.BoringSSL.HKDF -import BTLS.Buffer (onBufferOfMaxSize, packCUStringLen) +import BTLS.Buffer (onBufferOfMaxSize', packCUStringLen) +import BTLS.Result (Error, check) import BTLS.Types ( Algorithm(Algorithm), AssociatedData(AssociatedData), Salt(Salt) , SecretKey(SecretKey), noSalt @@ -66,7 +73,7 @@ import Data.Digest (md5, sha1, sha224, sha256, sha384, sha512) -- | Computes an HKDF. It is defined by -- --- prop> hkdf md salt info len = expand md info len . extract md salt +-- prop> hkdf md salt info len = extract md salt >=> expand md info len -- -- but may be faster than calling the two functions individually. hkdf :: @@ -75,16 +82,16 @@ hkdf :: -> AssociatedData -> Int -- ^ The length of the derived key, in bytes. -> SecretKey - -> SecretKey -hkdf md salt info outLen = expand md info outLen . extract md salt + -> Either [Error] SecretKey +hkdf md salt info outLen = extract md salt >=> expand md info outLen -- | Computes an HKDF pseudorandom key (PRK). -extract :: Algorithm -> Salt -> SecretKey -> SecretKey +extract :: Algorithm -> Salt -> SecretKey -> Either [Error] SecretKey extract (Algorithm md) (Salt salt) (SecretKey secret) = - SecretKey $ + fmap SecretKey $ unsafeLocalState $ - onBufferOfMaxSize evpMaxMDSize $ \pOutKey pOutLen -> do - hkdfExtract pOutKey pOutLen md secret salt + onBufferOfMaxSize' evpMaxMDSize $ \pOutKey pOutLen -> + check $ hkdfExtract pOutKey pOutLen md secret salt -- | Computes HKDF output key material (OKM). expand :: @@ -92,10 +99,10 @@ expand :: -> AssociatedData -> Int -- ^ The length of the OKM, in bytes. -> SecretKey - -> SecretKey + -> Either [Error] SecretKey expand (Algorithm md) (AssociatedData info) outLen (SecretKey secret) = - SecretKey $ + fmap SecretKey $ unsafeLocalState $ - allocaArray outLen $ \pOutKey -> do - hkdfExpand pOutKey outLen md secret info - packCUStringLen (pOutKey, outLen) + allocaArray outLen $ \pOutKey -> runExceptT $ do + check $ hkdfExpand pOutKey outLen md secret info + lift $ packCUStringLen (pOutKey, outLen) diff --git a/tests/BTLS/Assertions.hs b/tests/BTLS/Assertions.hs new file mode 100644 index 0000000..23ceb8d --- /dev/null +++ b/tests/BTLS/Assertions.hs @@ -0,0 +1,27 @@ +-- Copyright 2018 Google LLC +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); you may not +-- use this file except in compliance with the License. You may obtain a copy of +-- the License at +-- +-- https://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +-- WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +-- License for the specific language governing permissions and limitations under +-- the License. + +module BTLS.Assertions + ( isRightAndHolds + ) where + +import Control.Monad (unless) +import Test.Tasty.HUnit (Assertion, assertFailure) + +isRightAndHolds :: (Eq a, Show a, Show e) => Either e a -> a -> Assertion +actual@(Left _) `isRightAndHolds` _ = + assertFailure ("expected: Right _\n but got: " ++ show actual) +Right actual `isRightAndHolds` expected = + unless (expected == actual) $ + assertFailure ("expected: Right " ++ show expected ++ "\n but got: Right " ++ show actual) diff --git a/tests/Codec/Crypto/HKDFTests.hs b/tests/Codec/Crypto/HKDFTests.hs index 44a41cd..d4cfe30 100644 --- a/tests/Codec/Crypto/HKDFTests.hs +++ b/tests/Codec/Crypto/HKDFTests.hs @@ -21,8 +21,9 @@ import qualified Data.ByteString as ByteString import qualified Data.ByteString.Base16 as ByteString.Base16 import qualified Data.ByteString.Char8 as ByteString.Char8 import Test.Tasty (TestTree, testGroup) -import Test.Tasty.HUnit ((@?=), testCase) +import Test.Tasty.HUnit (testCase) +import BTLS.Assertions (isRightAndHolds) import Codec.Crypto.HKDF (AssociatedData(AssociatedData), Salt(Salt), SecretKey(SecretKey), noSalt) import qualified Codec.Crypto.HKDF as HKDF import Data.Digest (sha1, sha256) @@ -91,10 +92,11 @@ testRFC5869 = testGroup "RFC 5869 examples" ] where t name hash ikm salt info len prk okm = - testGroup name [ testCase "hkdf" $ HKDF.hkdf hash salt info len ikm @?= okm - , testCase "extract" $ HKDF.extract hash salt ikm @?= prk - , testCase "expand" $ HKDF.expand hash info len prk @?= okm - ] + testGroup name + [ testCase "hkdf" $ HKDF.hkdf hash salt info len ikm `isRightAndHolds` okm + , testCase "extract" $ HKDF.extract hash salt ikm `isRightAndHolds` prk + , testCase "expand" $ HKDF.expand hash info len prk `isRightAndHolds` okm + ] hex :: ByteString -> ByteString hex s = diff --git a/tests/Data/HMACTests.hs b/tests/Data/HMACTests.hs index 10500e4..55e795e 100644 --- a/tests/Data/HMACTests.hs +++ b/tests/Data/HMACTests.hs @@ -16,13 +16,13 @@ module Data.HMACTests (tests) where -import Control.Monad (unless) import qualified Data.ByteString as ByteString import qualified Data.ByteString.Lazy as ByteString.Lazy import qualified Data.ByteString.Lazy.Char8 as ByteString.Lazy.Char8 import Test.Tasty (TestTree, testGroup) -import Test.Tasty.HUnit (Assertion, assertFailure, testCase) +import Test.Tasty.HUnit (testCase) +import BTLS.Assertions (isRightAndHolds) import Data.Digest (md5, sha1, sha224, sha256, sha384, sha512) import Data.HMAC (Error, SecretKey(SecretKey), hmac) @@ -191,10 +191,3 @@ hmacSha224 key bytes = show <$> hmac sha224 key bytes hmacSha256 key bytes = show <$> hmac sha256 key bytes hmacSha384 key bytes = show <$> hmac sha384 key bytes hmacSha512 key bytes = show <$> hmac sha512 key bytes - -isRightAndHolds :: (Eq a, Show a, Show e) => Either e a -> a -> Assertion -actual@(Left _) `isRightAndHolds` _ = - assertFailure ("expected: Right _\n but got: " ++ show actual) -Right actual `isRightAndHolds` expected = - unless (expected == actual) $ - assertFailure ("expected: Right " ++ show expected ++ "\n but got: Right " ++ show actual) -- cgit v1.2.3