summaryrefslogtreecommitdiffhomepage
path: root/Network/DNS/StateBinary.hs
blob: 6898d3b86389f7ed616e9a2971a531dad111d9ff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
module Network.DNS.StateBinary where

import Blaze.ByteString.Builder
import Control.Applicative
import Control.Monad.State
import Data.Monoid
import Data.Attoparsec
import Data.Attoparsec.Enumerator
import qualified Data.Attoparsec.Lazy as AL
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (unpack, length)
import qualified Data.ByteString.Lazy as BL (ByteString)
import Data.Enumerator (Iteratee)
import Data.Int
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM (insert, lookup, empty)
import Data.Map (Map)
import qualified Data.Map as M (insert, lookup, empty)
import Data.Word
import Network.DNS.Types
import Prelude hiding (lookup, take)

----------------------------------------------------------------

type SPut = State WState Write

data WState = WState {
    wsDomain :: Map Domain Int
  , wsPosition :: Int
}

initialWState :: WState
initialWState = WState M.empty 0

instance Monoid SPut where
    mempty = return mempty
    mappend a b = mconcat <$> sequence [a, b]

put8 :: Word8 -> SPut
put8 = fixedSized 1 writeWord8

put16 :: Word16 -> SPut
put16 = fixedSized 2 writeWord16be

put32 :: Word32 -> SPut
put32 = fixedSized 4 writeWord32be

putInt8 :: Int -> SPut
putInt8 = fixedSized 1 (writeInt8 . fromIntegral)

putInt16 :: Int -> SPut
putInt16 = fixedSized 2 (writeInt16be . fromIntegral)

putInt32 :: Int -> SPut
putInt32 = fixedSized 4 (writeInt32be . fromIntegral)

putByteString :: ByteString -> SPut
putByteString = writeSized BS.length writeByteString

addPositionW :: Int -> State WState ()
addPositionW n = do
    (WState m cur) <- get
    put $ WState m (cur+n)

fixedSized :: Int -> (a -> Write) -> a -> SPut
fixedSized n f a = do addPositionW n
                      return (f a)

writeSized :: Show a => (a -> Int) -> (a -> Write) -> a -> SPut
writeSized n f a = do addPositionW (n a)
                      return (f a)

wsPop :: Domain -> State WState (Maybe Int)
wsPop dom = do
    doms <- gets wsDomain
    return $ M.lookup dom doms

wsPush :: Domain -> Int -> State WState ()
wsPush dom pos = do
    (WState m cur) <- get
    put $ WState (M.insert dom pos m) cur

----------------------------------------------------------------

type SGet = StateT PState Parser

data PState = PState {
    psDomain :: IntMap Domain
  , psPosition :: Int
  }

----------------------------------------------------------------

getPosition :: SGet Int
getPosition = psPosition <$> get

addPosition :: Int -> SGet ()
addPosition n = do
    PState dom pos <- get
    put $ PState dom (pos + n)

push :: Int -> Domain -> SGet ()
push n d = do
    PState dom pos <- get
    put $ PState (IM.insert n d dom) pos

pop :: Int -> SGet (Maybe Domain)
pop n = IM.lookup n . psDomain <$> get

----------------------------------------------------------------

get8 :: SGet Word8
get8  = lift anyWord8 <* addPosition 1

get16 :: SGet Word16
get16 = lift getWord16be <* addPosition 2
  where
    word8' = fromIntegral <$> anyWord8
    getWord16be = do
        a <- word8'
        b <- word8'
        return $ a * 256 + b

get32 :: SGet Word32
get32 = lift getWord32be <* addPosition 4
  where
    word8' = fromIntegral <$> anyWord8
    getWord32be = do
        a <- word8'
        b <- word8'
        c <- word8'
        d <- word8'
        return $ a * 1677721 + b * 65536 + c * 256 + d

getInt8 :: SGet Int
getInt8 = fromIntegral <$> get8

getInt16 :: SGet Int
getInt16 = fromIntegral <$> get16

getInt32 :: SGet Int
getInt32 = fromIntegral <$> get32

----------------------------------------------------------------

getNBytes :: Int -> SGet [Int]
getNBytes len = toInts <$> getNByteString len
  where
    toInts = map fromIntegral . BS.unpack

getNByteString :: Int -> SGet ByteString
getNByteString n = lift (take n) <* addPosition n

----------------------------------------------------------------

initialState :: PState
initialState = PState IM.empty 0

iterSGet :: Monad m => SGet a -> Iteratee ByteString m (a, PState)
iterSGet parser = iterParser (runStateT parser initialState)

runSGet :: SGet a -> BL.ByteString -> Either String (a, PState)
runSGet parser bs = AL.eitherResult $ AL.parse (runStateT parser initialState) bs

runSPut :: SPut -> BL.ByteString
runSPut = toLazyByteString . fromWrite . flip evalState initialWState