From 1e42a760dece5f8dc2a48c6d6b2c73301ea3401b Mon Sep 17 00:00:00 2001
From: Philipp Balzarek
Date: Sat, 1 Dec 2012 22:47:35 +0100
Subject: [PATCH] Fix BufferedSource (couldn't handle leftovers and was
therefore useless) Factor out connection from Session Abstract over
connection type (remove mention of Handle)
---
source/Data/Conduit/BufferedSource.hs | 30 ++++----
source/Data/Conduit/TLS.hs | 52 ++++++++-----
source/Network/Xmpp/Concurrent/Monad.hs | 4 +-
source/Network/Xmpp/Monad.hs | 98 ++++++++++++++-----------
source/Network/Xmpp/Session.hs | 2 +-
source/Network/Xmpp/Stream.hs | 20 +++--
source/Network/Xmpp/TLS.hs | 66 ++++++++++++++---
source/Network/Xmpp/Types.hs | 26 ++++---
tests/Tests.hs | 4 +-
9 files changed, 197 insertions(+), 105 deletions(-)
diff --git a/source/Data/Conduit/BufferedSource.hs b/source/Data/Conduit/BufferedSource.hs
index cfb620e..efe6a76 100644
--- a/source/Data/Conduit/BufferedSource.hs
+++ b/source/Data/Conduit/BufferedSource.hs
@@ -14,21 +14,19 @@ data SourceClosed = SourceClosed deriving (Show, Typeable)
instance Exception SourceClosed
+newtype BufferedSource m o = BufferedSource
+ { bs :: IORef (ResumableSource m o)
+ }
+
+
-- | Buffered source from conduit 0.3
-bufferSource :: MonadIO m => Source m o -> IO (Source m o)
+bufferSource :: Monad m => Source m o -> IO (BufferedSource m o)
bufferSource s = do
- srcRef <- newIORef . Just $ DCI.ResumableSource s (return ())
- return $ do
- src' <- liftIO $ readIORef srcRef
- src <- case src' of
- Just s -> return s
- Nothing -> liftIO $ throwIO SourceClosed
- let go src = do
- (src', res) <- lift $ src $$++ CL.head
- case res of
- Nothing -> liftIO $ writeIORef srcRef Nothing
- Just x -> do
- liftIO (writeIORef srcRef $ Just src')
- yield x
- go src'
- in go src
+ srcRef <- newIORef $ DCI.ResumableSource s (return ())
+ return $ BufferedSource srcRef
+
+(.$$+) (BufferedSource bs) snk = do
+ src <- liftIO $ readIORef bs
+ (src', r) <- src $$++ snk
+ liftIO $ writeIORef bs src'
+ return r
diff --git a/source/Data/Conduit/TLS.hs b/source/Data/Conduit/TLS.hs
index 2dbdb28..68fa23b 100644
--- a/source/Data/Conduit/TLS.hs
+++ b/source/Data/Conduit/TLS.hs
@@ -8,38 +8,42 @@ module Data.Conduit.TLS
)
where
-import Control.Monad(liftM, when)
-import Control.Monad.IO.Class
+import Control.Monad
+import Control.Monad (liftM, when)
+import Control.Monad.IO.Class
-import Crypto.Random
+import Crypto.Random
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
-import Data.Conduit
-import Control.Monad
+import Data.Conduit
+import qualified Data.Conduit.Binary as CB
+import Data.IORef
-import Network.TLS as TLS
-import Network.TLS.Extra as TLSExtra
+import Network.TLS as TLS
+import Network.TLS.Extra as TLSExtra
-import System.IO(Handle)
+import System.IO (Handle)
-client params gen handle = do
- contextNewOnHandle handle params gen
+client params gen backend = do
+ contextNew backend params gen
defaultParams = defaultParamsClient
tlsinit :: (MonadIO m, MonadIO m1) =>
Bool
-> TLSParams
- -> Handle -> m ( Source m1 BS.ByteString
- , Sink BS.ByteString m1 ()
- , BS.ByteString -> IO ()
- , Context
- )
-tlsinit debug tlsParams handle = do
+ -> Backend
+ -> m ( Source m1 BS.ByteString
+ , Sink BS.ByteString m1 ()
+ , BS.ByteString -> IO ()
+ , Int -> m1 BS.ByteString
+ , Context
+ )
+tlsinit debug tlsParams backend = do
when debug . liftIO $ putStrLn "TLS with debug mode enabled"
gen <- liftIO $ (newGenIO :: IO SystemRandom) -- TODO: Find better random source?
- con <- client tlsParams gen handle
+ con <- client tlsParams gen backend
handshake con
let src = forever $ do
dt <- liftIO $ recvData con
@@ -53,10 +57,24 @@ tlsinit debug tlsParams handle = do
sendData con (BL.fromChunks [x])
when debug (liftIO $ putStr "out: " >> BS.putStrLn x)
snk
+ read <- liftIO $ mkReadBuffer (recvData con)
return ( src
, snk
, \s -> do
when debug (liftIO $ BS.putStrLn s)
sendData con $ BL.fromChunks [s]
+ , liftIO . read
, con
)
+
+mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString)
+mkReadBuffer read = do
+ buffer <- newIORef BS.empty
+ let read' n = do
+ nc <- readIORef buffer
+ bs <- if BS.null nc then read
+ else return nc
+ let (result, rest) = BS.splitAt n bs
+ writeIORef buffer rest
+ return result
+ return read'
diff --git a/source/Network/Xmpp/Concurrent/Monad.hs b/source/Network/Xmpp/Concurrent/Monad.hs
index 9046e6f..2f17ea9 100644
--- a/source/Network/Xmpp/Concurrent/Monad.hs
+++ b/source/Network/Xmpp/Concurrent/Monad.hs
@@ -56,7 +56,7 @@ withConnection a session = do
(do
(res, s') <- runStateT a s
atomically $ do
- putTMVar (writeRef session) (sConPushBS s')
+ putTMVar (writeRef session) (cSend . sCon $ s')
putTMVar (conStateRef session) s'
return $ Right res
)
@@ -102,7 +102,7 @@ endSession session = do -- TODO: This has to be idempotent (is it?)
closeConnection :: Session -> IO ()
closeConnection session = Ex.mask_ $ do
send <- atomically $ takeTMVar (writeRef session)
- cc <- sCloseConnection <$> ( atomically $ readTMVar (conStateRef session))
+ cc <- cClose . sCon <$> ( atomically $ readTMVar (conStateRef session))
send ""
void . forkIO $ do
threadDelay 3000000
diff --git a/source/Network/Xmpp/Monad.hs b/source/Network/Xmpp/Monad.hs
index 63503ba..e2cea4e 100644
--- a/source/Network/Xmpp/Monad.hs
+++ b/source/Network/Xmpp/Monad.hs
@@ -16,9 +16,10 @@ import Control.Monad.State.Strict
import Data.ByteString as BS
import Data.Conduit
-import qualified Data.Conduit.List as CL
-import Data.Conduit.BufferedSource
import Data.Conduit.Binary as CB
+import Data.Conduit.BufferedSource
+import qualified Data.Conduit.List as CL
+import Data.IORef
import Data.Text(Text)
import Data.XML.Pickle
import Data.XML.Types
@@ -42,8 +43,8 @@ debug = False
pushElement :: Element -> XmppConMonad Bool
pushElement x = do
- sink <- gets sConPushBS
- liftIO . sink $ renderElement x
+ send <- gets (cSend . sCon)
+ liftIO . send $ renderElement x
-- | Encode and send stanza
pushStanza :: Stanza -> XmppConMonad Bool
@@ -55,26 +56,26 @@ pushStanza = pushElement . pickleElem xpStanza
-- XMPP streams. RFC 6120 defines XMPP only in terms of XML 1.0.
pushXmlDecl :: XmppConMonad Bool
pushXmlDecl = do
- sink <- gets sConPushBS
- liftIO $ sink ""
+ con <- gets sCon
+ liftIO $ (cSend con) ""
pushOpenElement :: Element -> XmppConMonad Bool
pushOpenElement e = do
- sink <- gets sConPushBS
+ sink <- gets (cSend . sCon )
liftIO . sink $ renderOpenElement e
-- `Connect-and-resumes' the given sink to the connection source, and pulls a
-- `b' value.
-pullToSink :: Sink Event IO b -> XmppConMonad b
-pullToSink snk = do
- source <- gets sConSrc
- (_, r) <- lift $ source $$+ snk
+pullToSinkEvents :: Sink Event IO b -> XmppConMonad b
+pullToSinkEvents snk = do
+ source <- gets (cEventSource . sCon)
+ r <- lift $ source .$$+ snk
return r
pullElement :: XmppConMonad Element
pullElement = do
Ex.catches (do
- e <- pullToSink (elements =$ CL.head)
+ e <- pullToSinkEvents (elements =$ await)
case e of
Nothing -> liftIO $ Ex.throwIO StreamConnectionError
Just r -> return r
@@ -106,8 +107,8 @@ pullStanza = do
-- Performs the given IO operation, catches any errors and re-throws everything
-- except 'ResourceVanished' and IllegalOperation, in which case it will return False instead
-catchPush :: IO () -> IO Bool
-catchPush p = Ex.catch
+catchSend :: IO () -> IO Bool
+catchSend p = Ex.catch
(p >> return True)
(\e -> case GIE.ioe_type e of
GIE.ResourceVanished -> return False
@@ -115,18 +116,20 @@ catchPush p = Ex.catch
_ -> Ex.throwIO e
)
--- XmppConnection state used when there is no connection.
+-- -- XmppConnection state used when there is no connection.
xmppNoConnection :: XmppConnection
xmppNoConnection = XmppConnection
- { sConSrc = zeroSource
- , sRawSrc = zeroSource
- , sConPushBS = \_ -> return False -- Nothing has been sent.
- , sConHandle = Nothing
+ { sCon = Connection { cSend = \_ -> return False
+ , cRecv = \_ -> Ex.throwIO
+ $ StreamConnectionError
+ , cEventSource = undefined
+ , cFlush = return ()
+ , cClose = return ()
+ }
, sFeatures = SF Nothing [] []
, sConnectionState = XmppConnectionClosed
, sHostname = Nothing
, sJid = Nothing
- , sCloseConnection = return ()
, sStreamLang = Nothing
, sStreamId = Nothing
, sPreferredLang = Nothing
@@ -140,30 +143,34 @@ xmppNoConnection = XmppConnection
-- Connects to the given hostname on port 5222 (TODO: Make this dynamic) and
-- updates the XmppConMonad XmppConnection state.
-xmppRawConnect :: HostName -> PortID -> Text -> XmppConMonad ()
-xmppRawConnect host port hostname = do
- con <- liftIO $ do
- con <- connectTo host port
- hSetBuffering con NoBuffering
- return con
- let raw = if debug
- then sourceHandle con $= debugConduit
- else sourceHandle con
- src <- liftIO . bufferSource $ raw $= XP.parseBytes def
+xmppConnectTCP :: HostName -> PortID -> Text -> XmppConMonad ()
+xmppConnectTCP host port hostname = do
+ hand <- liftIO $ do
+ h <- connectTo host port
+ hSetBuffering h NoBuffering
+ return h
+ eSource <- liftIO . bufferSource $ (sourceHandle hand) $= XP.parseBytes def
+ let con = Connection { cSend = if debug
+ then \d -> do
+ BS.putStrLn (BS.append "out: " d)
+ catchSend $ BS.hPut hand d
+ else catchSend . BS.hPut hand
+ , cRecv = if debug then
+ \n -> do
+ bs <- BS.hGetSome hand n
+ BS.putStrLn bs
+ return bs
+ else BS.hGetSome hand
+ , cEventSource = eSource
+ , cFlush = hFlush hand
+ , cClose = hClose hand
+ }
let st = XmppConnection
- { sConSrc = src
- , sRawSrc = raw
- , sConPushBS = if debug
- then \d -> do
- BS.putStrLn (BS.append "out: " d)
- catchPush $ BS.hPut con d
- else catchPush . BS.hPut con
- , sConHandle = (Just con)
+ { sCon = con
, sFeatures = (SF Nothing [] [])
, sConnectionState = XmppConnectionPlain
, sHostname = (Just hostname)
, sJid = Nothing
- , sCloseConnection = (hClose con)
, sPreferredLang = Nothing -- TODO: Allow user to set
, sStreamLang = Nothing
, sStreamId = Nothing
@@ -180,11 +187,18 @@ xmppNewSession action = runStateT action xmppNoConnection
-- Closes the connection and updates the XmppConMonad XmppConnection state.
xmppKillConnection :: XmppConMonad (Either Ex.SomeException ())
xmppKillConnection = do
- cc <- gets sCloseConnection
+ cc <- gets (cClose . sCon)
err <- liftIO $ (Ex.try cc :: IO (Either Ex.SomeException ()))
put xmppNoConnection
return err
+xmppReplaceConnection :: XmppConnection -> XmppConMonad (Either Ex.SomeException ())
+xmppReplaceConnection newCon = do
+ cc <- gets (cClose . sCon)
+ err <- liftIO $ (Ex.try cc :: IO (Either Ex.SomeException ()))
+ put newCon
+ return err
+
-- Sends an IQ request and waits for the response. If the response ID does not
-- match the outgoing ID, an error is thrown.
xmppSendIQ' :: StanzaId
@@ -211,8 +225,8 @@ xmppSendIQ' iqID to tp lang body = do
-- not we received a element from the server is returned.
xmppCloseStreams :: XmppConMonad ([Element], Bool)
xmppCloseStreams = do
- send <- gets sConPushBS
- cc <- gets sCloseConnection
+ send <- gets (cSend . sCon)
+ cc <- gets (cClose . sCon)
liftIO $ send ""
void $ liftIO $ forkIO $ do
threadDelay 3000000
diff --git a/source/Network/Xmpp/Session.hs b/source/Network/Xmpp/Session.hs
index a282512..8e84ad2 100644
--- a/source/Network/Xmpp/Session.hs
+++ b/source/Network/Xmpp/Session.hs
@@ -58,7 +58,7 @@ simpleConnect host port hostname username password resource = do
-- | Connect to host with given address.
connect :: HostName -> PortID -> Text -> XmppConMonad (Either StreamError ())
connect address port hostname = do
- xmppRawConnect address port hostname
+ xmppConnectTCP address port hostname
result <- xmppStartStream
case result of
Left e -> do
diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs
index 888c2ad..6f8dd5e 100644
--- a/source/Network/Xmpp/Stream.hs
+++ b/source/Network/Xmpp/Stream.hs
@@ -9,14 +9,15 @@ import qualified Control.Exception as Ex
import Control.Monad.Error
import Control.Monad.State.Strict
+import qualified Data.ByteString as BS
import Data.Conduit
import Data.Conduit.BufferedSource
import Data.Conduit.List as CL
import Data.Maybe (fromJust, isJust, isNothing)
import Data.Text as Text
+import Data.Void (Void)
import Data.XML.Pickle
import Data.XML.Types
-import Data.Void (Void)
import Network.Xmpp.Monad
import Network.Xmpp.Pickle
@@ -79,7 +80,8 @@ xmppStartStream = runErrorT $ do
, Nothing
, sPreferredLang state
)
- (lt, from, id, features) <- ErrorT . pullToSink $ runErrorT $ xmppStream from
+ (lt, from, id, features) <- ErrorT . pullToSinkEvents $ runErrorT $
+ xmppStream from
modify (\s -> s { sFeatures = features
, sStreamLang = Just lt
, sStreamId = id
@@ -92,10 +94,16 @@ xmppStartStream = runErrorT $ do
-- and calls xmppStartStream.
xmppRestartStream :: XmppConMonad (Either StreamError ())
xmppRestartStream = do
- raw <- gets sRawSrc
- newsrc <- liftIO . bufferSource $ raw $= XP.parseBytes def
- modify (\s -> s{sConSrc = newsrc})
+ raw <- gets (cRecv . sCon)
+ newSrc <- liftIO . bufferSource $ loopRead raw $= XP.parseBytes def
+ modify (\s -> s{sCon = (sCon s){cEventSource = newSrc}})
xmppStartStream
+ where
+ loopRead read = do
+ bs <- liftIO (read 4096)
+ if BS.null bs
+ then return ()
+ else yield bs >> loopRead read
-- Reads the (partial) stream:stream and the server features from the stream.
-- Also validates the stream element's attributes and throws an error if
@@ -170,4 +178,4 @@ xpStreamFeatures = xpWrap
pickleSaslFeature = xpElemNodes
"{urn:ietf:params:xml:ns:xmpp-sasl}mechanisms"
(xpAll $ xpElemNodes
- "{urn:ietf:params:xml:ns:xmpp-sasl}mechanism" (xpContent xpId))
\ No newline at end of file
+ "{urn:ietf:params:xml:ns:xmpp-sasl}mechanism" (xpContent xpId))
diff --git a/source/Network/Xmpp/TLS.hs b/source/Network/Xmpp/TLS.hs
index c50aebf..26bbe6a 100644
--- a/source/Network/Xmpp/TLS.hs
+++ b/source/Network/Xmpp/TLS.hs
@@ -9,6 +9,10 @@ import Control.Monad
import Control.Monad.Error
import Control.Monad.State.Strict
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Lazy as BL
+import Data.Conduit
+import qualified Data.Conduit.Binary as CB
import Data.Conduit.TLS as TLS
import Data.Typeable
import Data.XML.Types
@@ -18,6 +22,42 @@ import Network.Xmpp.Pickle(ppElement)
import Network.Xmpp.Stream
import Network.Xmpp.Types
+mkBackend con = Backend { backendSend = \bs -> void (cSend con bs)
+ , backendRecv = cRecv con
+ , backendFlush = cFlush con
+ , backendClose = cClose con
+ }
+ where
+ cutBytes n = do
+ liftIO $ putStrLn "awaiting"
+ mbs <- await
+ liftIO $ putStrLn "done awaiting"
+ case mbs of
+ Nothing -> return BS.empty
+ Just bs -> do
+ let (a, b) = BS.splitAt n bs
+ liftIO . putStrLn $
+ "remaining" ++ (show $ BS.length b) ++ " of " ++ (show n)
+
+ unless (BS.null b) $ leftover b
+ return a
+
+
+cutBytes n = do
+ liftIO $ putStrLn "awaiting"
+ mbs <- await
+ liftIO $ putStrLn "done awaiting"
+ case mbs of
+ Nothing -> return False
+ Just bs -> do
+ let (a, b) = BS.splitAt n bs
+ liftIO . putStrLn $
+ "remaining" ++ (show $ BS.length b) ++ " of " ++ (show n)
+
+ unless (BS.null b) $ leftover b
+ return True
+
+
starttlsE :: Element
starttlsE = Element "{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] []
@@ -36,6 +76,7 @@ exampleParams = TLS.defaultParamsClient
data XmppTLSError = TLSError TLSError
| TLSNoServerSupport
| TLSNoConnection
+ | TLSConnectionSecured -- ^ Connection already secured
| TLSStreamError StreamError
| XmppTLSError -- General instance used for the Error instance
deriving (Show, Eq, Typeable)
@@ -48,8 +89,12 @@ instance Error XmppTLSError where
startTLS :: TLS.TLSParams -> XmppConMonad (Either XmppTLSError ())
startTLS params = Ex.handle (return . Left . TLSError) . runErrorT $ do
features <- lift $ gets sFeatures
- handle' <- lift $ gets sConHandle
- handle <- maybe (throwError TLSNoConnection) return handle'
+ state <- gets sConnectionState
+ case state of
+ XmppConnectionPlain -> return ()
+ XmppConnectionClosed -> throwError TLSNoConnection
+ XmppConnectionSecured -> throwError TLSConnectionSecured
+ con <- lift $ gets sCon
when (stls features == Nothing) $ throwError TLSNoServerSupport
lift $ pushElement starttlsE
answer <- lift $ pullElement
@@ -60,14 +105,15 @@ startTLS params = Ex.handle (return . Left . TLSError) . runErrorT $ do
-- TODO: find something more suitable
e -> lift . Ex.throwIO . StreamXMLError $
"Unexpected element: " ++ ppElement e
- (raw, _snk, psh, ctx) <- lift $ TLS.tlsinit debug params handle
- lift $ modify ( \x -> x
- { sRawSrc = raw
--- , sConSrc = -- Note: this momentarily leaves us in an
- -- inconsistent state
- , sConPushBS = catchPush . psh
- , sCloseConnection = TLS.bye ctx >> sCloseConnection x
- })
+ liftIO $ putStrLn "#"
+ (raw, _snk, psh, read, ctx) <- lift $ TLS.tlsinit debug params (mkBackend con)
+ liftIO $ putStrLn "*"
+ let newCon = Connection { cSend = catchSend . psh
+ , cRecv = read
+ , cFlush = contextFlush ctx
+ , cClose = bye ctx >> cClose con
+ }
+ lift $ modify ( \x -> x {sCon = newCon})
either (lift . Ex.throwIO) return =<< lift xmppRestartStream
modify (\s -> s{sConnectionState = XmppConnectionSecured})
return ()
diff --git a/source/Network/Xmpp/Types.hs b/source/Network/Xmpp/Types.hs
index a357437..3806477 100644
--- a/source/Network/Xmpp/Types.hs
+++ b/source/Network/Xmpp/Types.hs
@@ -32,6 +32,7 @@ module Network.Xmpp.Types
, StreamErrorCondition(..)
, Version(..)
, XmppConMonad
+ , Connection(..)
, XmppConnection(..)
, XmppConnectionState(..)
, XmppT(..)
@@ -50,8 +51,10 @@ import Control.Monad.Error
import qualified Data.Attoparsec.Text as AP
import qualified Data.ByteString as BS
import Data.Conduit
-import Data.String(IsString(..))
+import Data.Conduit.BufferedSource
+import Data.IORef
import Data.Maybe (fromJust, fromMaybe, maybeToList)
+import Data.String(IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Typeable(Typeable)
@@ -740,21 +743,24 @@ data XmppConnectionState
| XmppConnectionSecured -- ^ Connection established and secured via TLS.
deriving (Show, Eq, Typeable)
+data Connection = Connection { cSend :: BS.ByteString -> IO Bool
+ , cRecv :: Int -> IO BS.ByteString
+ -- This is to hold the state of the XML parser
+ -- (otherwise we will receive lot's of EvenBegin
+ -- Document and forger about name prefixes)
+ , cEventSource :: BufferedSource IO Event
+
+ , cFlush :: IO ()
+ , cClose :: IO ()
+ }
+
data XmppConnection = XmppConnection
- { sConSrc :: !(Source IO Event) -- ^ inbound connection
- , sRawSrc :: !(Source IO BS.ByteString) -- ^ inbound
- -- connection
- , sConPushBS :: !(BS.ByteString -> IO Bool) -- ^ outbound
- -- connection
- , sConHandle :: !(Maybe Handle) -- ^ Handle for TLS
+ { sCon :: Connection
, sFeatures :: !ServerFeatures -- ^ Features the server
-- advertised
, sConnectionState :: !XmppConnectionState -- ^ State of connection
, sHostname :: !(Maybe Text) -- ^ Hostname of the server
, sJid :: !(Maybe Jid) -- ^ Our JID
- , sCloseConnection :: !(IO ()) -- ^ necessary steps to cleanly
- -- close the connection (send TLS
- -- bye etc.)
, sPreferredLang :: !(Maybe LangTag) -- ^ Default language when
-- no explicit language
-- tag is set
diff --git a/tests/Tests.hs b/tests/Tests.hs
index 1e7f45e..d9b2db3 100644
--- a/tests/Tests.hs
+++ b/tests/Tests.hs
@@ -173,9 +173,11 @@ runMain debug number multi = do
endSession s) (session context)
debug' "running"
flip withConnection (session context) $ Ex.catch (do
+ debug' "connect"
connect "localhost" (PortNumber 5222) "species64739.dyndns.org"
+-- debug' "tls start"
startTLS exampleParams
- -- debug' "ibr start"
+ debug' "ibr start"
-- ibrTest debug' (localpart we) "pwd"
-- debug' "ibr end"
saslResponse <- simpleAuth