diff --git a/source/Network/Xmpp/Lens.hs b/source/Network/Xmpp/Lens.hs index 0454af6..9b0dcd8 100644 --- a/source/Network/Xmpp/Lens.hs +++ b/source/Network/Xmpp/Lens.hs @@ -48,6 +48,16 @@ module Network.Xmpp.Lens , establishSessionL , tlsBehaviourL , tlsParamsL + -- **** TLS parameters + , clientServerIdentificationL + , tlsServerIdentificationL + , clientSupportedL + , supportedCiphersL + , supportedVersionsL + , tlsSupportedCiphersL + , tlsSupportedVersionsL + , clientUseServerNameIndicationL + , tlsUseNameIndicationL -- *** 'SessionConfiguration' , streamConfigurationL , onConnectionClosedL @@ -105,12 +115,13 @@ import qualified Data.Text as Text import Data.Text(Text) import Data.XML.Types(Element) import Network.DNS(ResolvConf) -import Network.TLS (ClientParams) +import Network.TLS as TLS import Network.Xmpp.Concurrent.Types -import Network.Xmpp.IM.Roster.Types import Network.Xmpp.IM.Message import Network.Xmpp.IM.Presence +import Network.Xmpp.IM.Roster.Types import Network.Xmpp.Types +import qualified Data.ByteString as BS -- | Van-Laarhoven lenses. type Lens a b = Functor f => (b -> f b) -> a -> f a @@ -121,7 +132,6 @@ type Traversal a b = Applicative f => (b -> f b) -> a -> f a -- Accessors --------------- - -- | Read the value the lens is pointing to view :: Lens a b -> a -> b view l x = getConst $ l Const x @@ -390,11 +400,36 @@ tlsBehaviourL :: Lens StreamConfiguration TlsBehaviour tlsBehaviourL inj sc@StreamConfiguration{tlsBehaviour = x} = (\x' -> sc{tlsBehaviour = x'}) <$> inj x + tlsParamsL :: Lens StreamConfiguration ClientParams tlsParamsL inj sc@StreamConfiguration{tlsParams = x} = (\x' -> sc{tlsParams = x'}) <$> inj x --- SessioConfiguration +-- TLS parameters +----------------- + +clientServerIdentificationL :: Lens ClientParams (String, BS.ByteString) +clientServerIdentificationL inj cp@ClientParams{clientServerIdentification = x} + = (\x' -> cp{clientServerIdentification = x'}) <$> inj x + +clientSupportedL :: Lens ClientParams Supported +clientSupportedL inj cp@ClientParams{clientSupported = x} + = (\x' -> cp{clientSupported = x'}) <$> inj x + +clientUseServerNameIndicationL :: Lens ClientParams Bool +clientUseServerNameIndicationL inj + cp@ClientParams{clientUseServerNameIndication = x} + = (\x' -> cp{clientUseServerNameIndication = x'}) <$> inj x + +supportedCiphersL :: Lens Supported [Cipher] +supportedCiphersL inj s@Supported{supportedCiphers = x} + = (\x' -> s{supportedCiphers = x'}) <$> inj x + +supportedVersionsL :: Lens Supported [TLS.Version] +supportedVersionsL inj s@Supported{supportedVersions = x} + = (\x' -> s{supportedVersions = x'}) <$> inj x + +-- SessionConfiguration ----------------------- streamConfigurationL :: Lens SessionConfiguration StreamConfiguration streamConfigurationL inj sc@SessionConfiguration{sessionStreamConfiguration = x} @@ -416,6 +451,29 @@ pluginsL :: Lens SessionConfiguration [Plugin] pluginsL inj sc@SessionConfiguration{plugins = x} = (\x' -> sc{plugins = x'}) <$> inj x +-- | Access clientServerIdentification inside tlsParams inside streamConfiguration +tlsServerIdentificationL :: Lens SessionConfiguration (String, BS.ByteString) +tlsServerIdentificationL = streamConfigurationL + . tlsParamsL + . clientServerIdentificationL + +-- | Access clientUseServerNameIndication inside tlsParams +tlsUseNameIndicationL :: Lens SessionConfiguration Bool +tlsUseNameIndicationL = streamConfigurationL + . tlsParamsL + . clientUseServerNameIndicationL + +-- | Access supportedCiphers inside clientSupported inside tlsParams +tlsSupportedCiphersL :: Lens SessionConfiguration [Cipher] +tlsSupportedCiphersL = streamConfigurationL + . tlsParamsL . clientSupportedL . supportedCiphersL + +-- | Access supportedVersions inside clientSupported inside tlsParams +tlsSupportedVersionsL :: Lens SessionConfiguration [TLS.Version] +tlsSupportedVersionsL = streamConfigurationL + . tlsParamsL . clientSupportedL . supportedVersionsL + + -- Roster ------------------ diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs index ebd19b8..daa949f 100644 --- a/source/Network/Xmpp/Stream.hs +++ b/source/Network/Xmpp/Stream.hs @@ -36,7 +36,6 @@ import Data.XML.Pickle import Data.XML.Types import qualified GHC.IO.Exception as GIE import Network -import Network.TLS import Network.DNS hiding (encode, lookup) import Network.Xmpp.Marshal import Network.Xmpp.Types @@ -47,6 +46,7 @@ import System.Random (randomRIO) import Text.XML.Stream.Parse as XP import Network.Xmpp.Utilities +import qualified Network.Xmpp.Lens as L -- "readMaybe" definition, as readMaybe is not introduced in the `base' package -- until version 4.6. @@ -518,10 +518,9 @@ createStream :: HostName -> StreamConfiguration -> ErrorT XmppFailure IO (Stream createStream realm config = do result <- connect realm config case result of - Just (host, hand) -> ErrorT $ do + Just hand -> ErrorT $ do debugM "Pontarius.Xmpp" "Acquired handle." debugM "Pontarius.Xmpp" "Setting NoBuffering mode on handle." - debugM "Pontarius.Xmpp" $ "Setting TLS expected host to " ++ show host eSource <- liftIO . bufferSrc $ (sourceStreamHandle hand $= logConduit) $= XP.parseBytes def @@ -535,7 +534,7 @@ createStream realm config = do , streamId = Nothing , streamLang = Nothing , streamJid = Nothing - , streamConfiguration = setCertificateHost host config + , streamConfiguration = maybeSetTlsHost realm config } stream' <- mkStream stream return $ Right stream' @@ -548,17 +547,14 @@ createStream realm config = do liftIO . debugM "Pontarius.Xmpp" $ "In: " ++ (BSC8.unpack d) ++ "." return d - setCertificateHost host conf = - conf{tlsParams = - (tlsParams conf){clientServerIdentification = - case clientServerIdentification(tlsParams conf) of - (_, blob) -> (host, blob)}} - + tlsIdentL = L.tlsParamsL . L.clientServerIdentificationL + updateHost host ("", _) = (host, "") + updateHost _ hst = hst + maybeSetTlsHost host = L.modify tlsIdentL (updateHost host) -- Connects using the specified method. Returns the Handle acquired, if any. -connect :: HostName - -> StreamConfiguration - -> ErrorT XmppFailure IO (Maybe (HostName, StreamHandle)) +connect :: HostName -> StreamConfiguration -> ErrorT XmppFailure IO + (Maybe StreamHandle) connect realm config = do case connectionDetails config of UseHost host port -> lift $ do @@ -568,26 +564,24 @@ connect realm config = do Nothing -> return Nothing Just h' -> do liftIO $ hSetBuffering h' NoBuffering - return . Just $ (host, handleToStreamHandle h') + return . Just $ handleToStreamHandle h' UseSrv host -> do h <- connectSrv (resolvConf config) host case h of Nothing -> return Nothing - Just (hn, h') -> do + Just h' -> do liftIO $ hSetBuffering h' NoBuffering - return . Just $ (hn, handleToStreamHandle h') + return . Just $ handleToStreamHandle h' UseRealm -> do h <- connectSrv (resolvConf config) realm case h of Nothing -> return Nothing - Just (hn, h') -> do + Just h' -> do liftIO $ hSetBuffering h' NoBuffering - return $ Just (hn, handleToStreamHandle h') + return . Just $ handleToStreamHandle h' UseConnection mkC -> Just <$> mkC -connectSrv :: ResolvConf - -> String - -> ErrorT XmppFailure IO (Maybe (HostName, Handle)) +connectSrv :: ResolvConf -> String -> ErrorT XmppFailure IO (Maybe Handle) connectSrv config host = do case checkHostName (Text.pack host) of Just host' -> do @@ -598,9 +592,8 @@ connectSrv config host = do Nothing -> do lift $ debugM "Pontarius.Xmpp" "No SRV records, using fallback process." - h <- lift $ resolvAndConnectTcp resolvSeed (BSC8.pack $ host) - 5222 - return $ (\h' -> (host, h')) <$> h + lift $ resolvAndConnectTcp resolvSeed (BSC8.pack $ host) + 5222 Just srvRecords' -> do lift $ debugM "Pontarius.Xmpp" "SRV records found, performing A/AAAA lookups." @@ -681,17 +674,12 @@ resolvAndConnectTcp resolvSeed domain port = do -- Tries `resolvAndConnectTcp' for every SRV record, stopping if a handle is -- acquired. -resolvSrvsAndConnectTcp :: ResolvSeed - -> [(Domain, Int)] - -> IO (Maybe (HostName, Handle)) +resolvSrvsAndConnectTcp :: ResolvSeed -> [(Domain, Int)] -> IO (Maybe Handle) resolvSrvsAndConnectTcp _ [] = return Nothing resolvSrvsAndConnectTcp resolvSeed ((domain, port):remaining) = do result <- resolvAndConnectTcp resolvSeed domain port case result of - -- The last character of the target is always a dot in SRV records, so - -- we drop it. (Presumably the dns library should do that?) - Just handle -> return $ Just ( init . Text.unpack $ Text.decodeUtf8 $ domain - , handle) + Just handle -> return $ Just handle Nothing -> resolvSrvsAndConnectTcp resolvSeed remaining diff --git a/source/Network/Xmpp/Tls.hs b/source/Network/Xmpp/Tls.hs index 15f2524..3926c84 100644 --- a/source/Network/Xmpp/Tls.hs +++ b/source/Network/Xmpp/Tls.hs @@ -173,9 +173,9 @@ connectTls :: ResolvConf -- ^ Resolv conf to use (try 'defaultResolvConf' as a -> ClientParams -- ^ TLS parameters to use when securing the connection -> String -- ^ Host to use when connecting (will be resolved -- using SRV records) - -> ErrorT XmppFailure IO (String, StreamHandle) + -> ErrorT XmppFailure IO StreamHandle connectTls config params host = do - (hn, h) <- connectSrv config host >>= \h' -> case h' of + h <- connectSrv config host >>= \h' -> case h' of Nothing -> throwError TcpConnectionFailure Just h'' -> return h'' let hand = handleToStreamHandle h @@ -185,13 +185,11 @@ connectTls config params host = do csi -> csi } (_raw, _snk, psh, recv, ctx) <- tlsinit params' $ mkBackend hand - return $ ( hn - , StreamHandle { streamSend = catchPush . psh - , streamReceive = wrapExceptions . recv - , streamFlush = contextFlush ctx - , streamClose = bye ctx >> streamClose hand - } - ) + return StreamHandle{ streamSend = catchPush . psh + , streamReceive = wrapExceptions . recv + , streamFlush = contextFlush ctx + , streamClose = bye ctx >> streamClose hand + } wrapExceptions :: IO a -> IO (Either XmppFailure a) wrapExceptions f = Ex.catches (liftM Right $ f) diff --git a/source/Network/Xmpp/Types.hs b/source/Network/Xmpp/Types.hs index 6db6e6f..541e44a 100644 --- a/source/Network/Xmpp/Types.hs +++ b/source/Network/Xmpp/Types.hs @@ -1158,7 +1158,7 @@ data ConnectionDetails = UseRealm -- ^ Use realm to resolv host. This is the -- default. | UseSrv HostName -- ^ Use this hostname for a SRV lookup | UseHost HostName PortID -- ^ Use specified host - | UseConnection (ErrorT XmppFailure IO (HostName, StreamHandle)) + | UseConnection (ErrorT XmppFailure IO StreamHandle) -- ^ Use a custom method to create a StreamHandle. This -- will also be used by reconnect. For example, to -- establish TLS before starting the stream as done by @@ -1223,6 +1223,7 @@ instance Default StreamConfiguration where , establishSession = True , tlsBehaviour = PreferTls , tlsParams = xmppDefaultParams + , tlsOverrideHostname = Nothing } -- | How the client should behave in regards to TLS. diff --git a/tests/Run/Google.hs b/tests/Run/Google.hs index 5d12567..398e981 100644 --- a/tests/Run/Google.hs +++ b/tests/Run/Google.hs @@ -6,6 +6,7 @@ module Run.Google where import qualified Data.Configurator as Conf import Network.Xmpp +import Network.Xmpp.Lens import System.Exit import System.Log.Logger import Test.HUnit @@ -13,9 +14,7 @@ import Network.TLS import Run.Config -xmppConf = def {sessionStreamConfiguration = - def{tlsParams = (tlsParams def){clientUseServerNameIndication = False}} - } +xmppConf = set tlsServerIdentificationL ("talk.google.com", "") $ def connectGoogle = do conf <- loadConfig