diff --git a/source/Data/Conduit/BufferedSource.hs b/source/Data/Conduit/BufferedSource.hs index cfb620e..efe6a76 100644 --- a/source/Data/Conduit/BufferedSource.hs +++ b/source/Data/Conduit/BufferedSource.hs @@ -14,21 +14,19 @@ data SourceClosed = SourceClosed deriving (Show, Typeable) instance Exception SourceClosed +newtype BufferedSource m o = BufferedSource + { bs :: IORef (ResumableSource m o) + } + + -- | Buffered source from conduit 0.3 -bufferSource :: MonadIO m => Source m o -> IO (Source m o) +bufferSource :: Monad m => Source m o -> IO (BufferedSource m o) bufferSource s = do - srcRef <- newIORef . Just $ DCI.ResumableSource s (return ()) - return $ do - src' <- liftIO $ readIORef srcRef - src <- case src' of - Just s -> return s - Nothing -> liftIO $ throwIO SourceClosed - let go src = do - (src', res) <- lift $ src $$++ CL.head - case res of - Nothing -> liftIO $ writeIORef srcRef Nothing - Just x -> do - liftIO (writeIORef srcRef $ Just src') - yield x - go src' - in go src + srcRef <- newIORef $ DCI.ResumableSource s (return ()) + return $ BufferedSource srcRef + +(.$$+) (BufferedSource bs) snk = do + src <- liftIO $ readIORef bs + (src', r) <- src $$++ snk + liftIO $ writeIORef bs src' + return r diff --git a/source/Data/Conduit/TLS.hs b/source/Data/Conduit/TLS.hs index 2dbdb28..68fa23b 100644 --- a/source/Data/Conduit/TLS.hs +++ b/source/Data/Conduit/TLS.hs @@ -8,38 +8,42 @@ module Data.Conduit.TLS ) where -import Control.Monad(liftM, when) -import Control.Monad.IO.Class +import Control.Monad +import Control.Monad (liftM, when) +import Control.Monad.IO.Class -import Crypto.Random +import Crypto.Random import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL -import Data.Conduit -import Control.Monad +import Data.Conduit +import qualified Data.Conduit.Binary as CB +import Data.IORef -import Network.TLS as TLS -import Network.TLS.Extra as TLSExtra +import Network.TLS as TLS +import Network.TLS.Extra as TLSExtra -import System.IO(Handle) +import System.IO (Handle) -client params gen handle = do - contextNewOnHandle handle params gen +client params gen backend = do + contextNew backend params gen defaultParams = defaultParamsClient tlsinit :: (MonadIO m, MonadIO m1) => Bool -> TLSParams - -> Handle -> m ( Source m1 BS.ByteString - , Sink BS.ByteString m1 () - , BS.ByteString -> IO () - , Context - ) -tlsinit debug tlsParams handle = do + -> Backend + -> m ( Source m1 BS.ByteString + , Sink BS.ByteString m1 () + , BS.ByteString -> IO () + , Int -> m1 BS.ByteString + , Context + ) +tlsinit debug tlsParams backend = do when debug . liftIO $ putStrLn "TLS with debug mode enabled" gen <- liftIO $ (newGenIO :: IO SystemRandom) -- TODO: Find better random source? - con <- client tlsParams gen handle + con <- client tlsParams gen backend handshake con let src = forever $ do dt <- liftIO $ recvData con @@ -53,10 +57,24 @@ tlsinit debug tlsParams handle = do sendData con (BL.fromChunks [x]) when debug (liftIO $ putStr "out: " >> BS.putStrLn x) snk + read <- liftIO $ mkReadBuffer (recvData con) return ( src , snk , \s -> do when debug (liftIO $ BS.putStrLn s) sendData con $ BL.fromChunks [s] + , liftIO . read , con ) + +mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString) +mkReadBuffer read = do + buffer <- newIORef BS.empty + let read' n = do + nc <- readIORef buffer + bs <- if BS.null nc then read + else return nc + let (result, rest) = BS.splitAt n bs + writeIORef buffer rest + return result + return read' diff --git a/source/Network/Xmpp/Concurrent/Monad.hs b/source/Network/Xmpp/Concurrent/Monad.hs index 9046e6f..2f17ea9 100644 --- a/source/Network/Xmpp/Concurrent/Monad.hs +++ b/source/Network/Xmpp/Concurrent/Monad.hs @@ -56,7 +56,7 @@ withConnection a session = do (do (res, s') <- runStateT a s atomically $ do - putTMVar (writeRef session) (sConPushBS s') + putTMVar (writeRef session) (cSend . sCon $ s') putTMVar (conStateRef session) s' return $ Right res ) @@ -102,7 +102,7 @@ endSession session = do -- TODO: This has to be idempotent (is it?) closeConnection :: Session -> IO () closeConnection session = Ex.mask_ $ do send <- atomically $ takeTMVar (writeRef session) - cc <- sCloseConnection <$> ( atomically $ readTMVar (conStateRef session)) + cc <- cClose . sCon <$> ( atomically $ readTMVar (conStateRef session)) send "" void . forkIO $ do threadDelay 3000000 diff --git a/source/Network/Xmpp/Monad.hs b/source/Network/Xmpp/Monad.hs index 63503ba..e2cea4e 100644 --- a/source/Network/Xmpp/Monad.hs +++ b/source/Network/Xmpp/Monad.hs @@ -16,9 +16,10 @@ import Control.Monad.State.Strict import Data.ByteString as BS import Data.Conduit -import qualified Data.Conduit.List as CL -import Data.Conduit.BufferedSource import Data.Conduit.Binary as CB +import Data.Conduit.BufferedSource +import qualified Data.Conduit.List as CL +import Data.IORef import Data.Text(Text) import Data.XML.Pickle import Data.XML.Types @@ -42,8 +43,8 @@ debug = False pushElement :: Element -> XmppConMonad Bool pushElement x = do - sink <- gets sConPushBS - liftIO . sink $ renderElement x + send <- gets (cSend . sCon) + liftIO . send $ renderElement x -- | Encode and send stanza pushStanza :: Stanza -> XmppConMonad Bool @@ -55,26 +56,26 @@ pushStanza = pushElement . pickleElem xpStanza -- XMPP streams. RFC 6120 defines XMPP only in terms of XML 1.0. pushXmlDecl :: XmppConMonad Bool pushXmlDecl = do - sink <- gets sConPushBS - liftIO $ sink "" + con <- gets sCon + liftIO $ (cSend con) "" pushOpenElement :: Element -> XmppConMonad Bool pushOpenElement e = do - sink <- gets sConPushBS + sink <- gets (cSend . sCon ) liftIO . sink $ renderOpenElement e -- `Connect-and-resumes' the given sink to the connection source, and pulls a -- `b' value. -pullToSink :: Sink Event IO b -> XmppConMonad b -pullToSink snk = do - source <- gets sConSrc - (_, r) <- lift $ source $$+ snk +pullToSinkEvents :: Sink Event IO b -> XmppConMonad b +pullToSinkEvents snk = do + source <- gets (cEventSource . sCon) + r <- lift $ source .$$+ snk return r pullElement :: XmppConMonad Element pullElement = do Ex.catches (do - e <- pullToSink (elements =$ CL.head) + e <- pullToSinkEvents (elements =$ await) case e of Nothing -> liftIO $ Ex.throwIO StreamConnectionError Just r -> return r @@ -106,8 +107,8 @@ pullStanza = do -- Performs the given IO operation, catches any errors and re-throws everything -- except 'ResourceVanished' and IllegalOperation, in which case it will return False instead -catchPush :: IO () -> IO Bool -catchPush p = Ex.catch +catchSend :: IO () -> IO Bool +catchSend p = Ex.catch (p >> return True) (\e -> case GIE.ioe_type e of GIE.ResourceVanished -> return False @@ -115,18 +116,20 @@ catchPush p = Ex.catch _ -> Ex.throwIO e ) --- XmppConnection state used when there is no connection. +-- -- XmppConnection state used when there is no connection. xmppNoConnection :: XmppConnection xmppNoConnection = XmppConnection - { sConSrc = zeroSource - , sRawSrc = zeroSource - , sConPushBS = \_ -> return False -- Nothing has been sent. - , sConHandle = Nothing + { sCon = Connection { cSend = \_ -> return False + , cRecv = \_ -> Ex.throwIO + $ StreamConnectionError + , cEventSource = undefined + , cFlush = return () + , cClose = return () + } , sFeatures = SF Nothing [] [] , sConnectionState = XmppConnectionClosed , sHostname = Nothing , sJid = Nothing - , sCloseConnection = return () , sStreamLang = Nothing , sStreamId = Nothing , sPreferredLang = Nothing @@ -140,30 +143,34 @@ xmppNoConnection = XmppConnection -- Connects to the given hostname on port 5222 (TODO: Make this dynamic) and -- updates the XmppConMonad XmppConnection state. -xmppRawConnect :: HostName -> PortID -> Text -> XmppConMonad () -xmppRawConnect host port hostname = do - con <- liftIO $ do - con <- connectTo host port - hSetBuffering con NoBuffering - return con - let raw = if debug - then sourceHandle con $= debugConduit - else sourceHandle con - src <- liftIO . bufferSource $ raw $= XP.parseBytes def +xmppConnectTCP :: HostName -> PortID -> Text -> XmppConMonad () +xmppConnectTCP host port hostname = do + hand <- liftIO $ do + h <- connectTo host port + hSetBuffering h NoBuffering + return h + eSource <- liftIO . bufferSource $ (sourceHandle hand) $= XP.parseBytes def + let con = Connection { cSend = if debug + then \d -> do + BS.putStrLn (BS.append "out: " d) + catchSend $ BS.hPut hand d + else catchSend . BS.hPut hand + , cRecv = if debug then + \n -> do + bs <- BS.hGetSome hand n + BS.putStrLn bs + return bs + else BS.hGetSome hand + , cEventSource = eSource + , cFlush = hFlush hand + , cClose = hClose hand + } let st = XmppConnection - { sConSrc = src - , sRawSrc = raw - , sConPushBS = if debug - then \d -> do - BS.putStrLn (BS.append "out: " d) - catchPush $ BS.hPut con d - else catchPush . BS.hPut con - , sConHandle = (Just con) + { sCon = con , sFeatures = (SF Nothing [] []) , sConnectionState = XmppConnectionPlain , sHostname = (Just hostname) , sJid = Nothing - , sCloseConnection = (hClose con) , sPreferredLang = Nothing -- TODO: Allow user to set , sStreamLang = Nothing , sStreamId = Nothing @@ -180,11 +187,18 @@ xmppNewSession action = runStateT action xmppNoConnection -- Closes the connection and updates the XmppConMonad XmppConnection state. xmppKillConnection :: XmppConMonad (Either Ex.SomeException ()) xmppKillConnection = do - cc <- gets sCloseConnection + cc <- gets (cClose . sCon) err <- liftIO $ (Ex.try cc :: IO (Either Ex.SomeException ())) put xmppNoConnection return err +xmppReplaceConnection :: XmppConnection -> XmppConMonad (Either Ex.SomeException ()) +xmppReplaceConnection newCon = do + cc <- gets (cClose . sCon) + err <- liftIO $ (Ex.try cc :: IO (Either Ex.SomeException ())) + put newCon + return err + -- Sends an IQ request and waits for the response. If the response ID does not -- match the outgoing ID, an error is thrown. xmppSendIQ' :: StanzaId @@ -211,8 +225,8 @@ xmppSendIQ' iqID to tp lang body = do -- not we received a element from the server is returned. xmppCloseStreams :: XmppConMonad ([Element], Bool) xmppCloseStreams = do - send <- gets sConPushBS - cc <- gets sCloseConnection + send <- gets (cSend . sCon) + cc <- gets (cClose . sCon) liftIO $ send "" void $ liftIO $ forkIO $ do threadDelay 3000000 diff --git a/source/Network/Xmpp/Session.hs b/source/Network/Xmpp/Session.hs index a282512..8e84ad2 100644 --- a/source/Network/Xmpp/Session.hs +++ b/source/Network/Xmpp/Session.hs @@ -58,7 +58,7 @@ simpleConnect host port hostname username password resource = do -- | Connect to host with given address. connect :: HostName -> PortID -> Text -> XmppConMonad (Either StreamError ()) connect address port hostname = do - xmppRawConnect address port hostname + xmppConnectTCP address port hostname result <- xmppStartStream case result of Left e -> do diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs index 888c2ad..6f8dd5e 100644 --- a/source/Network/Xmpp/Stream.hs +++ b/source/Network/Xmpp/Stream.hs @@ -9,14 +9,15 @@ import qualified Control.Exception as Ex import Control.Monad.Error import Control.Monad.State.Strict +import qualified Data.ByteString as BS import Data.Conduit import Data.Conduit.BufferedSource import Data.Conduit.List as CL import Data.Maybe (fromJust, isJust, isNothing) import Data.Text as Text +import Data.Void (Void) import Data.XML.Pickle import Data.XML.Types -import Data.Void (Void) import Network.Xmpp.Monad import Network.Xmpp.Pickle @@ -79,7 +80,8 @@ xmppStartStream = runErrorT $ do , Nothing , sPreferredLang state ) - (lt, from, id, features) <- ErrorT . pullToSink $ runErrorT $ xmppStream from + (lt, from, id, features) <- ErrorT . pullToSinkEvents $ runErrorT $ + xmppStream from modify (\s -> s { sFeatures = features , sStreamLang = Just lt , sStreamId = id @@ -92,10 +94,16 @@ xmppStartStream = runErrorT $ do -- and calls xmppStartStream. xmppRestartStream :: XmppConMonad (Either StreamError ()) xmppRestartStream = do - raw <- gets sRawSrc - newsrc <- liftIO . bufferSource $ raw $= XP.parseBytes def - modify (\s -> s{sConSrc = newsrc}) + raw <- gets (cRecv . sCon) + newSrc <- liftIO . bufferSource $ loopRead raw $= XP.parseBytes def + modify (\s -> s{sCon = (sCon s){cEventSource = newSrc}}) xmppStartStream + where + loopRead read = do + bs <- liftIO (read 4096) + if BS.null bs + then return () + else yield bs >> loopRead read -- Reads the (partial) stream:stream and the server features from the stream. -- Also validates the stream element's attributes and throws an error if @@ -170,4 +178,4 @@ xpStreamFeatures = xpWrap pickleSaslFeature = xpElemNodes "{urn:ietf:params:xml:ns:xmpp-sasl}mechanisms" (xpAll $ xpElemNodes - "{urn:ietf:params:xml:ns:xmpp-sasl}mechanism" (xpContent xpId)) \ No newline at end of file + "{urn:ietf:params:xml:ns:xmpp-sasl}mechanism" (xpContent xpId)) diff --git a/source/Network/Xmpp/TLS.hs b/source/Network/Xmpp/TLS.hs index c50aebf..26bbe6a 100644 --- a/source/Network/Xmpp/TLS.hs +++ b/source/Network/Xmpp/TLS.hs @@ -9,6 +9,10 @@ import Control.Monad import Control.Monad.Error import Control.Monad.State.Strict +import qualified Data.ByteString as BS +import qualified Data.ByteString.Lazy as BL +import Data.Conduit +import qualified Data.Conduit.Binary as CB import Data.Conduit.TLS as TLS import Data.Typeable import Data.XML.Types @@ -18,6 +22,42 @@ import Network.Xmpp.Pickle(ppElement) import Network.Xmpp.Stream import Network.Xmpp.Types +mkBackend con = Backend { backendSend = \bs -> void (cSend con bs) + , backendRecv = cRecv con + , backendFlush = cFlush con + , backendClose = cClose con + } + where + cutBytes n = do + liftIO $ putStrLn "awaiting" + mbs <- await + liftIO $ putStrLn "done awaiting" + case mbs of + Nothing -> return BS.empty + Just bs -> do + let (a, b) = BS.splitAt n bs + liftIO . putStrLn $ + "remaining" ++ (show $ BS.length b) ++ " of " ++ (show n) + + unless (BS.null b) $ leftover b + return a + + +cutBytes n = do + liftIO $ putStrLn "awaiting" + mbs <- await + liftIO $ putStrLn "done awaiting" + case mbs of + Nothing -> return False + Just bs -> do + let (a, b) = BS.splitAt n bs + liftIO . putStrLn $ + "remaining" ++ (show $ BS.length b) ++ " of " ++ (show n) + + unless (BS.null b) $ leftover b + return True + + starttlsE :: Element starttlsE = Element "{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] [] @@ -36,6 +76,7 @@ exampleParams = TLS.defaultParamsClient data XmppTLSError = TLSError TLSError | TLSNoServerSupport | TLSNoConnection + | TLSConnectionSecured -- ^ Connection already secured | TLSStreamError StreamError | XmppTLSError -- General instance used for the Error instance deriving (Show, Eq, Typeable) @@ -48,8 +89,12 @@ instance Error XmppTLSError where startTLS :: TLS.TLSParams -> XmppConMonad (Either XmppTLSError ()) startTLS params = Ex.handle (return . Left . TLSError) . runErrorT $ do features <- lift $ gets sFeatures - handle' <- lift $ gets sConHandle - handle <- maybe (throwError TLSNoConnection) return handle' + state <- gets sConnectionState + case state of + XmppConnectionPlain -> return () + XmppConnectionClosed -> throwError TLSNoConnection + XmppConnectionSecured -> throwError TLSConnectionSecured + con <- lift $ gets sCon when (stls features == Nothing) $ throwError TLSNoServerSupport lift $ pushElement starttlsE answer <- lift $ pullElement @@ -60,14 +105,15 @@ startTLS params = Ex.handle (return . Left . TLSError) . runErrorT $ do -- TODO: find something more suitable e -> lift . Ex.throwIO . StreamXMLError $ "Unexpected element: " ++ ppElement e - (raw, _snk, psh, ctx) <- lift $ TLS.tlsinit debug params handle - lift $ modify ( \x -> x - { sRawSrc = raw --- , sConSrc = -- Note: this momentarily leaves us in an - -- inconsistent state - , sConPushBS = catchPush . psh - , sCloseConnection = TLS.bye ctx >> sCloseConnection x - }) + liftIO $ putStrLn "#" + (raw, _snk, psh, read, ctx) <- lift $ TLS.tlsinit debug params (mkBackend con) + liftIO $ putStrLn "*" + let newCon = Connection { cSend = catchSend . psh + , cRecv = read + , cFlush = contextFlush ctx + , cClose = bye ctx >> cClose con + } + lift $ modify ( \x -> x {sCon = newCon}) either (lift . Ex.throwIO) return =<< lift xmppRestartStream modify (\s -> s{sConnectionState = XmppConnectionSecured}) return () diff --git a/source/Network/Xmpp/Types.hs b/source/Network/Xmpp/Types.hs index a357437..3806477 100644 --- a/source/Network/Xmpp/Types.hs +++ b/source/Network/Xmpp/Types.hs @@ -32,6 +32,7 @@ module Network.Xmpp.Types , StreamErrorCondition(..) , Version(..) , XmppConMonad + , Connection(..) , XmppConnection(..) , XmppConnectionState(..) , XmppT(..) @@ -50,8 +51,10 @@ import Control.Monad.Error import qualified Data.Attoparsec.Text as AP import qualified Data.ByteString as BS import Data.Conduit -import Data.String(IsString(..)) +import Data.Conduit.BufferedSource +import Data.IORef import Data.Maybe (fromJust, fromMaybe, maybeToList) +import Data.String(IsString(..)) import Data.Text (Text) import qualified Data.Text as Text import Data.Typeable(Typeable) @@ -740,21 +743,24 @@ data XmppConnectionState | XmppConnectionSecured -- ^ Connection established and secured via TLS. deriving (Show, Eq, Typeable) +data Connection = Connection { cSend :: BS.ByteString -> IO Bool + , cRecv :: Int -> IO BS.ByteString + -- This is to hold the state of the XML parser + -- (otherwise we will receive lot's of EvenBegin + -- Document and forger about name prefixes) + , cEventSource :: BufferedSource IO Event + + , cFlush :: IO () + , cClose :: IO () + } + data XmppConnection = XmppConnection - { sConSrc :: !(Source IO Event) -- ^ inbound connection - , sRawSrc :: !(Source IO BS.ByteString) -- ^ inbound - -- connection - , sConPushBS :: !(BS.ByteString -> IO Bool) -- ^ outbound - -- connection - , sConHandle :: !(Maybe Handle) -- ^ Handle for TLS + { sCon :: Connection , sFeatures :: !ServerFeatures -- ^ Features the server -- advertised , sConnectionState :: !XmppConnectionState -- ^ State of connection , sHostname :: !(Maybe Text) -- ^ Hostname of the server , sJid :: !(Maybe Jid) -- ^ Our JID - , sCloseConnection :: !(IO ()) -- ^ necessary steps to cleanly - -- close the connection (send TLS - -- bye etc.) , sPreferredLang :: !(Maybe LangTag) -- ^ Default language when -- no explicit language -- tag is set diff --git a/tests/Tests.hs b/tests/Tests.hs index 1e7f45e..d9b2db3 100644 --- a/tests/Tests.hs +++ b/tests/Tests.hs @@ -173,9 +173,11 @@ runMain debug number multi = do endSession s) (session context) debug' "running" flip withConnection (session context) $ Ex.catch (do + debug' "connect" connect "localhost" (PortNumber 5222) "species64739.dyndns.org" +-- debug' "tls start" startTLS exampleParams - -- debug' "ibr start" + debug' "ibr start" -- ibrTest debug' (localpart we) "pwd" -- debug' "ibr end" saslResponse <- simpleAuth