From ee8d2b9c69dd7dd4d47e5d88f47150770a15129a Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Mon, 24 Oct 2011 15:57:04 +0900 Subject: Supporting server side and brushing up. --- Network/DNS.hs | 13 +++- Network/DNS/Decode.hs | 188 ++++++++++++++++++++++++++++++++++++++++++++++++ Network/DNS/Encode.hs | 140 ++++++++++++++++++++++++++++++++++++ Network/DNS/Query.hs | 150 -------------------------------------- Network/DNS/Resolver.hs | 24 +++++-- Network/DNS/Response.hs | 175 -------------------------------------------- SimpleServer.hs | 31 ++++---- TestProtocol.hs | 10 ++- dns.cabal | 11 ++- 9 files changed, 378 insertions(+), 364 deletions(-) create mode 100644 Network/DNS/Decode.hs create mode 100644 Network/DNS/Encode.hs delete mode 100644 Network/DNS/Query.hs delete mode 100644 Network/DNS/Response.hs diff --git a/Network/DNS.hs b/Network/DNS.hs index acf56c6..73b44db 100644 --- a/Network/DNS.hs +++ b/Network/DNS.hs @@ -1,10 +1,9 @@ {-| Thread-safe DNS library written in Haskell. - Currently, only resolver side is supported. This code is written in - Haskell, not using FFI. + This code is written in Haskell, not using FFI. - Sample code: + Sample code for DNS lookup: @ import qualified Network.DNS as DNS (lookup) @@ -18,11 +17,19 @@ -} module Network.DNS ( + -- * High level module Network.DNS.Lookup , module Network.DNS.Resolver , module Network.DNS.Types + -- * Low level + , module Network.DNS.Decode + , module Network.DNS.Encode ) where import Network.DNS.Lookup import Network.DNS.Resolver import Network.DNS.Types +import Network.DNS.Decode +import Network.DNS.Encode + + diff --git a/Network/DNS/Decode.hs b/Network/DNS/Decode.hs new file mode 100644 index 0000000..c84a06d --- /dev/null +++ b/Network/DNS/Decode.hs @@ -0,0 +1,188 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Network.DNS.Decode ( + receive + , decode + ) where + +import Control.Applicative +import Control.Monad +import Data.Bits +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS +import qualified Data.ByteString.Lazy as BL +import Data.Enumerator (Enumerator, run_, ($$)) +import Data.IP +import Data.Maybe +import Network +import Network.DNS.Internal +import Network.DNS.StateBinary +import Network.Socket.Enumerator + +---------------------------------------------------------------- + +{-| Receiving DNS data from 'Socket' and parse it. + The second argument is a buffer size for the socket. +-} +receive :: Socket -> Integer -> IO DNSFormat +receive sock bufsize = receiveDNSFormat responseEnum + where + responseEnum = enumSocket bufsize sock + +{-| Parsing DNS data. +-} +decode :: BL.ByteString -> Either String DNSFormat +decode bs = fst <$> runSGet decodeResponse bs + +---------------------------------------------------------------- + +receiveDNSFormat :: Enumerator ByteString IO (DNSFormat, PState) + -> IO DNSFormat +receiveDNSFormat enum = fst <$> run_ (enum $$ iter) + where + iter = iterSGet decodeResponse + +---------------------------------------------------------------- + +decodeResponse :: SGet DNSFormat +decodeResponse = do + hd <- decodeHeader + DNSFormat hd <$> decodeQueries (qdCount hd) + <*> decodeRRs (anCount hd) + <*> decodeRRs (nsCount hd) + <*> decodeRRs (arCount hd) + +---------------------------------------------------------------- + +decodeFlags :: SGet DNSFlags +decodeFlags = toFlags <$> get16 + where + toFlags flgs = DNSFlags (getQorR flgs) + (getOpcode flgs) + (getAuthAnswer flgs) + (getTrunCation flgs) + (getRecDesired flgs) + (getRecAvailable flgs) + (getRcode flgs) + getQorR w = if testBit w 15 then QR_Response else QR_Query + getOpcode w = toEnum $ fromIntegral $ shiftR w 11 .&. 0x0f + getAuthAnswer w = testBit w 10 + getTrunCation w = testBit w 9 + getRecDesired w = testBit w 8 + getRecAvailable w = testBit w 7 + getRcode w = toEnum $ fromIntegral $ w .&. 0x0f + +---------------------------------------------------------------- + +decodeHeader :: SGet DNSHeader +decodeHeader = DNSHeader <$> decodeIdentifier + <*> decodeFlags + <*> decodeQdCount + <*> decodeAnCount + <*> decodeNsCount + <*> decodeArCount + where + decodeIdentifier = getInt16 + decodeQdCount = getInt16 + decodeAnCount = getInt16 + decodeNsCount = getInt16 + decodeArCount = getInt16 + +---------------------------------------------------------------- + +decodeQueries :: Int -> SGet [Question] +decodeQueries n = replicateM n decodeQuery + +decodeType :: SGet TYPE +decodeType = intToType <$> getInt16 + +decodeQuery :: SGet Question +decodeQuery = Question <$> decodeDomain + <*> (decodeType <* ignoreClass) + +decodeRRs :: Int -> SGet [ResourceRecord] +decodeRRs n = replicateM n decodeRR + +decodeRR :: SGet ResourceRecord +decodeRR = do + Question dom typ <- decodeQuery + ttl <- decodeTTL + len <- decodeRLen + dat <- decodeRData typ len + return ResourceRecord { rrname = dom + , rrtype = typ + , rrttl = ttl + , rdlen = len + , rdata = dat + } + where + decodeTTL = fromIntegral <$> get32 + decodeRLen = getInt16 + +decodeRData :: TYPE -> Int -> SGet RDATA +decodeRData NS _ = RD_NS <$> decodeDomain +decodeRData MX _ = RD_MX <$> decodePreference <*> decodeDomain + where + decodePreference = getInt16 +decodeRData CNAME _ = RD_CNAME <$> decodeDomain +decodeRData TXT len = (RD_TXT . ignoreLength) <$> getNByteString len + where + ignoreLength = BS.tail +decodeRData A len = (RD_A . toIPv4) <$> getNBytes len +decodeRData AAAA len = (RD_AAAA . toIPv6 . combine) <$> getNBytes len + where + combine [] = [] + combine [_] = error "combine" + combine (a:b:cs) = a * 256 + b : combine cs +decodeRData SOA _ = RD_SOA <$> decodeDomain + <*> decodeDomain + <*> decodeSerial + <*> decodeRefesh + <*> decodeRetry + <*> decodeExpire + <*> decodeMinumun + where + decodeSerial = getInt32 + decodeRefesh = getInt32 + decodeRetry = getInt32 + decodeExpire = getInt32 + decodeMinumun = getInt32 +decodeRData PTR _ = RD_PTR <$> decodeDomain +decodeRData SRV _ = RD_SRV <$> decodePriority + <*> decodeWeight + <*> decodePort + <*> decodeDomain + where + decodePriority = getInt16 + decodeWeight = getInt16 + decodePort = getInt16 + +decodeRData _ len = RD_OTH <$> getNBytes len + +---------------------------------------------------------------- + +decodeDomain :: SGet Domain +decodeDomain = do + pos <- getPosition + c <- getInt8 + if c == 0 + then return "" + else do + let n = getValue c + if isPointer c + then do + d <- getInt8 + let offset = n * 256 + d + fromMaybe (error $ "decodeDomain: " ++ show offset) <$> pop offset + else do + hs <- getNByteString n + ds <- decodeDomain + let dom = hs `BS.append` "." `BS.append` ds + push pos dom + return dom + where + getValue c = c .&. 0x3f + isPointer c = testBit c 7 && testBit c 6 + +ignoreClass :: SGet () +ignoreClass = () <$ get16 diff --git a/Network/DNS/Encode.hs b/Network/DNS/Encode.hs new file mode 100644 index 0000000..6116a04 --- /dev/null +++ b/Network/DNS/Encode.hs @@ -0,0 +1,140 @@ +{-# LANGUAGE RecordWildCards #-} +module Network.DNS.Encode (encode) where + +import qualified Data.ByteString.Lazy.Char8 as BL (ByteString) +import qualified Data.ByteString.Char8 as BS (length, null, break, drop) +import Network.DNS.StateBinary +import Network.DNS.Internal +import Data.Monoid +import Control.Monad.State +import Data.Bits +import Data.Word +import Data.IP + +(+++) :: Monoid a => a -> a -> a +(+++) = mappend + +---------------------------------------------------------------- + +{-| Composing DNS data. +-} +encode :: DNSFormat -> BL.ByteString +encode fmt = runSPut (encodeDNSFormat fmt) + +---------------------------------------------------------------- + +encodeDNSFormat :: DNSFormat -> SPut +encodeDNSFormat fmt = encodeHeader hdr + +++ mconcat (map encodeQuestion qs) + +++ mconcat (map encodeRR an) + +++ mconcat (map encodeRR au) + +++ mconcat (map encodeRR ad) + where + hdr = header fmt + qs = question fmt + an = answer fmt + au = authority fmt + ad = additional fmt + +encodeHeader :: DNSHeader -> SPut +encodeHeader hdr = encodeIdentifier (identifier hdr) + +++ encodeFlags (flags hdr) + +++ decodeQdCount (qdCount hdr) + +++ decodeAnCount (anCount hdr) + +++ decodeNsCount (nsCount hdr) + +++ decodeArCount (arCount hdr) + where + encodeIdentifier = putInt16 + decodeQdCount = putInt16 + decodeAnCount = putInt16 + decodeNsCount = putInt16 + decodeArCount = putInt16 + +encodeFlags :: DNSFlags -> SPut +encodeFlags DNSFlags{..} = put16 word + where + word16 :: Enum a => a -> Word16 + word16 = toEnum . fromEnum + + set :: Word16 -> State Word16 () + set byte = modify (.|. byte) + + st :: State Word16 () + st = sequence_ + [ set (word16 rcode) + , when recAvailable $ set (bit 7) + , when recDesired $ set (bit 8) + , when trunCation $ set (bit 9) + , when authAnswer $ set (bit 10) + , set (word16 opcode `shiftL` 11) + , when (qOrR==QR_Response) $ set (bit 15) + ] + + word = execState st 0 + +encodeQuestion :: Question -> SPut +encodeQuestion Question{..} = + encodeDomain qname + +++ putInt16 (typeToInt qtype) + +++ put16 1 + +encodeRR :: ResourceRecord -> SPut +encodeRR ResourceRecord{..} = + mconcat + [ encodeDomain rrname + , putInt16 (typeToInt rrtype) + , put16 1 + , putInt32 rrttl + , putInt16 rdlen + , encodeRDATA rdata + ] + +encodeRDATA :: RDATA -> SPut +encodeRDATA rd = case rd of + (RD_A ip) -> mconcat $ map putInt8 (fromIPv4 ip) + (RD_AAAA ip) -> mconcat $ map putInt16 (fromIPv6 ip) + (RD_NS dom) -> encodeDomain dom + (RD_CNAME dom) -> encodeDomain dom + (RD_PTR dom) -> encodeDomain dom + (RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom] + (RD_TXT txt) -> putByteString txt + (RD_OTH bytes) -> mconcat $ map putInt8 bytes + (RD_SOA d1 d2 serial refresh retry expire min') -> mconcat + [ encodeDomain d1 + , encodeDomain d2 + , putInt32 serial + , putInt32 refresh + , putInt32 retry + , putInt32 expire + , putInt32 min' + ] + (RD_SRV prio weight port dom) -> mconcat + [ putInt16 prio + , putInt16 weight + , putInt16 port + , encodeDomain dom + ] + +---------------------------------------------------------------- + +encodeDomain :: Domain -> SPut +encodeDomain dom | BS.null dom = put8 0 +encodeDomain dom = do + mpos <- wsPop dom + cur <- gets wsPosition + case mpos of + Just pos -> encodePointer pos + Nothing -> wsPush dom cur >> + mconcat [ encodePartialDomain hd + , encodeDomain tl + ] + where + (hd, tl') = BS.break (=='.') dom + tl = if BS.null tl' then tl' else BS.drop 1 tl' + +encodePointer :: Int -> SPut +encodePointer pos = let w = (pos .|. 0xc000) in putInt16 w + +encodePartialDomain :: Domain -> SPut +encodePartialDomain sub = putInt8 (BS.length sub) + +++ putByteString sub diff --git a/Network/DNS/Query.hs b/Network/DNS/Query.hs deleted file mode 100644 index bdf86bc..0000000 --- a/Network/DNS/Query.hs +++ /dev/null @@ -1,150 +0,0 @@ -{-# LANGUAGE RecordWildCards #-} -module Network.DNS.Query (composeQuery, composeDNSFormat) where - -import qualified Data.ByteString.Lazy.Char8 as BL (ByteString) -import qualified Data.ByteString.Char8 as BS (length, null, break, drop) -import Network.DNS.StateBinary -import Network.DNS.Internal -import Data.Monoid -import Control.Monad.State -import Data.Bits -import Data.Word -import Data.IP - -(+++) :: Monoid a => a -> a -> a -(+++) = mappend - ----------------------------------------------------------------- - -composeDNSFormat :: DNSFormat -> BL.ByteString -composeDNSFormat fmt = runSPut (encodeDNSFormat fmt) - -composeQuery :: Int -> [Question] -> BL.ByteString -composeQuery idt qs = composeDNSFormat qry - where - hdr = header defaultQuery - qry = defaultQuery { - header = hdr { - identifier = idt - , qdCount = length qs - } - , question = qs - } - ----------------------------------------------------------------- - -encodeDNSFormat :: DNSFormat -> SPut -encodeDNSFormat fmt = encodeHeader hdr - +++ mconcat (map encodeQuestion qs) - +++ mconcat (map encodeRR an) - +++ mconcat (map encodeRR au) - +++ mconcat (map encodeRR ad) - where - hdr = header fmt - qs = question fmt - an = answer fmt - au = authority fmt - ad = additional fmt - -encodeHeader :: DNSHeader -> SPut -encodeHeader hdr = encodeIdentifier (identifier hdr) - +++ encodeFlags (flags hdr) - +++ decodeQdCount (qdCount hdr) - +++ decodeAnCount (anCount hdr) - +++ decodeNsCount (nsCount hdr) - +++ decodeArCount (arCount hdr) - where - encodeIdentifier = putInt16 - decodeQdCount = putInt16 - decodeAnCount = putInt16 - decodeNsCount = putInt16 - decodeArCount = putInt16 - -encodeFlags :: DNSFlags -> SPut -encodeFlags DNSFlags{..} = put16 word - where - word16 :: Enum a => a -> Word16 - word16 = toEnum . fromEnum - - set :: Word16 -> State Word16 () - set byte = modify (.|. byte) - - st :: State Word16 () - st = sequence_ - [ set (word16 rcode) - , when recAvailable $ set (bit 7) - , when recDesired $ set (bit 8) - , when trunCation $ set (bit 9) - , when authAnswer $ set (bit 10) - , set (word16 opcode `shiftL` 11) - , when (qOrR==QR_Response) $ set (bit 15) - ] - - word = execState st 0 - -encodeQuestion :: Question -> SPut -encodeQuestion Question{..} = - encodeDomain qname - +++ putInt16 (typeToInt qtype) - +++ put16 1 - -encodeRR :: ResourceRecord -> SPut -encodeRR ResourceRecord{..} = - mconcat - [ encodeDomain rrname - , putInt16 (typeToInt rrtype) - , put16 1 - , putInt32 rrttl - , putInt16 rdlen - , encodeRDATA rdata - ] - -encodeRDATA :: RDATA -> SPut -encodeRDATA rd = case rd of - (RD_A ip) -> mconcat $ map putInt8 (fromIPv4 ip) - (RD_AAAA ip) -> mconcat $ map putInt16 (fromIPv6 ip) - (RD_NS dom) -> encodeDomain dom - (RD_CNAME dom) -> encodeDomain dom - (RD_PTR dom) -> encodeDomain dom - (RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom] - (RD_TXT txt) -> putByteString txt - (RD_OTH bytes) -> mconcat $ map putInt8 bytes - (RD_SOA d1 d2 serial refresh retry expire min') -> mconcat - [ encodeDomain d1 - , encodeDomain d2 - , putInt32 serial - , putInt32 refresh - , putInt32 retry - , putInt32 expire - , putInt32 min' - ] - (RD_SRV prio weight port dom) -> mconcat - [ putInt16 prio - , putInt16 weight - , putInt16 port - , encodeDomain dom - ] - ----------------------------------------------------------------- - -encodeDomain :: Domain -> SPut -encodeDomain dom | BS.null dom = put8 0 -encodeDomain dom = do - mpos <- wsPop dom - cur <- gets wsPosition - case mpos of - Just pos -> encodePointer pos - Nothing -> wsPush dom cur >> - mconcat [ encodePartialDomain hd - , encodeDomain tl - ] - where - (hd, tl') = BS.break (=='.') dom - tl = if BS.null tl' then tl' else BS.drop 1 tl' - -encodePointer :: Int -> SPut -encodePointer pos = let w = (pos .|. 0xc000) in putInt16 w - -encodePartialDomain :: Domain -> SPut -encodePartialDomain sub = putInt8 (BS.length sub) - +++ putByteString sub diff --git a/Network/DNS/Resolver.hs b/Network/DNS/Resolver.hs index b3182f6..44aee7d 100644 --- a/Network/DNS/Resolver.hs +++ b/Network/DNS/Resolver.hs @@ -28,19 +28,20 @@ module Network.DNS.Resolver ( import Control.Applicative import Control.Exception +import qualified Data.ByteString.Lazy as BL import Data.Char import Data.Int import Data.List hiding (find, lookup) import Network.BSD -import Network.DNS.Query -import Network.DNS.Response +import Network.DNS.Decode +import Network.DNS.Encode +import Network.DNS.Internal import Network.DNS.Types import Network.Socket hiding (send, sendTo, recv, recvFrom) import Network.Socket.ByteString.Lazy import Prelude hiding (lookup) import System.Random import System.Timeout -import Network.Socket.Enumerator ---------------------------------------------------------------- @@ -170,8 +171,7 @@ lookupRaw :: Resolver -> Domain -> TYPE -> IO (Maybe DNSFormat) lookupRaw rlv dom typ = do seqno <- genId rlv sendAll sock (composeQuery seqno [q]) - let responseEnum = enumSocket bufsize sock - (>>= check seqno) <$> timeout tm (parseResponse responseEnum responseIter) + (>>= check seqno) <$> timeout tm (receive sock bufsize) where sock = dnsSock rlv bufsize = dnsBufsize rlv @@ -182,3 +182,17 @@ lookupRaw rlv dom typ = do if identifier hdr == seqno && anCount hdr /= 0 then Just res else Nothing + +---------------------------------------------------------------- + +composeQuery :: Int -> [Question] -> BL.ByteString +composeQuery idt qs = encode qry + where + hdr = header defaultQuery + qry = defaultQuery { + header = hdr { + identifier = idt + , qdCount = length qs + } + , question = qs + } diff --git a/Network/DNS/Response.hs b/Network/DNS/Response.hs deleted file mode 100644 index 523335b..0000000 --- a/Network/DNS/Response.hs +++ /dev/null @@ -1,175 +0,0 @@ -{-# LANGUAGE OverloadedStrings #-} - -module Network.DNS.Response (responseIter, parseResponse, runDNSFormat, runDNSFormat_) where - -import Control.Applicative -import Control.Monad -import Data.Bits -import qualified Data.ByteString.Char8 as BS -import Data.IP -import Data.Maybe -import Network.DNS.Internal -import Network.DNS.StateBinary -import Data.Enumerator (Enumerator, Iteratee, run_, ($$)) -import Data.ByteString (ByteString) -import qualified Data.ByteString.Lazy as BL - -runDNSFormat :: BL.ByteString -> Either String (DNSFormat, PState) -runDNSFormat = runSGet decodeResponse - -runDNSFormat_ :: BL.ByteString -> Either String DNSFormat -runDNSFormat_ bs = fst <$> runDNSFormat bs - -responseIter :: Monad m => Iteratee ByteString m (DNSFormat, PState) -responseIter = iterSGet decodeResponse - -parseResponse :: (Functor m, Monad m) - => Enumerator ByteString m (a,b) - -> Iteratee ByteString m (a,b) - -> m a -parseResponse enum iter = fst <$> run_ (enum $$ iter) - ----------------------------------------------------------------- - -decodeResponse :: SGet DNSFormat -decodeResponse = do - hd <- decodeHeader - DNSFormat hd <$> decodeQueries (qdCount hd) - <*> decodeRRs (anCount hd) - <*> decodeRRs (nsCount hd) - <*> decodeRRs (arCount hd) - ----------------------------------------------------------------- - -decodeFlags :: SGet DNSFlags -decodeFlags = toFlags <$> get16 - where - toFlags flgs = DNSFlags (getQorR flgs) - (getOpcode flgs) - (getAuthAnswer flgs) - (getTrunCation flgs) - (getRecDesired flgs) - (getRecAvailable flgs) - (getRcode flgs) - getQorR w = if testBit w 15 then QR_Response else QR_Query - getOpcode w = toEnum $ fromIntegral $ shiftR w 11 .&. 0x0f - getAuthAnswer w = testBit w 10 - getTrunCation w = testBit w 9 - getRecDesired w = testBit w 8 - getRecAvailable w = testBit w 7 - getRcode w = toEnum $ fromIntegral $ w .&. 0x0f - ----------------------------------------------------------------- - -decodeHeader :: SGet DNSHeader -decodeHeader = DNSHeader <$> decodeIdentifier - <*> decodeFlags - <*> decodeQdCount - <*> decodeAnCount - <*> decodeNsCount - <*> decodeArCount - where - decodeIdentifier = getInt16 - decodeQdCount = getInt16 - decodeAnCount = getInt16 - decodeNsCount = getInt16 - decodeArCount = getInt16 - ----------------------------------------------------------------- - -decodeQueries :: Int -> SGet [Question] -decodeQueries n = replicateM n decodeQuery - -decodeType :: SGet TYPE -decodeType = intToType <$> getInt16 - -decodeQuery :: SGet Question -decodeQuery = Question <$> decodeDomain - <*> (decodeType <* ignoreClass) - -decodeRRs :: Int -> SGet [ResourceRecord] -decodeRRs n = replicateM n decodeRR - -decodeRR :: SGet ResourceRecord -decodeRR = do - Question dom typ <- decodeQuery - ttl <- decodeTTL - len <- decodeRLen - dat <- decodeRData typ len - return ResourceRecord { rrname = dom - , rrtype = typ - , rrttl = ttl - , rdlen = len - , rdata = dat - } - where - decodeTTL = fromIntegral <$> get32 - decodeRLen = getInt16 - -decodeRData :: TYPE -> Int -> SGet RDATA -decodeRData NS _ = RD_NS <$> decodeDomain -decodeRData MX _ = RD_MX <$> decodePreference <*> decodeDomain - where - decodePreference = getInt16 -decodeRData CNAME _ = RD_CNAME <$> decodeDomain -decodeRData TXT len = (RD_TXT . ignoreLength) <$> getNByteString len - where - ignoreLength = BS.tail -decodeRData A len = (RD_A . toIPv4) <$> getNBytes len -decodeRData AAAA len = (RD_AAAA . toIPv6 . combine) <$> getNBytes len - where - combine [] = [] - combine [_] = error "combine" - combine (a:b:cs) = a * 256 + b : combine cs -decodeRData SOA _ = RD_SOA <$> decodeDomain - <*> decodeDomain - <*> decodeSerial - <*> decodeRefesh - <*> decodeRetry - <*> decodeExpire - <*> decodeMinumun - where - decodeSerial = getInt32 - decodeRefesh = getInt32 - decodeRetry = getInt32 - decodeExpire = getInt32 - decodeMinumun = getInt32 -decodeRData PTR _ = RD_PTR <$> decodeDomain -decodeRData SRV _ = RD_SRV <$> decodePriority - <*> decodeWeight - <*> decodePort - <*> decodeDomain - where - decodePriority = getInt16 - decodeWeight = getInt16 - decodePort = getInt16 - -decodeRData _ len = RD_OTH <$> getNBytes len - ----------------------------------------------------------------- - -decodeDomain :: SGet Domain -decodeDomain = do - pos <- getPosition - c <- getInt8 - if c == 0 - then return "" - else do - let n = getValue c - if isPointer c - then do - d <- getInt8 - let offset = n * 256 + d - fromMaybe (error $ "decodeDomain: " ++ show offset) <$> pop offset - else do - hs <- getNByteString n - ds <- decodeDomain - let dom = hs `BS.append` "." `BS.append` ds - push pos dom - return dom - where - getValue c = c .&. 0x3f - isPointer c = testBit c 7 && testBit c 6 - -ignoreClass :: SGet () -ignoreClass = () <$ get16 diff --git a/SimpleServer.hs b/SimpleServer.hs index 401e7f4..8e12a29 100644 --- a/SimpleServer.hs +++ b/SimpleServer.hs @@ -1,23 +1,20 @@ {-# LANGUAGE RecordWildCards, OverloadedStrings #-} -import System.Environment -import Debug.Trace -import Control.Monad -import Control.Concurrent import Control.Applicative -import Data.Monoid -import Data.Maybe +import Control.Concurrent +import Control.Monad import qualified Data.ByteString as S import Data.ByteString.Lazy hiding (putStrLn, filter, length) -import System.Timeout +import Data.Default +import Data.Maybe +import Data.Monoid +import Debug.Trace import Network.BSD import Network.DNS hiding (lookup) -import Network.DNS.Response -import Network.DNS.Query import Network.Socket hiding (recvFrom) import Network.Socket.ByteString -import Network.Socket.Enumerator -import Data.Default +import System.Environment +import System.Timeout data Conf = Conf { bufSize :: Int @@ -42,11 +39,9 @@ proxyRequest :: Conf -> ResolvConf -> DNSFormat -> IO (Maybe DNSFormat) proxyRequest Conf{..} rc req = do let worker Resolver{..} = do - let packet = mconcat . toChunks $ composeDNSFormat req + let packet = mconcat . toChunks $ encode req sendAll dnsSock packet - let responseEnum = enumSocket dnsBufsize dnsSock - parseResponse responseEnum responseIter - + receive dnsSock dnsBufsize rs <- makeResolvSeed rc withResolver rs $ \r -> (>>= check) <$> timeout' "proxy timeout" timeOut (worker r) @@ -76,7 +71,7 @@ handleRequest conf rc req = maybe (proxyRequest conf rc req) (trace "return A re ] handlePacket :: Conf -> Socket -> SockAddr -> S.ByteString -> IO () -handlePacket conf@Conf{..} sock addr bs = case runDNSFormat_ (fromChunks [bs]) of +handlePacket conf@Conf{..} sock addr bs = case decode (fromChunks [bs]) of Right req -> do print req let rc = defaultResolvConf { resolvInfo = RCHostName realDNS } @@ -84,7 +79,7 @@ handlePacket conf@Conf{..} sock addr bs = case runDNSFormat_ (fromChunks [bs]) o print mrsp case mrsp of Just rsp -> - let packet = mconcat . toChunks $ composeDNSFormat rsp + let packet = mconcat . toChunks $ encode rsp in timeout' "send timeout" timeOut (sendAllTo sock packet addr) >> print (S.length packet) >> return () @@ -95,7 +90,7 @@ main :: IO () main = withSocketsDo $ do dns <- fromMaybe (realDNS def) . listToMaybe <$> getArgs let conf = def { realDNS=dns } - addrinfos <- getAddrInfo + addrinfos <- getAddrInfo (Just (defaultHints {addrFlags = [AI_PASSIVE]})) Nothing (Just "domain") addrinfo <- maybe (fail "no addr info") return (listToMaybe addrinfos) diff --git a/TestProtocol.hs b/TestProtocol.hs index 3c78cb6..5e2abe7 100644 --- a/TestProtocol.hs +++ b/TestProtocol.hs @@ -2,17 +2,15 @@ module TestProtocol where +import Data.IP import Network.DNS import Network.DNS.Internal -import Network.DNS.Query -import Network.DNS.Response -import Data.IP import Test.Framework (defaultMain, testGroup, Test) import Test.Framework.Providers.HUnit import Test.HUnit hiding (Test) tests :: [Test] -tests = +tests = [ testGroup "Test case" [ testCase "QueryA" (test_Format testQueryA) , testCase "QueryAAAA" (test_Format testQueryAAAA) @@ -155,8 +153,8 @@ test_Format fmt = do let (Right fmt') = result assertEqual "fail" fmt fmt' where - bs = composeDNSFormat fmt - result = runDNSFormat_ bs + bs = encode fmt + result = decode bs main :: IO () main = defaultMain tests diff --git a/dns.cabal b/dns.cabal index 1d1856a..7c975ed 100644 --- a/dns.cabal +++ b/dns.cabal @@ -1,14 +1,11 @@ Name: dns -Version: 0.2.1 +Version: 0.3.0 Author: Kazu Yamamoto Maintainer: Kazu Yamamoto License: BSD3 License-File: LICENSE Synopsis: DNS libary in Haskell -Description: DNS libary. Currently only resolver side - is supported. That is, this library includes - a composer of DNS query and a parser of DNS - response. +Description: DNS libary for clients and servers. Category: Network Cabal-Version: >= 1.6 Build-Type: Simple @@ -22,10 +19,10 @@ library Network.DNS.Lookup Network.DNS.Resolver Network.DNS.Types + Network.DNS.Encode + Network.DNS.Decode Other-Modules: Network.DNS.Internal Network.DNS.StateBinary - Network.DNS.Query - Network.DNS.Response if impl(ghc >= 7) Build-Depends: base >= 4 && < 5, binary, iproute, -- cgit v1.2.3