diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs index d12dae4..0c7f908 100644 --- a/source/Network/Xmpp/Stream.hs +++ b/source/Network/Xmpp/Stream.hs @@ -50,7 +50,7 @@ import Data.Ord import Data.Maybe import Data.List import Data.IP - +import System.Random -- import Text.XML.Stream.Elements @@ -555,10 +555,8 @@ srvLookup realm resolvSeed = ErrorT $ do Just srvResult -> do debugM "Pontarius.Xmpp" $ "SRV result: " ++ (show srvResult) -- Get [(Domain, PortNumber)] of SRV request, if any. - return $ Just $ Prelude.map (\(_, _, port, domain) -> (domain, fromIntegral port)) $ - sortBy (comparing $ \(prio, _weight, _, _) -> prio) srvResult - -- TODO: Perform the `Weight' probability calculations of - -- . + srvResult' <- orderSrvResult srvResult + return $ Just $ Prelude.map (\(_, _, port, domain) -> (domain, fromIntegral port)) srvResult' -- The service is not available at this domain. -- Sorts the records based on the priority value. Just [(_, _, _, ".")] -> do @@ -570,6 +568,44 @@ srvLookup realm resolvSeed = ErrorT $ do case result of Right result' -> return $ Right result' Left e -> return $ Left $ XmppIOException e + where + -- This function orders the SRV result in accordance with RFC + -- 2782. It sorts the SRV results in order of priority, and then + -- uses a random process to order the records with the same + -- priority based on their weight. + orderSrvResult :: [(Int, Int, Int, Domain)] -> IO [(Int, Int, Int, Domain)] + orderSrvResult srvResult = do + -- Order the result set by priority. + let srvResult' = sortBy (comparing (\(priority, _, _, _) -> priority)) srvResult + -- Group elements in sublists based on their priority. The + -- type is `[[(Int, Int, Int, Domain)]]'. + let srvResult'' = Data.List.groupBy (\(priority, _, _, _) (priority', _, _, _) -> priority == priority') srvResult' :: [[(Int, Int, Int, Domain)]] + -- For each sublist, put records with a weight of zero first. + let srvResult''' = Data.List.map (\sublist -> let (a, b) = partition (\(_, weight, _, _) -> weight == 0) sublist in Data.List.concat [a, b]) srvResult'' + -- Order each sublist. + srvResult'''' <- mapM orderSublist srvResult''' + -- Concatinated the results. + return $ Data.List.concat srvResult'''' + where + orderSublist :: [(Int, Int, Int, Domain)] -> IO [(Int, Int, Int, Domain)] + orderSublist [] = return [] + orderSublist sublist = do + -- Compute the running sum, as well as the total sum of + -- the sublist. Add the running sum to the SRV tuples. + let (total, sublist') = Data.List.mapAccumL (\total (priority, weight, port, domain) -> (total + weight, (priority, weight, port, domain, total + weight))) 0 sublist + -- Choose a random number between 0 and the total sum + -- (inclusive). + randomNumber <- randomRIO (0, total) + -- Select the first record with its running sum greater + -- than or equal to the random number. + let (beginning, ((priority, weight, port, domain, _):end)) = Data.List.break (\(_, _, _, _, running) -> randomNumber <= running) sublist' + -- Remove the running total number from the remaining + -- elements. + let sublist'' = Data.List.map (\(priority, weight, port, domain, _) -> (priority, weight, port, domain)) (Data.List.concat [beginning, end]) + -- Repeat the ordering procedure on the remaining + -- elements. + tail <- orderSublist sublist'' + return $ ((priority, weight, port, domain):tail) -- Closes the connection and updates the XmppConMonad Stream state. -- killStream :: TMVar Stream -> IO (Either ExL.SomeException ())