I'm using Data.Serialize.Get
and am trying to define the following combinator:
getConsumed :: Get a -> Get (ByteString, a)
which should act like the passed-in Get
action, but also return the ByteString
that the Get
consumed. The use case is that I have a binary structure that I need to both parse and hash, and I don't know the length before parsing it.
This combinator, despite its simple semantics, is proving surprisingly tricky to implement.
Without delving into the internals of Get
, my instinct was to use this monstrosity:
getConsumed :: Get a -> Get (B.ByteString, a)
getConsumed g = do
(len, r) <- lookAhead $ do
before <- remaining
res <- g
after <- remaining
return (before - after, res)
bs <- getBytes len
return (bs, r)
Which will use lookahead, peek at the remaining bytes before and after running the action, return the result of the action, and then consume the length. This shouldn't duplicate any work, but it occasionally fails with:
*** Exception: GetException "Failed reading: getBytes: negative length requested\nEmpty call stack\n"
so I must be misunderstanding something about cereal somewhere.
Does anyone see what's wrong with my definition of getconsumed
or have a better idea for how to implement it?
Edit: Dan Doel points out that remaining
can just return the remaining length of a given chunk, which isn't very useful if you cross a chunk boundary. I'm not sure what the point of the action is, in that case, but that explains why my code wasn't working! Now I just need to find a viable alternative.
Edit 2: after thinking about it some more, it seems like the fact that remaining
gives me the length of the current chunk can be to my advantage if I feed the Get
manually with individual chunks (remaining >>= getBytes
) in a loop and keep track of what it's eating as I do it. I haven't managed to get this approach working either yet, but it seems more promising than the original one.
Edit 3: if anyone's curious, here's code from edit 2 above:
getChunk :: Get B.ByteString
getChunk = remaining >>= getBytes
getConsumed :: Get a -> Get (B.ByteString, a)
getConsumed g = do
(len, res) <- lookAhead $ measure g
bs <- getBytes len
return (bs, res)
where
measure :: Get a -> Get (Int ,a)
measure g = do
chunk <- getChunk
measure' (B.length chunk) (runGetPartial g chunk)
measure' :: Int -> Result a -> Get (Int, a)
measure' !n (Fail e) = fail e
measure' !n (Done r bs) = return (n - B.length bs, r)
measure' !n (Partial f) = do
chunk <- getChunk
measure' (n + B.length chunk) (f chunk)
Unfortunately, it still seems to fail after a while on my sample input with:
*** Exception: GetException "Failed reading: too few bytes\nFrom:\tdemandInput\n\n\nEmpty call stack\n"
EDIT: Another solution, which does no extra computation!
getConsumed :: Get a -> Get (B.ByteString, a)
getConsumed g = do
(len, r) <- lookAhead $ do
(res,after) <- lookAhead $ liftM2 (,) g remaining
total <- remaining
return (total-after, res)
bs <- getBytes len
return (bs, r)
One solution is to call lookAhead
twice. The first time makes sure that all necessary chunks are loaded, and the second performs the actual length computation (along with returning the deserialized data).
getConsumed :: Get a -> Get (B.ByteString, a)
getConsumed g = do
_ <- lookAhead g -- Make sure all necessary chunks are preloaded
(len, r) <- lookAhead $ do
before <- remaining
res <- g
after <- remaining
return (before - after, res)
bs <- getBytes len
return (bs, r)
The Cereal package does not store enough information to simply implement what you want. I expect that your idea of using chunks might work, or perhaps a special runGet
. Forking Cereal and using the internals is probably your easiest path.
Writing your own can work, this is what I did when making the protocol-buffers library. My custom Text.ProtocolBuffers.Get
library does implement enough machinery to do what you want:
import Text.ProtocolBuffers.Get
import Control.Applicative
import qualified Data.ByteString as B
getConsumed :: Get a -> Get (B.ByteString, a)
getConsumed thing = do
start <- bytesRead
(a,stop) <- lookAhead ((,) <$> thing <*> bytesRead)
bs <- getByteString (fromIntegral (stop-start))
return (bs,a)
This is clear because my library tracks the number of byteRead. Otherwise the API is quite similar to Cereal.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With