summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar huangyi <yi.codeplayer@gmail.com>2011-10-23 22:03:27 +0800
committerGravatar huangyi <yi.codeplayer@gmail.com>2011-10-23 22:03:27 +0800
commit89d6ab583274e7e10a69bc915b0e48cfdbc6207a (patch)
treec368ba3b3664fbfe9b2be45353d07e1e0960adea
parentf16d70af84c736b986153727e2bbcb11ec5da7bd (diff)
add domain compress
-rw-r--r--Network/DNS/Query.hs32
-rw-r--r--Network/DNS/StateBinary.hs60
-rw-r--r--TestProtocol.hs20
3 files changed, 81 insertions, 31 deletions
diff --git a/Network/DNS/Query.hs b/Network/DNS/Query.hs
index 0d164b4..3ebd5e0 100644
--- a/Network/DNS/Query.hs
+++ b/Network/DNS/Query.hs
@@ -2,9 +2,7 @@
module Network.DNS.Query (composeQuery, composeDNSFormat) where
import qualified Data.ByteString.Lazy.Char8 as BL (ByteString)
-import qualified Data.ByteString as BS (unpack)
-import qualified Data.ByteString.Char8 as BS (length, split, null)
-import Blaze.ByteString.Builder.ByteString (writeByteString)
+import qualified Data.ByteString.Char8 as BS (length, null, break, drop)
import Network.DNS.StateBinary
import Network.DNS.Internal
import Data.Monoid
@@ -109,7 +107,7 @@ encodeRDATA rd = case rd of
(RD_CNAME dom) -> encodeDomain dom
(RD_PTR dom) -> encodeDomain dom
(RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom]
- (RD_TXT txt) -> writeByteString txt
+ (RD_TXT txt) -> putByteString txt
(RD_OTH bytes) -> mconcat $ map putInt8 bytes
(RD_SOA d1 d2 serial refresh retry expire min') -> mconcat $
[ encodeDomain d1
@@ -130,13 +128,23 @@ encodeRDATA rd = case rd of
----------------------------------------------------------------
encodeDomain :: Domain -> SPut
-encodeDomain dom = foldr ((+++) . encodeSubDomain) (put8 0) $ zip ls ss
+encodeDomain dom | BS.null dom = put8 0
+encodeDomain dom = do
+ mpos <- wsPop dom
+ cur <- gets wsPosition
+ case mpos of
+ Just pos -> encodePointer pos
+ Nothing -> wsPush dom cur >>
+ mconcat [ encodePartialDomain hd
+ , encodeDomain tl
+ ]
where
- ss = filter (not . BS.null) $ BS.split '.' dom
- ls = map BS.length ss
+ (hd, tl') = BS.break (=='.') dom
+ tl = if BS.null tl' then tl' else BS.drop 1 tl'
-encodeSubDomain :: (Int, Domain) -> SPut
-encodeSubDomain (len,sub) = putInt8 len
- +++ foldr ((+++) . put8) mempty ss
- where
- ss = BS.unpack sub
+encodePointer :: Int -> SPut
+encodePointer pos = let w = (pos .|. 0xc000) in putInt16 w
+
+encodePartialDomain :: Domain -> SPut
+encodePartialDomain sub = putInt8 (BS.length sub)
+ +++ putByteString sub
diff --git a/Network/DNS/StateBinary.hs b/Network/DNS/StateBinary.hs
index e44ee98..6898d3b 100644
--- a/Network/DNS/StateBinary.hs
+++ b/Network/DNS/StateBinary.hs
@@ -1,43 +1,85 @@
+{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
module Network.DNS.StateBinary where
import Blaze.ByteString.Builder
import Control.Applicative
import Control.Monad.State
+import Data.Monoid
import Data.Attoparsec
import Data.Attoparsec.Enumerator
import qualified Data.Attoparsec.Lazy as AL
import Data.ByteString (ByteString)
-import qualified Data.ByteString as BS (unpack)
+import qualified Data.ByteString as BS (unpack, length)
import qualified Data.ByteString.Lazy as BL (ByteString)
import Data.Enumerator (Iteratee)
import Data.Int
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM (insert, lookup, empty)
+import Data.Map (Map)
+import qualified Data.Map as M (insert, lookup, empty)
import Data.Word
import Network.DNS.Types
import Prelude hiding (lookup, take)
----------------------------------------------------------------
-type SPut = Write
+type SPut = State WState Write
+
+data WState = WState {
+ wsDomain :: Map Domain Int
+ , wsPosition :: Int
+}
+
+initialWState :: WState
+initialWState = WState M.empty 0
+
+instance Monoid SPut where
+ mempty = return mempty
+ mappend a b = mconcat <$> sequence [a, b]
put8 :: Word8 -> SPut
-put8 = writeWord8
+put8 = fixedSized 1 writeWord8
put16 :: Word16 -> SPut
-put16 = writeWord16be
+put16 = fixedSized 2 writeWord16be
put32 :: Word32 -> SPut
-put32 = writeWord32be
+put32 = fixedSized 4 writeWord32be
putInt8 :: Int -> SPut
-putInt8 = writeInt8 . fromIntegral
+putInt8 = fixedSized 1 (writeInt8 . fromIntegral)
putInt16 :: Int -> SPut
-putInt16 = writeInt16be . fromIntegral
+putInt16 = fixedSized 2 (writeInt16be . fromIntegral)
putInt32 :: Int -> SPut
-putInt32 = writeInt32be . fromIntegral
+putInt32 = fixedSized 4 (writeInt32be . fromIntegral)
+
+putByteString :: ByteString -> SPut
+putByteString = writeSized BS.length writeByteString
+
+addPositionW :: Int -> State WState ()
+addPositionW n = do
+ (WState m cur) <- get
+ put $ WState m (cur+n)
+
+fixedSized :: Int -> (a -> Write) -> a -> SPut
+fixedSized n f a = do addPositionW n
+ return (f a)
+
+writeSized :: Show a => (a -> Int) -> (a -> Write) -> a -> SPut
+writeSized n f a = do addPositionW (n a)
+ return (f a)
+
+wsPop :: Domain -> State WState (Maybe Int)
+wsPop dom = do
+ doms <- gets wsDomain
+ return $ M.lookup dom doms
+
+wsPush :: Domain -> Int -> State WState ()
+wsPush dom pos = do
+ (WState m cur) <- get
+ put $ WState (M.insert dom pos m) cur
----------------------------------------------------------------
@@ -122,4 +164,4 @@ runSGet :: SGet a -> BL.ByteString -> Either String (a, PState)
runSGet parser bs = AL.eitherResult $ AL.parse (runStateT parser initialState) bs
runSPut :: SPut -> BL.ByteString
-runSPut = toLazyByteString . fromWrite
+runSPut = toLazyByteString . fromWrite . flip evalState initialWState
diff --git a/TestProtocol.hs b/TestProtocol.hs
index bde01dd..cef00ee 100644
--- a/TestProtocol.hs
+++ b/TestProtocol.hs
@@ -15,17 +15,17 @@ import Test.HUnit hiding (Test)
tests :: [Test]
tests =
[ testGroup "Test case"
- [ testCase "QueryA" (test_Format queryA)
- , testCase "QueryAAAA" (test_Format queryAAAA)
- , testCase "ResponseA" (test_Format responseA)
+ [ testCase "QueryA" (test_Format testQueryA)
+ , testCase "QueryAAAA" (test_Format testQueryAAAA)
+ , testCase "ResponseA" (test_Format $ testResponseA)
]
]
defaultHeader :: DNSHeader
defaultHeader = header defaultQuery
-queryA :: DNSFormat
-queryA = defaultQuery
+testQueryA :: DNSFormat
+testQueryA = defaultQuery
{ header = defaultHeader
{ identifier = 1000
, qdCount = 1
@@ -33,8 +33,8 @@ queryA = defaultQuery
, question = [makeQuestion "www.mew.org." A]
}
-queryAAAA :: DNSFormat
-queryAAAA = defaultQuery
+testQueryAAAA :: DNSFormat
+testQueryAAAA = defaultQuery
{ header = defaultHeader
{ identifier = 1000
, qdCount = 1
@@ -42,8 +42,8 @@ queryAAAA = defaultQuery
, question = [makeQuestion "www.mew.org." AAAA]
}
-responseA :: DNSFormat
-responseA = DNSFormat { header = DNSHeader { identifier = 61046
+testResponseA :: DNSFormat
+testResponseA = DNSFormat { header = DNSHeader { identifier = 61046
, flags = DNSFlags { qOrR = QR_Response
, opcode = OP_STD
, authAnswer = False
@@ -157,7 +157,7 @@ test_Format fmt = do
assertEqual "fail" fmt fmt'
where
bs = composeDNSFormat fmt
- result = runResponse_ bs
+ result = runDNSFormat_ bs
main :: IO ()
main = defaultMain tests