Browse Source

Fix BufferedSource (couldn't handle leftovers and was therefore useless)

Factor out connection from Session
Abstract over connection type (remove mention of Handle)
master
Philipp Balzarek 13 years ago
parent
commit
1e42a760de
  1. 30
      source/Data/Conduit/BufferedSource.hs
  2. 52
      source/Data/Conduit/TLS.hs
  3. 4
      source/Network/Xmpp/Concurrent/Monad.hs
  4. 98
      source/Network/Xmpp/Monad.hs
  5. 2
      source/Network/Xmpp/Session.hs
  6. 18
      source/Network/Xmpp/Stream.hs
  7. 66
      source/Network/Xmpp/TLS.hs
  8. 26
      source/Network/Xmpp/Types.hs
  9. 4
      tests/Tests.hs

30
source/Data/Conduit/BufferedSource.hs

@ -14,21 +14,19 @@ data SourceClosed = SourceClosed deriving (Show, Typeable) @@ -14,21 +14,19 @@ data SourceClosed = SourceClosed deriving (Show, Typeable)
instance Exception SourceClosed
newtype BufferedSource m o = BufferedSource
{ bs :: IORef (ResumableSource m o)
}
-- | Buffered source from conduit 0.3
bufferSource :: MonadIO m => Source m o -> IO (Source m o)
bufferSource :: Monad m => Source m o -> IO (BufferedSource m o)
bufferSource s = do
srcRef <- newIORef . Just $ DCI.ResumableSource s (return ())
return $ do
src' <- liftIO $ readIORef srcRef
src <- case src' of
Just s -> return s
Nothing -> liftIO $ throwIO SourceClosed
let go src = do
(src', res) <- lift $ src $$++ CL.head
case res of
Nothing -> liftIO $ writeIORef srcRef Nothing
Just x -> do
liftIO (writeIORef srcRef $ Just src')
yield x
go src'
in go src
srcRef <- newIORef $ DCI.ResumableSource s (return ())
return $ BufferedSource srcRef
(.$$+) (BufferedSource bs) snk = do
src <- liftIO $ readIORef bs
(src', r) <- src $$++ snk
liftIO $ writeIORef bs src'
return r

52
source/Data/Conduit/TLS.hs

@ -8,38 +8,42 @@ module Data.Conduit.TLS @@ -8,38 +8,42 @@ module Data.Conduit.TLS
)
where
import Control.Monad(liftM, when)
import Control.Monad.IO.Class
import Control.Monad
import Control.Monad (liftM, when)
import Control.Monad.IO.Class
import Crypto.Random
import Crypto.Random
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import Control.Monad
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import Data.IORef
import Network.TLS as TLS
import Network.TLS.Extra as TLSExtra
import Network.TLS as TLS
import Network.TLS.Extra as TLSExtra
import System.IO(Handle)
import System.IO (Handle)
client params gen handle = do
contextNewOnHandle handle params gen
client params gen backend = do
contextNew backend params gen
defaultParams = defaultParamsClient
tlsinit :: (MonadIO m, MonadIO m1) =>
Bool
-> TLSParams
-> Handle -> m ( Source m1 BS.ByteString
, Sink BS.ByteString m1 ()
, BS.ByteString -> IO ()
, Context
)
tlsinit debug tlsParams handle = do
-> Backend
-> m ( Source m1 BS.ByteString
, Sink BS.ByteString m1 ()
, BS.ByteString -> IO ()
, Int -> m1 BS.ByteString
, Context
)
tlsinit debug tlsParams backend = do
when debug . liftIO $ putStrLn "TLS with debug mode enabled"
gen <- liftIO $ (newGenIO :: IO SystemRandom) -- TODO: Find better random source?
con <- client tlsParams gen handle
con <- client tlsParams gen backend
handshake con
let src = forever $ do
dt <- liftIO $ recvData con
@ -53,10 +57,24 @@ tlsinit debug tlsParams handle = do @@ -53,10 +57,24 @@ tlsinit debug tlsParams handle = do
sendData con (BL.fromChunks [x])
when debug (liftIO $ putStr "out: " >> BS.putStrLn x)
snk
read <- liftIO $ mkReadBuffer (recvData con)
return ( src
, snk
, \s -> do
when debug (liftIO $ BS.putStrLn s)
sendData con $ BL.fromChunks [s]
, liftIO . read
, con
)
mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString)
mkReadBuffer read = do
buffer <- newIORef BS.empty
let read' n = do
nc <- readIORef buffer
bs <- if BS.null nc then read
else return nc
let (result, rest) = BS.splitAt n bs
writeIORef buffer rest
return result
return read'

4
source/Network/Xmpp/Concurrent/Monad.hs

@ -56,7 +56,7 @@ withConnection a session = do @@ -56,7 +56,7 @@ withConnection a session = do
(do
(res, s') <- runStateT a s
atomically $ do
putTMVar (writeRef session) (sConPushBS s')
putTMVar (writeRef session) (cSend . sCon $ s')
putTMVar (conStateRef session) s'
return $ Right res
)
@ -102,7 +102,7 @@ endSession session = do -- TODO: This has to be idempotent (is it?) @@ -102,7 +102,7 @@ endSession session = do -- TODO: This has to be idempotent (is it?)
closeConnection :: Session -> IO ()
closeConnection session = Ex.mask_ $ do
send <- atomically $ takeTMVar (writeRef session)
cc <- sCloseConnection <$> ( atomically $ readTMVar (conStateRef session))
cc <- cClose . sCon <$> ( atomically $ readTMVar (conStateRef session))
send "</stream:stream>"
void . forkIO $ do
threadDelay 3000000

98
source/Network/Xmpp/Monad.hs

@ -16,9 +16,10 @@ import Control.Monad.State.Strict @@ -16,9 +16,10 @@ import Control.Monad.State.Strict
import Data.ByteString as BS
import Data.Conduit
import qualified Data.Conduit.List as CL
import Data.Conduit.BufferedSource
import Data.Conduit.Binary as CB
import Data.Conduit.BufferedSource
import qualified Data.Conduit.List as CL
import Data.IORef
import Data.Text(Text)
import Data.XML.Pickle
import Data.XML.Types
@ -42,8 +43,8 @@ debug = False @@ -42,8 +43,8 @@ debug = False
pushElement :: Element -> XmppConMonad Bool
pushElement x = do
sink <- gets sConPushBS
liftIO . sink $ renderElement x
send <- gets (cSend . sCon)
liftIO . send $ renderElement x
-- | Encode and send stanza
pushStanza :: Stanza -> XmppConMonad Bool
@ -55,26 +56,26 @@ pushStanza = pushElement . pickleElem xpStanza @@ -55,26 +56,26 @@ pushStanza = pushElement . pickleElem xpStanza
-- XMPP streams. RFC 6120 defines XMPP only in terms of XML 1.0.
pushXmlDecl :: XmppConMonad Bool
pushXmlDecl = do
sink <- gets sConPushBS
liftIO $ sink "<?xml version='1.0' encoding='UTF-8' ?>"
con <- gets sCon
liftIO $ (cSend con) "<?xml version='1.0' encoding='UTF-8' ?>"
pushOpenElement :: Element -> XmppConMonad Bool
pushOpenElement e = do
sink <- gets sConPushBS
sink <- gets (cSend . sCon )
liftIO . sink $ renderOpenElement e
-- `Connect-and-resumes' the given sink to the connection source, and pulls a
-- `b' value.
pullToSink :: Sink Event IO b -> XmppConMonad b
pullToSink snk = do
source <- gets sConSrc
(_, r) <- lift $ source $$+ snk
pullToSinkEvents :: Sink Event IO b -> XmppConMonad b
pullToSinkEvents snk = do
source <- gets (cEventSource . sCon)
r <- lift $ source .$$+ snk
return r
pullElement :: XmppConMonad Element
pullElement = do
Ex.catches (do
e <- pullToSink (elements =$ CL.head)
e <- pullToSinkEvents (elements =$ await)
case e of
Nothing -> liftIO $ Ex.throwIO StreamConnectionError
Just r -> return r
@ -106,8 +107,8 @@ pullStanza = do @@ -106,8 +107,8 @@ pullStanza = do
-- Performs the given IO operation, catches any errors and re-throws everything
-- except 'ResourceVanished' and IllegalOperation, in which case it will return False instead
catchPush :: IO () -> IO Bool
catchPush p = Ex.catch
catchSend :: IO () -> IO Bool
catchSend p = Ex.catch
(p >> return True)
(\e -> case GIE.ioe_type e of
GIE.ResourceVanished -> return False
@ -115,18 +116,20 @@ catchPush p = Ex.catch @@ -115,18 +116,20 @@ catchPush p = Ex.catch
_ -> Ex.throwIO e
)
-- XmppConnection state used when there is no connection.
-- -- XmppConnection state used when there is no connection.
xmppNoConnection :: XmppConnection
xmppNoConnection = XmppConnection
{ sConSrc = zeroSource
, sRawSrc = zeroSource
, sConPushBS = \_ -> return False -- Nothing has been sent.
, sConHandle = Nothing
{ sCon = Connection { cSend = \_ -> return False
, cRecv = \_ -> Ex.throwIO
$ StreamConnectionError
, cEventSource = undefined
, cFlush = return ()
, cClose = return ()
}
, sFeatures = SF Nothing [] []
, sConnectionState = XmppConnectionClosed
, sHostname = Nothing
, sJid = Nothing
, sCloseConnection = return ()
, sStreamLang = Nothing
, sStreamId = Nothing
, sPreferredLang = Nothing
@ -140,30 +143,34 @@ xmppNoConnection = XmppConnection @@ -140,30 +143,34 @@ xmppNoConnection = XmppConnection
-- Connects to the given hostname on port 5222 (TODO: Make this dynamic) and
-- updates the XmppConMonad XmppConnection state.
xmppRawConnect :: HostName -> PortID -> Text -> XmppConMonad ()
xmppRawConnect host port hostname = do
con <- liftIO $ do
con <- connectTo host port
hSetBuffering con NoBuffering
return con
let raw = if debug
then sourceHandle con $= debugConduit
else sourceHandle con
src <- liftIO . bufferSource $ raw $= XP.parseBytes def
xmppConnectTCP :: HostName -> PortID -> Text -> XmppConMonad ()
xmppConnectTCP host port hostname = do
hand <- liftIO $ do
h <- connectTo host port
hSetBuffering h NoBuffering
return h
eSource <- liftIO . bufferSource $ (sourceHandle hand) $= XP.parseBytes def
let con = Connection { cSend = if debug
then \d -> do
BS.putStrLn (BS.append "out: " d)
catchSend $ BS.hPut hand d
else catchSend . BS.hPut hand
, cRecv = if debug then
\n -> do
bs <- BS.hGetSome hand n
BS.putStrLn bs
return bs
else BS.hGetSome hand
, cEventSource = eSource
, cFlush = hFlush hand
, cClose = hClose hand
}
let st = XmppConnection
{ sConSrc = src
, sRawSrc = raw
, sConPushBS = if debug
then \d -> do
BS.putStrLn (BS.append "out: " d)
catchPush $ BS.hPut con d
else catchPush . BS.hPut con
, sConHandle = (Just con)
{ sCon = con
, sFeatures = (SF Nothing [] [])
, sConnectionState = XmppConnectionPlain
, sHostname = (Just hostname)
, sJid = Nothing
, sCloseConnection = (hClose con)
, sPreferredLang = Nothing -- TODO: Allow user to set
, sStreamLang = Nothing
, sStreamId = Nothing
@ -180,11 +187,18 @@ xmppNewSession action = runStateT action xmppNoConnection @@ -180,11 +187,18 @@ xmppNewSession action = runStateT action xmppNoConnection
-- Closes the connection and updates the XmppConMonad XmppConnection state.
xmppKillConnection :: XmppConMonad (Either Ex.SomeException ())
xmppKillConnection = do
cc <- gets sCloseConnection
cc <- gets (cClose . sCon)
err <- liftIO $ (Ex.try cc :: IO (Either Ex.SomeException ()))
put xmppNoConnection
return err
xmppReplaceConnection :: XmppConnection -> XmppConMonad (Either Ex.SomeException ())
xmppReplaceConnection newCon = do
cc <- gets (cClose . sCon)
err <- liftIO $ (Ex.try cc :: IO (Either Ex.SomeException ()))
put newCon
return err
-- Sends an IQ request and waits for the response. If the response ID does not
-- match the outgoing ID, an error is thrown.
xmppSendIQ' :: StanzaId
@ -211,8 +225,8 @@ xmppSendIQ' iqID to tp lang body = do @@ -211,8 +225,8 @@ xmppSendIQ' iqID to tp lang body = do
-- not we received a </stream:stream> element from the server is returned.
xmppCloseStreams :: XmppConMonad ([Element], Bool)
xmppCloseStreams = do
send <- gets sConPushBS
cc <- gets sCloseConnection
send <- gets (cSend . sCon)
cc <- gets (cClose . sCon)
liftIO $ send "</stream:stream>"
void $ liftIO $ forkIO $ do
threadDelay 3000000

2
source/Network/Xmpp/Session.hs

@ -58,7 +58,7 @@ simpleConnect host port hostname username password resource = do @@ -58,7 +58,7 @@ simpleConnect host port hostname username password resource = do
-- | Connect to host with given address.
connect :: HostName -> PortID -> Text -> XmppConMonad (Either StreamError ())
connect address port hostname = do
xmppRawConnect address port hostname
xmppConnectTCP address port hostname
result <- xmppStartStream
case result of
Left e -> do

18
source/Network/Xmpp/Stream.hs

@ -9,14 +9,15 @@ import qualified Control.Exception as Ex @@ -9,14 +9,15 @@ import qualified Control.Exception as Ex
import Control.Monad.Error
import Control.Monad.State.Strict
import qualified Data.ByteString as BS
import Data.Conduit
import Data.Conduit.BufferedSource
import Data.Conduit.List as CL
import Data.Maybe (fromJust, isJust, isNothing)
import Data.Text as Text
import Data.Void (Void)
import Data.XML.Pickle
import Data.XML.Types
import Data.Void (Void)
import Network.Xmpp.Monad
import Network.Xmpp.Pickle
@ -79,7 +80,8 @@ xmppStartStream = runErrorT $ do @@ -79,7 +80,8 @@ xmppStartStream = runErrorT $ do
, Nothing
, sPreferredLang state
)
(lt, from, id, features) <- ErrorT . pullToSink $ runErrorT $ xmppStream from
(lt, from, id, features) <- ErrorT . pullToSinkEvents $ runErrorT $
xmppStream from
modify (\s -> s { sFeatures = features
, sStreamLang = Just lt
, sStreamId = id
@ -92,10 +94,16 @@ xmppStartStream = runErrorT $ do @@ -92,10 +94,16 @@ xmppStartStream = runErrorT $ do
-- and calls xmppStartStream.
xmppRestartStream :: XmppConMonad (Either StreamError ())
xmppRestartStream = do
raw <- gets sRawSrc
newsrc <- liftIO . bufferSource $ raw $= XP.parseBytes def
modify (\s -> s{sConSrc = newsrc})
raw <- gets (cRecv . sCon)
newSrc <- liftIO . bufferSource $ loopRead raw $= XP.parseBytes def
modify (\s -> s{sCon = (sCon s){cEventSource = newSrc}})
xmppStartStream
where
loopRead read = do
bs <- liftIO (read 4096)
if BS.null bs
then return ()
else yield bs >> loopRead read
-- Reads the (partial) stream:stream and the server features from the stream.
-- Also validates the stream element's attributes and throws an error if

66
source/Network/Xmpp/TLS.hs

@ -9,6 +9,10 @@ import Control.Monad @@ -9,6 +9,10 @@ import Control.Monad
import Control.Monad.Error
import Control.Monad.State.Strict
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import Data.Conduit.TLS as TLS
import Data.Typeable
import Data.XML.Types
@ -18,6 +22,42 @@ import Network.Xmpp.Pickle(ppElement) @@ -18,6 +22,42 @@ import Network.Xmpp.Pickle(ppElement)
import Network.Xmpp.Stream
import Network.Xmpp.Types
mkBackend con = Backend { backendSend = \bs -> void (cSend con bs)
, backendRecv = cRecv con
, backendFlush = cFlush con
, backendClose = cClose con
}
where
cutBytes n = do
liftIO $ putStrLn "awaiting"
mbs <- await
liftIO $ putStrLn "done awaiting"
case mbs of
Nothing -> return BS.empty
Just bs -> do
let (a, b) = BS.splitAt n bs
liftIO . putStrLn $
"remaining" ++ (show $ BS.length b) ++ " of " ++ (show n)
unless (BS.null b) $ leftover b
return a
cutBytes n = do
liftIO $ putStrLn "awaiting"
mbs <- await
liftIO $ putStrLn "done awaiting"
case mbs of
Nothing -> return False
Just bs -> do
let (a, b) = BS.splitAt n bs
liftIO . putStrLn $
"remaining" ++ (show $ BS.length b) ++ " of " ++ (show n)
unless (BS.null b) $ leftover b
return True
starttlsE :: Element
starttlsE = Element "{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] []
@ -36,6 +76,7 @@ exampleParams = TLS.defaultParamsClient @@ -36,6 +76,7 @@ exampleParams = TLS.defaultParamsClient
data XmppTLSError = TLSError TLSError
| TLSNoServerSupport
| TLSNoConnection
| TLSConnectionSecured -- ^ Connection already secured
| TLSStreamError StreamError
| XmppTLSError -- General instance used for the Error instance
deriving (Show, Eq, Typeable)
@ -48,8 +89,12 @@ instance Error XmppTLSError where @@ -48,8 +89,12 @@ instance Error XmppTLSError where
startTLS :: TLS.TLSParams -> XmppConMonad (Either XmppTLSError ())
startTLS params = Ex.handle (return . Left . TLSError) . runErrorT $ do
features <- lift $ gets sFeatures
handle' <- lift $ gets sConHandle
handle <- maybe (throwError TLSNoConnection) return handle'
state <- gets sConnectionState
case state of
XmppConnectionPlain -> return ()
XmppConnectionClosed -> throwError TLSNoConnection
XmppConnectionSecured -> throwError TLSConnectionSecured
con <- lift $ gets sCon
when (stls features == Nothing) $ throwError TLSNoServerSupport
lift $ pushElement starttlsE
answer <- lift $ pullElement
@ -60,14 +105,15 @@ startTLS params = Ex.handle (return . Left . TLSError) . runErrorT $ do @@ -60,14 +105,15 @@ startTLS params = Ex.handle (return . Left . TLSError) . runErrorT $ do
-- TODO: find something more suitable
e -> lift . Ex.throwIO . StreamXMLError $
"Unexpected element: " ++ ppElement e
(raw, _snk, psh, ctx) <- lift $ TLS.tlsinit debug params handle
lift $ modify ( \x -> x
{ sRawSrc = raw
-- , sConSrc = -- Note: this momentarily leaves us in an
-- inconsistent state
, sConPushBS = catchPush . psh
, sCloseConnection = TLS.bye ctx >> sCloseConnection x
})
liftIO $ putStrLn "#"
(raw, _snk, psh, read, ctx) <- lift $ TLS.tlsinit debug params (mkBackend con)
liftIO $ putStrLn "*"
let newCon = Connection { cSend = catchSend . psh
, cRecv = read
, cFlush = contextFlush ctx
, cClose = bye ctx >> cClose con
}
lift $ modify ( \x -> x {sCon = newCon})
either (lift . Ex.throwIO) return =<< lift xmppRestartStream
modify (\s -> s{sConnectionState = XmppConnectionSecured})
return ()

26
source/Network/Xmpp/Types.hs

@ -32,6 +32,7 @@ module Network.Xmpp.Types @@ -32,6 +32,7 @@ module Network.Xmpp.Types
, StreamErrorCondition(..)
, Version(..)
, XmppConMonad
, Connection(..)
, XmppConnection(..)
, XmppConnectionState(..)
, XmppT(..)
@ -50,8 +51,10 @@ import Control.Monad.Error @@ -50,8 +51,10 @@ import Control.Monad.Error
import qualified Data.Attoparsec.Text as AP
import qualified Data.ByteString as BS
import Data.Conduit
import Data.String(IsString(..))
import Data.Conduit.BufferedSource
import Data.IORef
import Data.Maybe (fromJust, fromMaybe, maybeToList)
import Data.String(IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Typeable(Typeable)
@ -740,21 +743,24 @@ data XmppConnectionState @@ -740,21 +743,24 @@ data XmppConnectionState
| XmppConnectionSecured -- ^ Connection established and secured via TLS.
deriving (Show, Eq, Typeable)
data Connection = Connection { cSend :: BS.ByteString -> IO Bool
, cRecv :: Int -> IO BS.ByteString
-- This is to hold the state of the XML parser
-- (otherwise we will receive lot's of EvenBegin
-- Document and forger about name prefixes)
, cEventSource :: BufferedSource IO Event
, cFlush :: IO ()
, cClose :: IO ()
}
data XmppConnection = XmppConnection
{ sConSrc :: !(Source IO Event) -- ^ inbound connection
, sRawSrc :: !(Source IO BS.ByteString) -- ^ inbound
-- connection
, sConPushBS :: !(BS.ByteString -> IO Bool) -- ^ outbound
-- connection
, sConHandle :: !(Maybe Handle) -- ^ Handle for TLS
{ sCon :: Connection
, sFeatures :: !ServerFeatures -- ^ Features the server
-- advertised
, sConnectionState :: !XmppConnectionState -- ^ State of connection
, sHostname :: !(Maybe Text) -- ^ Hostname of the server
, sJid :: !(Maybe Jid) -- ^ Our JID
, sCloseConnection :: !(IO ()) -- ^ necessary steps to cleanly
-- close the connection (send TLS
-- bye etc.)
, sPreferredLang :: !(Maybe LangTag) -- ^ Default language when
-- no explicit language
-- tag is set

4
tests/Tests.hs

@ -173,9 +173,11 @@ runMain debug number multi = do @@ -173,9 +173,11 @@ runMain debug number multi = do
endSession s) (session context)
debug' "running"
flip withConnection (session context) $ Ex.catch (do
debug' "connect"
connect "localhost" (PortNumber 5222) "species64739.dyndns.org"
-- debug' "tls start"
startTLS exampleParams
-- debug' "ibr start"
debug' "ibr start"
-- ibrTest debug' (localpart we) "pwd"
-- debug' "ibr end"
saslResponse <- simpleAuth

Loading…
Cancel
Save