diff --git a/pontarius-xmpp.cabal b/pontarius-xmpp.cabal index 625b60f..7eb9bbb 100644 --- a/pontarius-xmpp.cabal +++ b/pontarius-xmpp.cabal @@ -72,6 +72,7 @@ Library , xml-types >=0.3.1 , xml-conduit >=1.1.0.7 , xml-picklers >=0.3.3 + , x509-system >=1.4 If impl(ghc ==7.0.1) { Build-Depends: bytestring >=0.9.1.9 && <=0.9.2.1 diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs index 7f0723f..9c35131 100644 --- a/source/Network/Xmpp/Stream.hs +++ b/source/Network/Xmpp/Stream.hs @@ -36,6 +36,7 @@ import Data.XML.Pickle import Data.XML.Types import qualified GHC.IO.Exception as GIE import Network +import Network.TLS import Network.DNS hiding (encode, lookup) import Network.Xmpp.Marshal import Network.Xmpp.Types @@ -517,7 +518,7 @@ createStream :: HostName -> StreamConfiguration -> ErrorT XmppFailure IO (Stream createStream realm config = do result <- connect realm config case result of - Just hand -> ErrorT $ do + Just (host, hand) -> ErrorT $ do debugM "Pontarius.Xmpp" "Acquired handle." debugM "Pontarius.Xmpp" "Setting NoBuffering mode on handle." eSource <- liftIO . bufferSrc $ @@ -533,7 +534,7 @@ createStream realm config = do , streamId = Nothing , streamLang = Nothing , streamJid = Nothing - , streamConfiguration = config + , streamConfiguration = setCertificateHost host config } stream' <- mkStream stream return $ Right stream' @@ -546,10 +547,17 @@ createStream realm config = do liftIO . debugM "Pontarius.Xmpp" $ "In: " ++ (BSC8.unpack d) ++ "." return d + setCertificateHost host conf = + conf{tlsParams = + (tlsParams conf){clientServerIdentification = + case clientServerIdentification(tlsParams conf) of + (_, blob) -> (host, blob)}} + -- Connects using the specified method. Returns the Handle acquired, if any. -connect :: HostName -> StreamConfiguration -> ErrorT XmppFailure IO - (Maybe StreamHandle) +connect :: HostName + -> StreamConfiguration + -> ErrorT XmppFailure IO (Maybe (HostName, StreamHandle)) connect realm config = do case connectionDetails config of UseHost host port -> lift $ do @@ -559,24 +567,26 @@ connect realm config = do Nothing -> return Nothing Just h' -> do liftIO $ hSetBuffering h' NoBuffering - return . Just $ handleToStreamHandle h' + return . Just $ (host, handleToStreamHandle h') UseSrv host -> do h <- connectSrv (resolvConf config) host case h of Nothing -> return Nothing - Just h' -> do + Just (hn, h') -> do liftIO $ hSetBuffering h' NoBuffering - return . Just $ handleToStreamHandle h' + return . Just $ (hn, handleToStreamHandle h') UseRealm -> do h <- connectSrv (resolvConf config) realm case h of Nothing -> return Nothing - Just h' -> do + Just (hn, h') -> do liftIO $ hSetBuffering h' NoBuffering - return . Just $ handleToStreamHandle h' + return $ Just (hn, handleToStreamHandle h') UseConnection mkC -> Just <$> mkC -connectSrv :: ResolvConf -> String -> ErrorT XmppFailure IO (Maybe Handle) +connectSrv :: ResolvConf + -> String + -> ErrorT XmppFailure IO (Maybe (HostName, Handle)) connectSrv config host = do case checkHostName (Text.pack host) of Just host' -> do @@ -587,8 +597,9 @@ connectSrv config host = do Nothing -> do lift $ debugM "Pontarius.Xmpp" "No SRV records, using fallback process." - lift $ resolvAndConnectTcp resolvSeed (BSC8.pack $ host) - 5222 + h <- lift $ resolvAndConnectTcp resolvSeed (BSC8.pack $ host) + 5222 + return $ (\h' -> (host, h')) <$> h Just srvRecords' -> do lift $ debugM "Pontarius.Xmpp" "SRV records found, performing A/AAAA lookups." @@ -668,12 +679,17 @@ resolvAndConnectTcp resolvSeed domain port = do -- Tries `resolvAndConnectTcp' for every SRV record, stopping if a handle is -- acquired. -resolvSrvsAndConnectTcp :: ResolvSeed -> [(Domain, Int)] -> IO (Maybe Handle) +resolvSrvsAndConnectTcp :: ResolvSeed + -> [(Domain, Int)] + -> IO (Maybe (HostName, Handle)) resolvSrvsAndConnectTcp _ [] = return Nothing resolvSrvsAndConnectTcp resolvSeed ((domain, port):remaining) = do result <- resolvAndConnectTcp resolvSeed domain port case result of - Just handle -> return $ Just handle + -- The last character of the target is always a dot in SRV records, so + -- we drop it. (Presumably the dns library should do that?) + Just handle -> return $ Just ( init . Text.unpack $ Text.decodeUtf8 $ domain + , handle) Nothing -> resolvSrvsAndConnectTcp resolvSeed remaining diff --git a/source/Network/Xmpp/Tls.hs b/source/Network/Xmpp/Tls.hs index 7854171..2d2f044 100644 --- a/source/Network/Xmpp/Tls.hs +++ b/source/Network/Xmpp/Tls.hs @@ -16,12 +16,14 @@ import qualified Data.ByteString.Char8 as BSC8 import qualified Data.ByteString.Lazy as BL import Data.Conduit import Data.IORef +import Data.Monoid import Data.XML.Types import Network.DNS.Resolver (ResolvConf) import Network.TLS import Network.Xmpp.Stream import Network.Xmpp.Types import System.Log.Logger (debugM, errorM, infoM) +import System.X509 mkBackend :: StreamHandle -> Backend mkBackend con = Backend { backendSend = \bs -> void (streamSend con bs) @@ -54,7 +56,7 @@ tls con = fmap join -- We can have Left values both from exceptions and the . wrapExceptions . flip withStream con . runErrorT $ do - conf <- gets $ streamConfiguration + conf <- gets streamConfiguration sState <- gets streamConnectionState case sState of Plain -> return () @@ -123,7 +125,11 @@ tlsinit :: (MonadIO m, MonadIO m1) => tlsinit params backend = do liftIO $ debugM "Pontarius.Xmpp.Tls" "TLS with debug mode enabled." gen <- liftIO (cprgCreate <$> createEntropyPool :: IO SystemRNG) - con <- client params gen backend + sysCStore <- liftIO getSystemCertificateStore + let params' = params{clientShared = + (clientShared params){ sharedCAStore = + sysCStore <> sharedCAStore (clientShared params)}} + con <- client params' gen backend handshake con let src = forever $ do dt <- liftIO $ recvData con @@ -167,18 +173,20 @@ connectTls :: ResolvConf -- ^ Resolv conf to use (try 'defaultResolvConf' as a -> ClientParams -- ^ TLS parameters to use when securing the connection -> String -- ^ Host to use when connecting (will be resolved -- using SRV records) - -> ErrorT XmppFailure IO StreamHandle + -> ErrorT XmppFailure IO (String, StreamHandle) connectTls config params host = do - h <- connectSrv config host >>= \h' -> case h' of + (hn, h) <- connectSrv config host >>= \h' -> case h' of Nothing -> throwError TcpConnectionFailure Just h'' -> return h'' let hand = handleToStreamHandle h (_raw, _snk, psh, recv, ctx) <- tlsinit params $ mkBackend hand - return $ StreamHandle { streamSend = catchPush . psh - , streamReceive = wrapExceptions . recv - , streamFlush = contextFlush ctx - , streamClose = bye ctx >> streamClose hand - } + return $ ( hn + , StreamHandle { streamSend = catchPush . psh + , streamReceive = wrapExceptions . recv + , streamFlush = contextFlush ctx + , streamClose = bye ctx >> streamClose hand + } + ) wrapExceptions :: IO a -> IO (Either XmppFailure a) wrapExceptions f = Ex.catches (liftM Right $ f) diff --git a/source/Network/Xmpp/Types.hs b/source/Network/Xmpp/Types.hs index 1d9e559..6fa1045 100644 --- a/source/Network/Xmpp/Types.hs +++ b/source/Network/Xmpp/Types.hs @@ -1158,12 +1158,15 @@ data ConnectionDetails = UseRealm -- ^ Use realm to resolv host. This is the -- default. | UseSrv HostName -- ^ Use this hostname for a SRV lookup | UseHost HostName PortID -- ^ Use specified host - | UseConnection (ErrorT XmppFailure IO StreamHandle) - -- ^ Use custom method to create a StreamHandle. This + | UseConnection (ErrorT XmppFailure IO (HostName, StreamHandle)) + -- ^ Use a custom method to create a StreamHandle. This -- will also be used by reconnect. For example, to -- establish TLS before starting the stream as done by -- GCM, see 'connectTls'. You can also return an - -- already established connection. + -- already established connection. This method should + -- also return a hostname that is used for TLS + -- signature verification. If startTLS is not used it + -- can be left empty -- | Configuration settings related to the stream. data StreamConfiguration = @@ -1201,6 +1204,7 @@ xmppDefaultParams = (defaultParamsClient "" BS.empty) , cipher_AES128_SHA1 ] } + , clientUseServerNameIndication = True } instance Default StreamConfiguration where diff --git a/tests/Run.hs b/tests/Run.hs index 32b110f..fcc281e 100644 --- a/tests/Run.hs +++ b/tests/Run.hs @@ -16,13 +16,14 @@ import qualified Data.Text as Text import Network import Network.Xmpp import System.Directory +import System.Exit import System.FilePath import System.Log.Logger import System.Timeout import Test.HUnit import Test.Hspec.Expectations -import Run.Payload +import Run.Payload xmppConfig :: ConnectionDetails -> SessionConfiguration xmppConfig det = def{sessionStreamConfiguration @@ -82,10 +83,20 @@ main = void $ do Just "emergency" -> return EMERGENCY Just e -> error $ "Log level " ++ (Text.unpack e) ++ " unknown" updateGlobalLogger "Pontarius.Xmpp" $ setLevel loglevel - Right sess1 <- session realm (simpleAuth uname1 pwd1) + mbSess1 <- session realm (simpleAuth uname1 pwd1) ((xmppConfig conDetails)) - Right sess2 <- session realm (simpleAuth uname2 pwd2) + sess1 <- case mbSess1 of + Left e -> do + assertFailure $ "session 1 could not be initialized" ++ show e + exitFailure + Right r -> return r + mbSess2 <- session realm (simpleAuth uname2 pwd2) ((xmppConfig conDetails)) + sess2 <- case mbSess2 of + Left e -> do + assertFailure $ "session 2 could not be initialized" ++ show e + exitFailure + Right r -> return r Just jid1 <- getJid sess1 Just jid2 <- getJid sess2 _ <- sendPresence presenceOnline sess1