From 597ae83b60dd28706cd29e4df5c69204a4b983a8 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Tue, 30 Aug 2011 16:23:07 +0900 Subject: attoparsec + enumerator. --- Network/DNS/Lookup.hs | 9 ++-- Network/DNS/Query.hs | 4 +- Network/DNS/Resolver.hs | 10 +++-- Network/DNS/Response.hs | 32 ++++++++------ Network/DNS/StateBinary.hs | 105 ++++++++++++++++++++++++--------------------- dns.cabal | 8 +++- 6 files changed, 93 insertions(+), 75 deletions(-) diff --git a/Network/DNS/Lookup.hs b/Network/DNS/Lookup.hs index b61d543..79af339 100644 --- a/Network/DNS/Lookup.hs +++ b/Network/DNS/Lookup.hs @@ -66,11 +66,10 @@ lookupXviaMX rlv dom func = do mdps <- lookupMX rlv dom maybe (return Nothing) lookup' mdps where - lookup' dps = do - as <- catMaybes <$> mapM (func . fst) dps - case as of - [] -> return Nothing - ass -> return $ Just (concat ass) + lookup' dps = check . catMaybes <$> mapM (func . fst) dps + check as = case as of + [] -> Nothing + ass -> Just (concat ass) ---------------------------------------------------------------- diff --git a/Network/DNS/Query.hs b/Network/DNS/Query.hs index 7c27b63..30c7e24 100644 --- a/Network/DNS/Query.hs +++ b/Network/DNS/Query.hs @@ -62,13 +62,13 @@ encodeQuestion qs = encodeDomain dom ---------------------------------------------------------------- encodeDomain :: Domain -> SPut -encodeDomain dom = foldr (+++) (put8 0) (map encodeSubDomain $ zip ls ss) +encodeDomain dom = foldr ((+++) . encodeSubDomain) (put8 0) $ zip ls ss where ss = filter (not . BS.null) $ BS.split '.' dom ls = map BS.length ss encodeSubDomain :: (Int, Domain) -> SPut encodeSubDomain (len,sub) = putInt8 len - +++ foldr (+++) mempty (map put8 ss) + +++ foldr ((+++) . put8) mempty ss where ss = BS.unpack sub diff --git a/Network/DNS/Resolver.hs b/Network/DNS/Resolver.hs index 1d19362..5919088 100644 --- a/Network/DNS/Resolver.hs +++ b/Network/DNS/Resolver.hs @@ -40,6 +40,7 @@ import Network.Socket.ByteString.Lazy import Prelude hiding (lookup) import System.Random import System.Timeout +import Network.Socket.Enumerator ---------------------------------------------------------------- @@ -55,7 +56,7 @@ data FileOrNumericHost = RCFilePath FilePath | RCHostName HostName data ResolvConf = ResolvConf { resolvInfo :: FileOrNumericHost , resolvTimeout :: Int - , resolvBufsize :: Int64 + , resolvBufsize :: Integer } {-| @@ -79,7 +80,7 @@ defaultResolvConf = ResolvConf { data ResolvSeed = ResolvSeed { addrInfo :: AddrInfo , rsTimeout :: Int - , rsBufsize :: Int64 + , rsBufsize :: Integer } {-| @@ -89,7 +90,7 @@ data Resolver = Resolver { genId :: IO Int , dnsSock :: Socket , dnsTimeout :: Int - , dnsBufsize :: Int64 + , dnsBufsize :: Integer } ---------------------------------------------------------------- @@ -169,7 +170,8 @@ lookupRaw :: Resolver -> Domain -> TYPE -> IO (Maybe DNSFormat) lookupRaw rlv dom typ = do seqno <- genId rlv sendAll sock (composeQuery seqno [q]) - (>>= check seqno) <$> timeout tm (parseResponse <$> recv sock bufsize) + let responseEnum = enumSocket bufsize sock + (>>= check seqno) <$> timeout tm (parseResponse responseEnum responseIter) where sock = dnsSock rlv bufsize = dnsBufsize rlv diff --git a/Network/DNS/Response.hs b/Network/DNS/Response.hs index 629cd92..d746fc2 100644 --- a/Network/DNS/Response.hs +++ b/Network/DNS/Response.hs @@ -1,20 +1,25 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.DNS.Response (parseResponse) where +module Network.DNS.Response (responseIter, parseResponse) where +import Control.Applicative import Control.Monad import Data.Bits import qualified Data.ByteString.Char8 as BS -import qualified Data.ByteString.Lazy.Char8 as L import Data.IP import Data.Maybe import Network.DNS.Internal import Network.DNS.StateBinary +import Data.Enumerator (Enumerator, Iteratee, run_, ($$)) +import Data.ByteString (ByteString) ----------------------------------------------------------------- +responseIter :: Iteratee ByteString IO (DNSFormat, PState) +responseIter = runSGet decodeResponse -parseResponse :: L.ByteString -> DNSFormat -parseResponse = runSGet decodeResponse +parseResponse :: Enumerator ByteString IO (a,b) + -> Iteratee ByteString IO (a,b) + -> IO a +parseResponse enum iter = fst <$> run_ (enum $$ iter) ---------------------------------------------------------------- @@ -29,16 +34,15 @@ decodeResponse = do ---------------------------------------------------------------- decodeFlags :: SGet DNSFlags -decodeFlags = do - flgs <- get16 - return $ DNSFlags (getQorR flgs) - (getOpcode flgs) - (getAuthAnswer flgs) - (getTrunCation flgs) - (getRecDesired flgs) - (getRecAvailable flgs) - (getRcode flgs) +decodeFlags = toFlags <$> get16 where + toFlags flgs = DNSFlags (getQorR flgs) + (getOpcode flgs) + (getAuthAnswer flgs) + (getTrunCation flgs) + (getRecDesired flgs) + (getRecAvailable flgs) + (getRcode flgs) getQorR w = if testBit w 15 then QR_Response else QR_Query getOpcode w = toEnum $ fromIntegral $ shiftR w 11 .&. 0x0f getAuthAnswer w = testBit w 10 diff --git a/Network/DNS/StateBinary.hs b/Network/DNS/StateBinary.hs index 2bd610c..e214dbf 100644 --- a/Network/DNS/StateBinary.hs +++ b/Network/DNS/StateBinary.hs @@ -1,41 +1,20 @@ module Network.DNS.StateBinary where import Blaze.ByteString.Builder +import Control.Applicative import Control.Monad.State -import Data.Binary.Get +import Data.Attoparsec +import Data.Attoparsec.Enumerator import Data.ByteString (ByteString) -import qualified Data.ByteString.Char8 as BS -import qualified Data.ByteString.Lazy.Char8 as L -import Data.Char +import qualified Data.ByteString as BS (unpack) +import qualified Data.ByteString.Lazy as L (ByteString) +import Data.Enumerator (Iteratee) import Data.Int import Data.IntMap (IntMap) import qualified Data.IntMap as IM (insert, lookup, empty) import Data.Word import Network.DNS.Types -import Prelude hiding (lookup) - ----------------------------------------------------------------- - -type SGet = StateT PState Get - -type PState = IntMap Domain - ----------------------------------------------------------------- - -(<$>) :: (Monad m) => (a -> b) -> m a -> m b -(<$>) = liftM - -(<$) :: (Monad m) => b -> m a -> m b -x <$ y = y >> return x - -(<*>) :: (Monad m) => m (a -> b) -> m a -> m b -(<*>) = ap - -(<*) :: (Monad m) => m a -> m b -> m a -(<*) ma mb = do - a <- ma - mb - return a +import Prelude hiding (lookup, take) ---------------------------------------------------------------- @@ -61,14 +40,55 @@ putInt32 = writeInt32be . fromIntegral ---------------------------------------------------------------- +type SGet = StateT PState Parser + +data PState = PState { + psDomain :: IntMap Domain + , psPosition :: Int + } + +---------------------------------------------------------------- + +getPosition :: SGet Int +getPosition = psPosition <$> get + +addPosition :: Int -> SGet () +addPosition n = do + PState dom pos <- get + put $ PState dom (pos + n) + +push :: Int -> Domain -> SGet () +push n d = do + PState dom pos <- get + put $ PState (IM.insert n d dom) pos + +pop :: Int -> SGet (Maybe Domain) +pop n = IM.lookup n . psDomain <$> get + +---------------------------------------------------------------- + get8 :: SGet Word8 -get8 = lift getWord8 +get8 = lift anyWord8 <* addPosition 1 get16 :: SGet Word16 -get16 = lift getWord16be +get16 = lift getWord16be <* addPosition 2 + where + word8' = fromIntegral <$> anyWord8 + getWord16be = do + a <- word8' + b <- word8' + return $ a * 256 + b get32 :: SGet Word32 -get32 = lift getWord32be +get32 = lift getWord32be <* addPosition 4 + where + word8' = fromIntegral <$> anyWord8 + getWord32be = do + a <- word8' + b <- word8' + c <- word8' + d <- word8' + return $ a * 1677721 + b * 65536 + c * 256 + d getInt8 :: SGet Int getInt8 = fromIntegral <$> get8 @@ -81,32 +101,21 @@ getInt32 = fromIntegral <$> get32 ---------------------------------------------------------------- -getPosition :: SGet Int -getPosition = fromIntegral <$> lift bytesRead - getNBytes :: Int -> SGet [Int] getNBytes len = toInts <$> getNByteString len where - toInts = map ord . BS.unpack + toInts = map fromIntegral . BS.unpack getNByteString :: Int -> SGet ByteString -getNByteString = lift . getByteString . fromIntegral - ----------------------------------------------------------------- - -push :: Int -> Domain -> SGet () -push n d = modify (IM.insert n d) - -pop :: Int -> SGet (Maybe Domain) -pop n = IM.lookup n <$> get +getNByteString n = lift (take n) <* addPosition n ---------------------------------------------------------------- -initialState :: IntMap Domain -initialState = IM.empty +initialState :: PState +initialState = PState IM.empty 0 -runSGet :: SGet DNSFormat -> L.ByteString -> DNSFormat -runSGet res bs = fst $ runGet (runStateT res initialState) bs +runSGet :: SGet a -> Iteratee ByteString IO (a, PState) +runSGet parser = iterParser (runStateT parser initialState) runSPut :: SPut -> L.ByteString runSPut = toLazyByteString . fromWrite diff --git a/dns.cabal b/dns.cabal index 316dcac..100c6a9 100644 --- a/dns.cabal +++ b/dns.cabal @@ -30,12 +30,16 @@ library Build-Depends: base >= 4 && < 5, binary, iproute, containers, mtl, bytestring, random, - network >= 2.3, blaze-builder + network >= 2.3, blaze-builder, + attoparsec, enumerator, attoparsec-enumerator, + network-enumerator else Build-Depends: base >= 4 && < 5, binary, iproute, containers, mtl, bytestring, random, - network, network-bytestring, blaze-builder + network, network-bytestring, blaze-builder, + attoparsec, enumerator, attoparsec-enumerator, + network-enumerator Source-Repository head Type: git Location: git://github.com/kazu-yamamoto/dns.git -- cgit v1.2.3