diff --git a/pontarius-xmpp.cabal b/pontarius-xmpp.cabal index 0f9b3c1..97c25ad 100644 --- a/pontarius-xmpp.cabal +++ b/pontarius-xmpp.cabal @@ -57,8 +57,7 @@ Library Exposed-modules: Network.Xmpp , Network.Xmpp.Internal , Network.Xmpp.IM - Other-modules: Data.Conduit.Tls - , Network.Xmpp.Concurrent + Other-modules: Network.Xmpp.Concurrent , Network.Xmpp.Concurrent.Types , Network.Xmpp.Concurrent.Basic , Network.Xmpp.Concurrent.IQ diff --git a/source/Data/Conduit/Tls.hs b/source/Data/Conduit/Tls.hs deleted file mode 100644 index 0842ae5..0000000 --- a/source/Data/Conduit/Tls.hs +++ /dev/null @@ -1,81 +0,0 @@ -{-# Language NoMonomorphismRestriction #-} -{-# OPTIONS_HADDOCK hide #-} -module Data.Conduit.Tls - ( tlsinit --- , conduitStdout - , module TLS - , module TLSExtra - ) - where - -import Control.Monad -import Control.Monad (liftM, when) -import Control.Monad.IO.Class - -import Crypto.Random - -import qualified Data.ByteString as BS -import qualified Data.ByteString.Lazy as BL -import Data.Conduit -import qualified Data.Conduit.Binary as CB -import Data.IORef - -import Network.TLS as TLS -import Crypto.Random.API -import Network.TLS.Extra as TLSExtra - -import System.IO (Handle) - -client params gen backend = do - contextNew backend params gen - -defaultParams = defaultParamsClient - -tlsinit :: (MonadIO m, MonadIO m1) => - Bool - -> TLSParams - -> Backend - -> m ( Source m1 BS.ByteString - , Sink BS.ByteString m1 () - , BS.ByteString -> IO () - , Int -> m1 BS.ByteString - , Context - ) -tlsinit debug tlsParams backend = do - when debug . liftIO $ putStrLn "TLS with debug mode enabled" - gen <- liftIO $ getSystemRandomGen -- TODO: Find better random source? - con <- client tlsParams gen backend - handshake con - let src = forever $ do - dt <- liftIO $ recvData con - when debug (liftIO $ putStr "in: " >> BS.putStrLn dt) - yield dt - let snk = do - d <- await - case d of - Nothing -> return () - Just x -> do - sendData con (BL.fromChunks [x]) - when debug (liftIO $ putStr "out: " >> BS.putStrLn x) - snk - read <- liftIO $ mkReadBuffer (recvData con) - return ( src - , snk - , \s -> do - when debug (liftIO $ BS.putStrLn s) - sendData con $ BL.fromChunks [s] - , liftIO . read - , con - ) - -mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString) -mkReadBuffer read = do - buffer <- newIORef BS.empty - let read' n = do - nc <- readIORef buffer - bs <- if BS.null nc then read - else return nc - let (result, rest) = BS.splitAt n bs - writeIORef buffer rest - return result - return read' diff --git a/source/Network/Xmpp/Tls.hs b/source/Network/Xmpp/Tls.hs index eccfb93..88cf37e 100644 --- a/source/Network/Xmpp/Tls.hs +++ b/source/Network/Xmpp/Tls.hs @@ -13,7 +13,6 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import Data.Conduit import qualified Data.Conduit.Binary as CB -import Data.Conduit.Tls as TLS import Data.Typeable import Data.XML.Types @@ -22,6 +21,11 @@ import Network.Xmpp.Types import Control.Concurrent.STM.TMVar +import Data.IORef +import Crypto.Random.API +import Network.TLS +import Network.TLS.Extra + mkBackend con = Backend { backendSend = \bs -> void (streamSend con bs) , backendRecv = streamReceive con , backendFlush = streamFlush con @@ -61,20 +65,20 @@ cutBytes n = do starttlsE :: Element starttlsE = Element "{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] [] -exampleParams :: TLS.TLSParams -exampleParams = TLS.defaultParamsClient - { pConnectVersion = TLS.TLS10 - , pAllowedVersions = [TLS.SSL3, TLS.TLS10, TLS.TLS11] - , pCiphers = [TLS.cipher_AES128_SHA1] - , pCompressions = [TLS.nullCompression] +exampleParams :: TLSParams +exampleParams = defaultParamsClient + { pConnectVersion = TLS10 + , pAllowedVersions = [SSL3, TLS10, TLS11] + , pCiphers = [cipher_AES128_SHA1] + , pCompressions = [nullCompression] , pUseSecureRenegotiation = False -- No renegotiation , onCertificatesRecv = \_certificate -> - return TLS.CertificateUsageAccept + return CertificateUsageAccept } -- Pushes ", waits for "", performs the TLS handshake, and -- restarts the stream. -startTls :: TLS.TLSParams -> TMVar Stream -> IO (Either XmppFailure ()) +startTls :: TLSParams -> TMVar Stream -> IO (Either XmppFailure ()) startTls params con = Ex.handle (return . Left . TlsError) . flip withStream con . runErrorT $ do @@ -92,7 +96,7 @@ startTls params con = Ex.handle (return . Left . TlsError) Left e -> return $ Left e Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}proceed" [] []) -> return $ Right () Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}failure" _ _) -> return $ Left XmppOtherFailure - (raw, _snk, psh, read, ctx) <- lift $ TLS.tlsinit debug params (mkBackend con) + (raw, _snk, psh, read, ctx) <- lift $ tlsinit debug params (mkBackend con) let newHand = StreamHandle { streamSend = catchPush . psh , streamReceive = read , streamFlush = contextFlush ctx @@ -102,3 +106,57 @@ startTls params con = Ex.handle (return . Left . TlsError) either (lift . Ex.throwIO) return =<< lift restartStream modify (\s -> s{streamState = Secured}) return () + +client params gen backend = do + contextNew backend params gen + +defaultParams = defaultParamsClient + +tlsinit :: (MonadIO m, MonadIO m1) => + Bool + -> TLSParams + -> Backend + -> m ( Source m1 BS.ByteString + , Sink BS.ByteString m1 () + , BS.ByteString -> IO () + , Int -> m1 BS.ByteString + , Context + ) +tlsinit debug tlsParams backend = do + when debug . liftIO $ putStrLn "TLS with debug mode enabled" + gen <- liftIO $ getSystemRandomGen -- TODO: Find better random source? + con <- client tlsParams gen backend + handshake con + let src = forever $ do + dt <- liftIO $ recvData con + when debug (liftIO $ putStr "in: " >> BS.putStrLn dt) + yield dt + let snk = do + d <- await + case d of + Nothing -> return () + Just x -> do + sendData con (BL.fromChunks [x]) + when debug (liftIO $ putStr "out: " >> BS.putStrLn x) + snk + read <- liftIO $ mkReadBuffer (recvData con) + return ( src + , snk + , \s -> do + when debug (liftIO $ BS.putStrLn s) + sendData con $ BL.fromChunks [s] + , liftIO . read + , con + ) + +mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString) +mkReadBuffer read = do + buffer <- newIORef BS.empty + let read' n = do + nc <- readIORef buffer + bs <- if BS.null nc then read + else return nc + let (result, rest) = BS.splitAt n bs + writeIORef buffer rest + return result + return read'