Browse Source

Broker: use tcp sockets instead of 0mq

master
Denis Tereshkin 11 months ago
parent
commit
ca36dcc0df
  1. 8
      libatrade.cabal
  2. 363
      src/ATrade/Broker/Client.hs
  3. 99
      src/ATrade/Broker/Protocol.hs
  4. 301
      src/ATrade/Broker/Server.hs

8
libatrade.cabal

@ -30,11 +30,12 @@ library
, ATrade.Broker.TradeSinks.ZMQTradeSink , ATrade.Broker.TradeSinks.ZMQTradeSink
, ATrade.Broker.TradeSinks.GotifyTradeSink , ATrade.Broker.TradeSinks.GotifyTradeSink
, ATrade.Util , ATrade.Util
, ATrade
other-modules: Paths_libatrade other-modules: Paths_libatrade
, ATrade.Utils.MessagePipe
build-depends: base >= 4.7 && < 5 build-depends: base >= 4.7 && < 5
, BoundedChan , BoundedChan
, aeson , aeson
, attoparsec
, bimap , bimap
, binary , binary
, bytestring , bytestring
@ -59,6 +60,8 @@ library
, co-log , co-log
, ansi-terminal , ansi-terminal
, net-mqtt , net-mqtt
, network
, network-run
default-language: Haskell2010 default-language: Haskell2010
@ -80,6 +83,7 @@ test-suite libatrade-test
, tuple , tuple
, time , time
, aeson , aeson
, attoparsec
, text , text
, BoundedChan , BoundedChan
, zeromq4-haskell , zeromq4-haskell
@ -88,6 +92,8 @@ test-suite libatrade-test
, monad-loops , monad-loops
, uuid , uuid
, stm , stm
, network
, network-run
ghc-options: -threaded -rtsopts -with-rtsopts=-N -Wincomplete-patterns -Wno-orphans ghc-options: -threaded -rtsopts -with-rtsopts=-N -Wincomplete-patterns -Wno-orphans
default-language: Haskell2010 default-language: Haskell2010
other-modules: ArbitraryInstances other-modules: ArbitraryInstances

363
src/ATrade/Broker/Client.hs

@ -1,4 +1,6 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module ATrade.Broker.Client ( module ATrade.Broker.Client (
startBrokerClient, startBrokerClient,
@ -14,14 +16,19 @@ import ATrade.Broker.Protocol (BrokerServerRequest (..),
BrokerServerResponse (..), BrokerServerResponse (..),
ClientIdentity, Notification, ClientIdentity, Notification,
NotificationSqnum (NotificationSqnum), NotificationSqnum (NotificationSqnum),
RequestSqnum, RequestId (..),
getNotificationSqnum, getNotificationSqnum,
getRequestId,
getResponseRequestId,
nextSqnum) nextSqnum)
import ATrade.Logging (Message, import ATrade.Logging (Message,
Severity (Debug, Info, Warning), Severity (Debug, Info, Warning),
logWith) logWith)
import ATrade.Types (ClientSecurityParams (cspCertificate, cspServerCertificate), import ATrade.Types (ClientSecurityParams (cspCertificate, cspServerCertificate),
Order, OrderId) Order, OrderId)
import ATrade.Util (atomicMapIORef)
import ATrade.Utils.MessagePipe (MessagePipe, emptyMessagePipe,
getMessages, push)
import Colog (LogAction) import Colog (LogAction)
import Control.Concurrent (MVar, ThreadId, forkIO, import Control.Concurrent (MVar, ThreadId, forkIO,
killThread, newEmptyMVar, killThread, newEmptyMVar,
@ -29,12 +36,14 @@ import Control.Concurrent (MVar, ThreadId, forkIO,
threadDelay, tryReadMVar, threadDelay, tryReadMVar,
yield) yield)
import Control.Concurrent.BoundedChan () import Control.Concurrent.BoundedChan ()
import Control.Concurrent.MVar () import Control.Concurrent.MVar (tryPutMVar)
import Control.Exception (SomeException, finally, handle, import Control.Exception (SomeException, bracket, catch,
throwIO) finally, handle, throwIO)
import Control.Monad (forM_, when) import Control.Monad (forM_, forever, void, when)
import Control.Monad.Loops (andM, whileM_) import Control.Monad.Loops (andM, whileM_)
import Data.Aeson (decode, encode) import Data.Aeson (decode, encode)
import Data.Attoparsec.Text (char, decimal, maybeResult,
parseOnly)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy as BL
import Data.Int (Int64) import Data.Int (Int64)
@ -43,25 +52,22 @@ import Data.IORef (IORef, atomicModifyIORef',
readIORef, writeIORef) readIORef, writeIORef)
import qualified Data.List as L import qualified Data.List as L
import Data.List.NonEmpty () import Data.List.NonEmpty ()
import Data.Maybe (isNothing) import Data.Maybe (isNothing, mapMaybe)
import qualified Data.Text as T import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8) import Data.Text.Encoding (decodeUtf8)
import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding as T
import qualified Data.Text.Lazy as TL
import Data.Time (UTCTime, diffUTCTime,
getCurrentTime)
import Language.Haskell.Printf
import Network.Socket (Family (AF_INET),
SockAddr (SockAddrInet),
Socket, SocketType (Stream),
connect, defaultProtocol,
socket, tupleToHostAddress)
import Network.Socket.ByteString (recv, sendAll)
import Safe (lastMay) import Safe (lastMay)
import System.Timeout (timeout) import System.Timeout (timeout)
import System.ZMQ4 (Context, Event (In),
Poll (Sock), Req (Req),
Sub (Sub), Switch (On),
connect, poll, receive,
receiveMulti, restrict, send,
setLinger, setReqCorrelate,
setReqRelaxed, setTcpKeepAlive,
setTcpKeepAliveCount,
setTcpKeepAliveIdle,
setTcpKeepAliveInterval,
subscribe, withSocket)
import System.ZMQ4.ZAP (zapApplyCertificate,
zapSetServerCertificate)
type NotificationCallback = Notification -> IO () type NotificationCallback = Notification -> IO ()
@ -72,178 +78,130 @@ data BrokerClientHandle = BrokerClientHandle {
submitOrder :: Order -> IO (Either T.Text ()), submitOrder :: Order -> IO (Either T.Text ()),
cancelOrder :: OrderId -> IO (Either T.Text ()), cancelOrder :: OrderId -> IO (Either T.Text ()),
getNotifications :: IO (Either T.Text [Notification]), getNotifications :: IO (Either T.Text [Notification]),
cmdVar :: MVar (BrokerServerRequest, MVar BrokerServerResponse), cmdVar :: MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime),
lastKnownNotificationRef :: IORef NotificationSqnum, lastKnownNotificationRef :: IORef NotificationSqnum,
notificationCallback :: [NotificationCallback], notificationCallback :: [NotificationCallback],
notificationThreadId :: ThreadId notificationThreadId :: ThreadId
} }
brokerClientThread :: B.ByteString -> data BrokerClientEvent = IncomingResponse BrokerServerResponse
Context -> | IncomingNotification Notification
brokerClientThread :: T.Text ->
T.Text -> T.Text ->
MVar (BrokerServerRequest, MVar BrokerServerResponse) -> MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) ->
MVar () -> MVar () ->
MVar () -> MVar () ->
ClientSecurityParams -> [NotificationCallback] ->
LogAction IO Message -> LogAction IO Message ->
IO () IO ()
brokerClientThread socketIdentity ctx ep cmd comp killMv secParams logger = finally brokerClientThread' cleanup brokerClientThread socketIdentity ep cmd comp killMv notificationCallbacks logger = finally brokerClientThread' cleanup
where where
log = logWith logger log = logWith logger
cleanup = log Info "Broker.Client" "Quitting broker client thread" >> putMVar comp () cleanup = log Info "Broker.Client" "Quitting broker client thread" >> putMVar comp ()
brokerClientThread' = whileM_ (isNothing <$> tryReadMVar killMv) $ do brokerClientThread' :: IO ()
log Debug "Broker.Client" "Starting event loop" brokerClientThread' = do
handle (\e -> do pendingResp <- newEmptyMVar
log Warning "Broker.Client" $ "Broker client: exception: " <> (T.pack . show) (e :: SomeException) <> "; isZMQ: " <> (T.pack . show) (isZMQError e) pipeRef <- newIORef emptyMessagePipe
if isZMQError e case parseHostAndPort ep of
then do Right (host, port) -> forever $ do
log Debug "Broker.Client" "Rethrowing exception" clientSocket <- socket AF_INET Stream defaultProtocol
throwIO e flip catch (\(_ :: SomeException) -> log Warning "Broker.Client" "Connection error") $ forever $ do
else do connect clientSocket $ SockAddrInet (fromIntegral port) host
return ()) $ withSocket ctx Req (\sock -> do sendAll clientSocket $ B.snoc (BL.toStrict $ encode (RequestSetClientIdentity (RequestId 0) socketIdentity)) 0
setLinger (restrict 0) sock bracket (forkIO $ sendThread cmd clientSocket pendingResp) killThread $ \_ -> do
setReqCorrelate True sock isTimeout <- newIORef False
setReqRelaxed True sock whileM_ (andM [isNothing <$> tryReadMVar killMv, not <$> readIORef isTimeout]) $ do
maybeRawData <- timeout 1000000 $ recv clientSocket 4096
case cspCertificate secParams of case maybeRawData of
Just clientCert -> zapApplyCertificate clientCert sock Just rawData -> do
Nothing -> return () if B.length rawData > 0
case cspServerCertificate secParams of then do
Just serverCert -> zapSetServerCertificate serverCert sock atomicMapIORef pipeRef (push rawData)
Nothing -> return () messages <- atomicModifyIORef' pipeRef getMessages
let parsed = mapMaybe decodeEvent messages
mapM_ (handleMessage pendingResp) parsed
else writeIORef isTimeout True
Nothing -> do
maybePending <- tryReadMVar pendingResp
case maybePending of
Just (req, respVar, timestamp) -> do
now <- getCurrentTime
when (now `diffUTCTime` timestamp > 5.0) $ do
log Warning "Broker.Client" $ TL.toStrict $ [t|Request timeout: %?|] req
void $ takeMVar pendingResp
putMVar respVar $ ResponseError (getRequestId req) "Timeout"
_ -> pure ()
log Debug "Broker.Client" "Recv thread done"
threadDelay 1000000
Left err -> log Warning "Broker.Client" $ "Error: " <> (T.pack . show) err
connect sock $ T.unpack ep
log Debug "Broker.Client" "Connected"
isTimeout <- newIORef False
whileM_ (andM [isNothing <$> tryReadMVar killMv, not <$> readIORef isTimeout]) $ do sendThread cmdvar sock pendingResp = forever $ do
(request, resp) <- takeMVar cmd (req, respVar, timestamp) <- takeMVar cmdvar
send sock [] (BL.toStrict $ encode request) putMVar pendingResp (req, respVar, timestamp)
incomingMessage <- timeout 5000000 $ receive sock let json = encode req
case incomingMessage of log Debug "Broker.Client" $ T.pack $ "sendThread: sending " <> show json
Just msg -> case decode . BL.fromStrict $ msg of sendAll sock $ BL.toStrict $ BL.snoc json 0
Just response -> putMVar resp response
Nothing -> putMVar resp (ResponseError "Unable to decode response")
Nothing -> do
putMVar resp (ResponseError "Response timeout")
writeIORef isTimeout True
threadDelay 1000000)
isZMQError e = "ZMQError" `L.isPrefixOf` show e
notificationThread :: ClientIdentity -> decodeEvent :: B.ByteString -> Maybe BrokerClientEvent
[NotificationCallback] -> decodeEvent raw = case decode $ BL.fromStrict raw :: Maybe Notification of
Context -> Just notif -> Just $ IncomingNotification notif
T.Text -> Nothing -> case decode $ BL.fromStrict raw :: Maybe BrokerServerResponse of
IORef RequestSqnum -> Just response -> Just $ IncomingResponse response
MVar (BrokerServerRequest, MVar BrokerServerResponse) -> Nothing -> Nothing
MVar () ->
ClientSecurityParams ->
LogAction IO Message ->
IORef NotificationSqnum ->
IO ()
notificationThread clientIdentity callbacks ctx ep idCounter cmdVar killMv secParams logger lastKnownNotificationSqnum = flip finally (return ()) $ do
whileM_ (isNothing <$> tryReadMVar killMv) $
withSocket ctx Sub $ \sock -> do
setLinger (restrict 0) sock
case cspCertificate secParams of
Just clientCert -> zapApplyCertificate clientCert sock
Nothing -> return ()
case cspServerCertificate secParams of
Just serverCert -> zapSetServerCertificate serverCert sock
Nothing -> return ()
setTcpKeepAlive On sock
setTcpKeepAliveCount (restrict 5) sock
setTcpKeepAliveIdle (restrict 60) sock
setTcpKeepAliveInterval (restrict 10) sock
connect sock $ T.unpack ep
log Debug "Broker.Client" $ "Subscribing: [" <> clientIdentity <> "]"
subscribe sock $ T.encodeUtf8 clientIdentity
initialSqnum <- requestCurrentSqnum cmdVar idCounter clientIdentity handleMessage :: MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) -> BrokerClientEvent -> IO ()
handleMessage respVar (IncomingResponse resp) = do
(req, respVar, _) <- takeMVar respVar
if getRequestId req == getResponseRequestId resp
then putMVar respVar resp
else do
log Warning "Broker.Client" $ TL.toStrict $ [t|Request ID mismatch: %?/%?|] (getRequestId req) (getResponseRequestId resp)
putMVar respVar (ResponseError (getRequestId req) "Request ID mismatch")
handleMessage _ (IncomingNotification notif) = callNotificationCallbacks notif
log Debug "Broker.Client" $ "Got current sqnum: " <> (T.pack . show) initialSqnum callNotificationCallbacks notif = mapM_ (\cb -> cb notif) notificationCallbacks
notifSqnumRef <- newIORef initialSqnum parseHostAndPort = parseOnly endpointParser
whileM_ (isNothing <$> tryReadMVar killMv) $ do
evs <- poll 5000 [Sock sock [In] Nothing]
if null . L.head $ evs
then do
respVar <- newEmptyMVar
sqnum <- nextId idCounter
notifSqnum <- readIORef notifSqnumRef
putMVar cmdVar (RequestNotifications sqnum clientIdentity notifSqnum, respVar)
resp <- takeMVar respVar
case resp of
(ResponseNotifications ns) -> do
forM_ ns $ \notif -> do
lastSqnum <- readIORef notifSqnumRef
when (getNotificationSqnum notif >= lastSqnum) $ do
forM_ callbacks $ \c -> c notif
atomicWriteIORef notifSqnumRef (nextSqnum lastSqnum)
(ResponseError msg) -> log Warning "Broker.Client" $ "ResponseError: " <> msg
_ -> log Warning "Broker.Client" "Unknown error when requesting notifications"
else do
msg <- receiveMulti sock
case msg of
[_, payload] -> case decode (BL.fromStrict payload) of
Just notification -> do
currentSqnum <- readIORef notifSqnumRef
when (getNotificationSqnum notification /= currentSqnum) $ do
log Warning "Broker.Client" $
"Notification sqnum mismatch: " <> (T.pack . show) currentSqnum <> " -> " <> (T.pack . show) (getNotificationSqnum notification)
atomicWriteIORef notifSqnumRef (nextSqnum $ getNotificationSqnum notification)
forM_ callbacks $ \c -> c notification
atomicWriteIORef lastKnownNotificationSqnum currentSqnum
_ -> return ()
_ -> return ()
where
log = logWith logger
requestCurrentSqnum cmdVar idCounter clientIdentity = do
respVar <- newEmptyMVar
sqnum <- nextId idCounter
putMVar cmdVar (RequestCurrentSqnum sqnum clientIdentity, respVar)
resp <- takeMVar respVar
case resp of
(ResponseCurrentSqnum sqnum) -> return sqnum
(ResponseError msg) -> do
log Warning "Broker.Client" $ "ResponseError: " <> msg
return (NotificationSqnum 1)
_ -> do
log Warning "Broker.Client" "Unknown error when requesting notifications"
return (NotificationSqnum 1)
endpointParser = do
b1 <- decimal
void $ char '.'
b2 <- decimal
void $ char '.'
b3 <- decimal
void $ char '.'
b4 <- decimal
void $ char ':'
port <- decimal
pure (tupleToHostAddress (b1, b2, b3, b4), port)
startBrokerClient :: B.ByteString -- ^ Socket Identity startBrokerClient :: T.Text -- ^ Socket Identity
-> Context -- ^ ZeroMQ context -> T.Text -- ^ Broker endpoint
-> T.Text -- ^ Broker endpoing
-> T.Text -- ^ Notification endpoing
-> [NotificationCallback] -- ^ List of notification callbacks -> [NotificationCallback] -- ^ List of notification callbacks
-> ClientSecurityParams -- ^
-> LogAction IO Message -> LogAction IO Message
-> IO BrokerClientHandle -> IO BrokerClientHandle
startBrokerClient socketIdentity ctx endpoint notifEndpoint notificationCallbacks secParams logger = do startBrokerClient socketIdentity endpoint notificationCallbacks logger = do
idCounter <- newIORef 1
compMv <- newEmptyMVar compMv <- newEmptyMVar
killMv <- newEmptyMVar killMv <- newEmptyMVar
cmdVar <- newEmptyMVar :: IO (MVar (BrokerServerRequest, MVar BrokerServerResponse)) cmdVar <- newEmptyMVar :: IO (MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime))
tid <- forkIO (brokerClientThread socketIdentity ctx endpoint cmdVar compMv killMv secParams logger) idCounter <- newIORef 0
tid <- forkIO (brokerClientThread socketIdentity endpoint cmdVar compMv killMv notificationCallbacks logger)
notifSqnumRef <- newIORef (NotificationSqnum 0) notifSqnumRef <- newIORef (NotificationSqnum 0)
lastKnownNotification <- newIORef (NotificationSqnum 0) lastKnownNotification <- newIORef (NotificationSqnum 0)
notifThreadId <- forkIO (notificationThread (T.decodeUtf8 socketIdentity) notificationCallbacks ctx notifEndpoint idCounter cmdVar killMv secParams logger
lastKnownNotification)
return BrokerClientHandle { return BrokerClientHandle {
tid = tid, tid = tid,
completionMvar = compMv, completionMvar = compMv,
killMvar = killMv, killMvar = killMv,
submitOrder = bcSubmitOrder (decodeUtf8 socketIdentity) idCounter cmdVar, submitOrder = bcSubmitOrder socketIdentity idCounter cmdVar logger,
cancelOrder = bcCancelOrder (decodeUtf8 socketIdentity) idCounter cmdVar, cancelOrder = bcCancelOrder socketIdentity idCounter cmdVar logger,
getNotifications = bcGetNotifications (decodeUtf8 socketIdentity) idCounter notifSqnumRef cmdVar lastKnownNotification, getNotifications = bcGetNotifications socketIdentity idCounter notifSqnumRef cmdVar lastKnownNotification logger,
cmdVar = cmdVar, cmdVar = cmdVar,
lastKnownNotificationRef = notifSqnumRef, lastKnownNotificationRef = notifSqnumRef,
notificationCallback = [], notificationCallback = []
notificationThreadId = notifThreadId }
}
stopBrokerClient :: BrokerClientHandle -> IO () stopBrokerClient :: BrokerClientHandle -> IO ()
stopBrokerClient handle = do stopBrokerClient handle = do
@ -256,45 +214,84 @@ stopBrokerClient handle = do
nextId cnt = atomicModifyIORef' cnt (\v -> (v + 1, v)) nextId cnt = atomicModifyIORef' cnt (\v -> (v + 1, v))
bcSubmitOrder :: ClientIdentity -> IORef Int64 -> MVar (BrokerServerRequest, MVar BrokerServerResponse) -> Order -> IO (Either T.Text ()) bcSubmitOrder ::
bcSubmitOrder clientIdentity idCounter cmdVar order = do ClientIdentity ->
IORef Int64 ->
MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) ->
LogAction IO Message ->
Order ->
IO (Either T.Text ())
bcSubmitOrder clientIdentity idCounter cmdVar logger order = do
respVar <- newEmptyMVar respVar <- newEmptyMVar
sqnum <- nextId idCounter sqnum <- nextId idCounter
putMVar cmdVar (RequestSubmitOrder sqnum clientIdentity order, respVar) now <- getCurrentTime
resp <- takeMVar respVar result <- timeout 3000000 $ do
case resp of putMVar cmdVar (RequestSubmitOrder (RequestId sqnum) clientIdentity order, respVar, now)
ResponseOk -> return $ Right () resp <- takeMVar respVar
(ResponseError msg) -> return $ Left msg case resp of
_ -> return $ Left "Unknown error" ResponseOk (RequestId requestId) -> do
if requestId == sqnum
then return $ Right ()
else do
logWith logger Warning "Broker.Client" "SubmitOrder: requestId mismatch"
pure $ Left "requestid mismatch"
(ResponseError (RequestId _) msg) -> return $ Left msg
_ -> return $ Left "Unknown error"
case result of
Just r -> pure r
_ -> pure $ Left "Request timeout"
bcCancelOrder :: ClientIdentity -> IORef RequestSqnum -> MVar (BrokerServerRequest, MVar BrokerServerResponse) -> OrderId -> IO (Either T.Text ()) bcCancelOrder ::
bcCancelOrder clientIdentity idCounter cmdVar orderId = do ClientIdentity ->
IORef Int64 ->
MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) ->
LogAction IO Message ->
OrderId ->
IO (Either T.Text ())
bcCancelOrder clientIdentity idCounter cmdVar logger orderId = do
respVar <- newEmptyMVar respVar <- newEmptyMVar
sqnum <- nextId idCounter sqnum <- nextId idCounter
putMVar cmdVar (RequestCancelOrder sqnum clientIdentity orderId, respVar) now <- getCurrentTime
resp <- takeMVar respVar result <- timeout 3000000 $ do
case resp of putMVar cmdVar (RequestCancelOrder (RequestId sqnum) clientIdentity orderId, respVar, now)
ResponseOk -> return $ Right () resp <- takeMVar respVar
(ResponseError msg) -> return $ Left msg case resp of
_ -> return $ Left "Unknown error" ResponseOk (RequestId requestId) -> do
if requestId == sqnum
then return $ Right ()
else do
logWith logger Warning "Broker.Client" "CancelOrder: requestId mismatch"
pure $ Left "requestid mismatch"
(ResponseError (RequestId _) msg) -> return $ Left msg
_ -> return $ Left "Unknown error"
case result of
Just r -> pure $ r
_ -> pure $ Left "Request timeout"
bcGetNotifications :: ClientIdentity -> bcGetNotifications :: ClientIdentity ->
IORef RequestSqnum -> IORef Int64 ->
IORef NotificationSqnum -> IORef NotificationSqnum ->
MVar (BrokerServerRequest, MVar BrokerServerResponse) -> MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) ->
IORef NotificationSqnum -> IORef NotificationSqnum ->
LogAction IO Message ->
IO (Either T.Text [Notification]) IO (Either T.Text [Notification])
bcGetNotifications clientIdentity idCounter notifSqnumRef cmdVar lastKnownNotification = do bcGetNotifications clientIdentity idCounter notifSqnumRef cmdVar lastKnownNotification logger = do
respVar <- newEmptyMVar respVar <- newEmptyMVar
sqnum <- nextId idCounter sqnum <- nextId idCounter
notifSqnum <- nextSqnum <$> readIORef notifSqnumRef notifSqnum <- nextSqnum <$> readIORef notifSqnumRef
putMVar cmdVar (RequestNotifications sqnum clientIdentity notifSqnum, respVar) now <- getCurrentTime
putMVar cmdVar (RequestNotifications (RequestId sqnum) clientIdentity notifSqnum, respVar, now)
resp <- takeMVar respVar resp <- takeMVar respVar
case resp of case resp of
(ResponseNotifications ns) -> do (ResponseNotifications (RequestId requestId) ns) ->
case lastMay ns of if (requestId == sqnum)
Just n -> atomicWriteIORef notifSqnumRef (getNotificationSqnum n) then do
Nothing -> readIORef lastKnownNotification >>= atomicWriteIORef notifSqnumRef case lastMay ns of
return $ Right ns Just n -> atomicWriteIORef notifSqnumRef (getNotificationSqnum n)
(ResponseError msg) -> return $ Left msg Nothing -> readIORef lastKnownNotification >>= atomicWriteIORef notifSqnumRef
return $ Right ns
else do
logWith logger Warning "Broker.Client" "GetNotifications: requestId mismatch"
return $ Left "requestId mismatch"
(ResponseError (RequestId requestId) msg) -> return $ Left msg
_ -> return $ Left "Unknown error" _ -> return $ Left "Unknown error"

99
src/ATrade/Broker/Protocol.hs

@ -11,11 +11,12 @@ module ATrade.Broker.Protocol (
nextSqnum, nextSqnum,
getNotificationSqnum, getNotificationSqnum,
notificationOrderId, notificationOrderId,
RequestSqnum(..),
requestSqnum,
TradeSinkMessage(..), TradeSinkMessage(..),
mkTradeMessage, mkTradeMessage,
ClientIdentity(..) ClientIdentity(..),
RequestId(..),
getRequestId,
getResponseRequestId
) where ) where
import ATrade.Types import ATrade.Types
@ -32,8 +33,10 @@ import Data.Time.Clock
import Language.Haskell.Printf import Language.Haskell.Printf
import Text.Parsec hiding ((<|>)) import Text.Parsec hiding ((<|>))
data RequestId = RequestId Int64
deriving (Eq, Show, Ord)
type ClientIdentity = T.Text type ClientIdentity = T.Text
type RequestSqnum = Int64
newtype NotificationSqnum = NotificationSqnum { unNotificationSqnum :: Int64 } newtype NotificationSqnum = NotificationSqnum { unNotificationSqnum :: Int64 }
deriving (Eq, Show, Ord) deriving (Eq, Show, Ord)
@ -77,83 +80,101 @@ instance ToJSON Notification where
toJSON (TradeNotification sqnum trade) = object [ "notification-sqnum" .= toJSON (unNotificationSqnum sqnum), "trade" .= toJSON trade ] toJSON (TradeNotification sqnum trade) = object [ "notification-sqnum" .= toJSON (unNotificationSqnum sqnum), "trade" .= toJSON trade ]
data BrokerServerRequest = RequestSubmitOrder RequestSqnum ClientIdentity Order data BrokerServerRequest = RequestSubmitOrder RequestId ClientIdentity Order
| RequestCancelOrder RequestSqnum ClientIdentity OrderId | RequestCancelOrder RequestId ClientIdentity OrderId
| RequestNotifications RequestSqnum ClientIdentity NotificationSqnum | RequestNotifications RequestId ClientIdentity NotificationSqnum
| RequestCurrentSqnum RequestSqnum ClientIdentity | RequestCurrentSqnum RequestId ClientIdentity
| RequestSetClientIdentity RequestId ClientIdentity
deriving (Eq, Show) deriving (Eq, Show)
requestSqnum :: BrokerServerRequest -> RequestSqnum getRequestId :: BrokerServerRequest -> RequestId
requestSqnum (RequestSubmitOrder sqnum _ _) = sqnum getRequestId (RequestSubmitOrder rid _ _) = rid
requestSqnum (RequestCancelOrder sqnum _ _) = sqnum getRequestId (RequestCancelOrder rid _ _) = rid
requestSqnum (RequestNotifications sqnum _ _) = sqnum getRequestId (RequestNotifications rid _ _) = rid
requestSqnum (RequestCurrentSqnum sqnum _) = sqnum getRequestId (RequestCurrentSqnum rid _) = rid
getRequestId (RequestSetClientIdentity rid _) = rid
instance FromJSON BrokerServerRequest where instance FromJSON BrokerServerRequest where
parseJSON = withObject "object" (\obj -> do parseJSON = withObject "object" (\obj -> do
sqnum <- obj .: "request-sqnum"
clientIdentity <- obj .: "client-identity" clientIdentity <- obj .: "client-identity"
parseRequest sqnum clientIdentity obj) requestId <- obj .: "request-id"
parseRequest (RequestId requestId) clientIdentity obj)
where where
parseRequest :: RequestSqnum -> ClientIdentity -> Object -> Parser BrokerServerRequest parseRequest :: RequestId -> ClientIdentity -> Object -> Parser BrokerServerRequest
parseRequest sqnum clientIdentity obj parseRequest requestId clientIdentity obj
| KM.member "order" obj = do | KM.member "order" obj = do
order <- obj .: "order" order <- obj .: "order"
RequestSubmitOrder sqnum clientIdentity <$> parseJSON order RequestSubmitOrder requestId clientIdentity <$> parseJSON order
| KM.member "cancel-order" obj = do | KM.member "cancel-order" obj = do
orderId <- obj .: "cancel-order" orderId <- obj .: "cancel-order"
RequestCancelOrder sqnum clientIdentity <$> parseJSON orderId RequestCancelOrder requestId clientIdentity <$> parseJSON orderId
| KM.member "request-notifications" obj = do | KM.member "request-notifications" obj = do
initialSqnum <- obj .: "initial-sqnum" initialSqnum <- obj .: "initial-sqnum"
return (RequestNotifications sqnum clientIdentity (NotificationSqnum initialSqnum)) return (RequestNotifications requestId clientIdentity (NotificationSqnum initialSqnum))
| KM.member "request-current-sqnum" obj = | KM.member "request-current-sqnum" obj =
return (RequestCurrentSqnum sqnum clientIdentity) return (RequestCurrentSqnum requestId clientIdentity)
| KM.member "set-client-identity" obj =
return (RequestSetClientIdentity requestId clientIdentity)
parseRequest _ _ _ = fail "Invalid request object" parseRequest _ _ _ = fail "Invalid request object"
instance ToJSON BrokerServerRequest where instance ToJSON BrokerServerRequest where
toJSON (RequestSubmitOrder sqnum clientIdentity order) = object ["request-sqnum" .= sqnum, toJSON (RequestSubmitOrder (RequestId rid) clientIdentity order) = object [
"request-id" .= rid,
"client-identity" .= clientIdentity, "client-identity" .= clientIdentity,
"order" .= order ] "order" .= order ]
toJSON (RequestCancelOrder sqnum clientIdentity oid) = object ["request-sqnum" .= sqnum, toJSON (RequestCancelOrder (RequestId rid) clientIdentity oid) = object [
"request-id" .= rid,
"client-identity" .= clientIdentity, "client-identity" .= clientIdentity,
"cancel-order" .= oid ] "cancel-order" .= oid ]
toJSON (RequestNotifications sqnum clientIdentity initialNotificationSqnum) = object ["request-sqnum" .= sqnum, toJSON (RequestNotifications (RequestId rid) clientIdentity initialNotificationSqnum) = object [
"request-id" .= rid,
"client-identity" .= clientIdentity, "client-identity" .= clientIdentity,
"request-notifications" .= ("" :: T.Text), "request-notifications" .= ("" :: T.Text),
"initial-sqnum" .= unNotificationSqnum initialNotificationSqnum] "initial-sqnum" .= unNotificationSqnum initialNotificationSqnum]
toJSON (RequestCurrentSqnum sqnum clientIdentity) = object toJSON (RequestCurrentSqnum (RequestId rid) clientIdentity) = object
["request-sqnum" .= sqnum, ["request-id" .= rid,
"client-identity" .= clientIdentity, "client-identity" .= clientIdentity,
"request-current-sqnum" .= ("" :: T.Text) ] "request-current-sqnum" .= ("" :: T.Text) ]
toJSON (RequestSetClientIdentity (RequestId rid) clientIdentity) = object
["request-id" .= rid,
"client-identity" .= clientIdentity,
"set-client-identity" .= ("" :: T.Text) ]
data BrokerServerResponse = ResponseOk getResponseRequestId :: BrokerServerResponse -> RequestId
| ResponseNotifications [Notification] getResponseRequestId (ResponseOk reqId) = reqId
| ResponseCurrentSqnum NotificationSqnum getResponseRequestId (ResponseNotifications reqId _) = reqId
| ResponseError T.Text getResponseRequestId (ResponseCurrentSqnum reqId _) = reqId
getResponseRequestId (ResponseError reqId _) = reqId
data BrokerServerResponse = ResponseOk RequestId
| ResponseNotifications RequestId [Notification]
| ResponseCurrentSqnum RequestId NotificationSqnum
| ResponseError RequestId T.Text
deriving (Eq, Show) deriving (Eq, Show)
instance FromJSON BrokerServerResponse where instance FromJSON BrokerServerResponse where
parseJSON = withObject "object" (\obj -> parseJSON = withObject "object" (\obj -> do
requestId <- obj .: "request-id"
if | KM.member "result" obj -> do if | KM.member "result" obj -> do
result <- obj .: "result" result <- obj .: "result"
if (result :: T.Text) == "success" if (result :: T.Text) == "success"
then return ResponseOk then return $ ResponseOk (RequestId requestId)
else do else do
msg <- obj .:? "message" .!= "" msg <- obj .:? "message" .!= ""
return (ResponseError msg) return $ (ResponseError (RequestId requestId) msg)
| KM.member "notifications" obj -> do | KM.member "notifications" obj -> do
notifications <- obj .: "notifications" notifications <- obj .: "notifications"
ResponseNotifications <$> parseJSON notifications ResponseNotifications (RequestId requestId) <$> parseJSON notifications
| KM.member "current-sqnum" obj -> do | KM.member "current-sqnum" obj -> do
rawSqnum <- obj .: "current-sqnum" rawSqnum <- obj .: "current-sqnum"
return $ ResponseCurrentSqnum (NotificationSqnum rawSqnum) return $ ResponseCurrentSqnum (RequestId requestId) (NotificationSqnum rawSqnum)
| otherwise -> fail "Unable to parse BrokerServerResponse") | otherwise -> fail "Unable to parse BrokerServerResponse")
instance ToJSON BrokerServerResponse where instance ToJSON BrokerServerResponse where
toJSON ResponseOk = object [ "result" .= ("success" :: T.Text) ] toJSON (ResponseOk (RequestId rid)) = object [ "request-id" .= rid, "result" .= ("success" :: T.Text) ]
toJSON (ResponseNotifications notifications) = object [ "notifications" .= notifications ] toJSON (ResponseNotifications (RequestId rid) notifications) = object [ "request-id" .= rid, "notifications" .= notifications ]
toJSON (ResponseCurrentSqnum sqnum) = object [ "current-sqnum" .= unNotificationSqnum sqnum ] toJSON (ResponseCurrentSqnum (RequestId rid) sqnum) = object [ "request-id" .= rid, "current-sqnum" .= unNotificationSqnum sqnum ]
toJSON (ResponseError errorMessage) = object [ "result" .= ("error" :: T.Text), "message" .= errorMessage ] toJSON (ResponseError (RequestId rid) errorMessage) = object [ "request-id" .= rid, "result" .= ("error" :: T.Text), "message" .= errorMessage ]
data TradeSinkMessage = TradeSinkHeartBeat | TradeSinkTrade { data TradeSinkMessage = TradeSinkHeartBeat | TradeSinkTrade {
tsAccountId :: T.Text, tsAccountId :: T.Text,

301
src/ATrade/Broker/Server.hs

@ -1,4 +1,5 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
module ATrade.Broker.Server ( module ATrade.Broker.Server (
startBrokerServer, startBrokerServer,
@ -15,18 +16,19 @@ import ATrade.Broker.Protocol (BrokerServerRequest (..),
ClientIdentity, ClientIdentity,
Notification (..), Notification (..),
NotificationSqnum (NotificationSqnum), NotificationSqnum (NotificationSqnum),
RequestSqnum, RequestId (..),
getNotificationSqnum, getNotificationSqnum,
nextSqnum, requestSqnum) getRequestId, nextSqnum)
import ATrade.Logging (Message (Message), import ATrade.Logging (Message (Message),
Severity (Debug, Warning), Severity (Debug, Warning),
logWith) Severity (Info), logWith)
import ATrade.Logging (Severity (Info))
import ATrade.Types (Order (orderAccountId, orderId), import ATrade.Types (Order (orderAccountId, orderId),
OrderId, OrderId,
ServerSecurityParams (sspCertificate, sspDomain), ServerSecurityParams (sspCertificate, sspDomain),
Trade (tradeOrderId)) Trade (tradeOrderId))
import ATrade.Util (atomicMapIORef) import ATrade.Util (atomicMapIORef)
import ATrade.Utils.MessagePipe (emptyMessagePipe, getMessages,
push)
import Colog (LogAction) import Colog (LogAction)
import Control.Concurrent (MVar, ThreadId, forkIO, import Control.Concurrent (MVar, ThreadId, forkIO,
killThread, myThreadId, killThread, myThreadId,
@ -34,107 +36,89 @@ import Control.Concurrent (MVar, ThreadId, forkIO,
readMVar, threadDelay, readMVar, threadDelay,
tryReadMVar, yield) tryReadMVar, yield)
import Control.Concurrent.BoundedChan (BoundedChan, newBoundedChan, import Control.Concurrent.BoundedChan (BoundedChan, newBoundedChan,
tryReadChan, tryWriteChan) readChan, tryReadChan,
import Control.Exception (finally) tryWriteChan)
import Control.Monad (unless) import Control.Exception (bracket, finally)
import Control.Monad (unless, void, when)
import Control.Monad.Extra (forever)
import Control.Monad.Loops (whileM_) import Control.Monad.Loops (whileM_)
import Data.Aeson (eitherDecode, encode) import Data.Aeson (eitherDecode, encode)
import qualified Data.Bimap as BM import qualified Data.Bimap as BM
import qualified Data.ByteString as B hiding (putStrLn) import qualified Data.ByteString as B hiding (putStrLn)
import qualified Data.ByteString.Lazy as BL hiding (putStrLn) import qualified Data.ByteString.Lazy as BL hiding (putStrLn)
import Data.IORef (IORef, atomicModifyIORef', import Data.IORef (IORef, atomicModifyIORef',
newIORef, readIORef) newIORef, readIORef,
writeIORef)
import qualified Data.List as L import qualified Data.List as L
import Data.List.NonEmpty (NonEmpty ((:|))) import Data.List.NonEmpty (NonEmpty ((:|)))
import qualified Data.Map as M import qualified Data.Map as M
import Data.Maybe (isJust, isNothing) import Data.Maybe (isJust, isNothing)
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Encoding as E import qualified Data.Text.Encoding as E
import qualified Data.Text.Lazy as TL
import Data.Time.Clock () import Data.Time.Clock ()
import Safe (lastMay) import Language.Haskell.Printf
import Network.Socket (Family (AF_INET),
SockAddr (SockAddrInet),
Socket, SocketType (Stream),
accept, bind, defaultProtocol,
listen, socket)
import Network.Socket.ByteString (recv, sendAll)
import Safe (lastMay, readMay)
import System.Timeout () import System.Timeout ()
import System.ZMQ4 (Context, Event (In), import System.ZMQ4 hiding (Socket, Stream, bind,
Poll (Sock), Pub (..), socket)
Router (..), Socket,
Switch (On), bind, close, poll,
receiveMulti, restrict,
sendMulti, setCurveServer,
setLinger, setTcpKeepAlive,
setTcpKeepAliveCount,
setTcpKeepAliveIdle,
setTcpKeepAliveInterval,
setZapDomain, socket)
import System.ZMQ4.ZAP (zapApplyCertificate)
type PeerId = B.ByteString type PeerId = B.ByteString
data FullOrderId = FullOrderId ClientIdentity OrderId data FullOrderId = FullOrderId ClientIdentity OrderId
deriving (Show, Eq, Ord) deriving (Show, Eq, Ord)
data BrokerServerState = BrokerServerState { data ClientState = ClientState {
bsSocket :: Socket Router, cThreadId :: ThreadId,
bsNotificationsSocket :: Socket Pub, cSocket :: Socket,
orderToBroker :: M.Map FullOrderId BrokerBackend, cClientIdentity :: ClientIdentity,
orderMap :: BM.Bimap FullOrderId OrderId, cEgressQueue :: BoundedChan B.ByteString
lastPacket :: M.Map PeerId (RequestSqnum, BrokerServerResponse), }
pendingNotifications :: M.Map ClientIdentity [Notification],
notificationSqnum :: M.Map ClientIdentity NotificationSqnum,
brokers :: [BrokerBackend],
completionMvar :: MVar (),
killMvar :: MVar (),
orderIdCounter :: OrderId,
tradeSink :: BoundedChan Trade,
initialSqnum :: NotificationSqnum
data BrokerServerState = BrokerServerState {
orderToBroker :: M.Map FullOrderId BrokerBackend,
orderMap :: BM.Bimap FullOrderId OrderId,
pendingNotifications :: M.Map ClientIdentity [Notification],
notificationSqnum :: M.Map ClientIdentity NotificationSqnum,
brokers :: [BrokerBackend],
completionMvar :: MVar (),
killMvar :: MVar (),
orderIdCounter :: OrderId,
tradeSink :: BoundedChan Trade,
initialSqnum :: NotificationSqnum
} }
data BrokerServerHandle = BrokerServerHandle ThreadId ThreadId (MVar ()) (MVar ()) data BrokerServerHandle = BrokerServerHandle
{
bhServerTid :: ThreadId
, bhClients :: IORef (M.Map ClientIdentity ClientState)
, bhKillMVar :: MVar ()
, bhCompletionMVar :: MVar ()
}
type TradeSink = Trade -> IO () type TradeSink = Trade -> IO ()
startBrokerServer :: [BrokerBackend] -> startBrokerServer :: [BrokerBackend] ->
Context -> Context ->
T.Text -> T.Text ->
T.Text ->
NotificationSqnum -> NotificationSqnum ->
[TradeSink] -> [TradeSink] ->
ServerSecurityParams ->
LogAction IO Message -> LogAction IO Message ->
IO BrokerServerHandle IO BrokerServerHandle
startBrokerServer brokers c ep notificationsEp initialSqnum tradeSinks params logger = do startBrokerServer brokers c ep initialSqnum tradeSinks logger = do
sock <- socket c Router
notificationsSock <- socket c Pub
setLinger (restrict 0) sock
setLinger (restrict 0) notificationsSock
case sspDomain params of
Just domain -> do
setZapDomain (restrict $ E.encodeUtf8 domain) sock
setZapDomain (restrict $ E.encodeUtf8 domain) notificationsSock
Nothing -> return ()
case sspCertificate params of
Just cert -> do
setCurveServer True sock
zapApplyCertificate cert sock
setCurveServer True notificationsSock
zapApplyCertificate cert notificationsSock
Nothing -> return ()
bind sock (T.unpack ep)
setTcpKeepAlive On notificationsSock
setTcpKeepAliveCount (restrict 5) notificationsSock
setTcpKeepAliveIdle (restrict 60) notificationsSock
setTcpKeepAliveInterval (restrict 10) notificationsSock
bind notificationsSock (T.unpack notificationsEp)
tid <- myThreadId
compMv <- newEmptyMVar compMv <- newEmptyMVar
killMv <- newEmptyMVar killMv <- newEmptyMVar
tsChan <- newBoundedChan 100 tsChan <- newBoundedChan 100
clientsMapRef <- newIORef M.empty
state <- newIORef BrokerServerState { state <- newIORef BrokerServerState {
bsSocket = sock,
bsNotificationsSocket = notificationsSock,
orderMap = BM.empty, orderMap = BM.empty,
orderToBroker = M.empty, orderToBroker = M.empty,
lastPacket = M.empty,
pendingNotifications = M.empty, pendingNotifications = M.empty,
notificationSqnum = M.empty, notificationSqnum = M.empty,
brokers = brokers, brokers = brokers,
@ -144,18 +128,45 @@ startBrokerServer brokers c ep notificationsEp initialSqnum tradeSinks params lo
tradeSink = tsChan, tradeSink = tsChan,
initialSqnum = initialSqnum initialSqnum = initialSqnum
} }
mapM_ (\bro -> setNotificationCallback bro (Just $ notificationCallback state logger)) brokers
log Info "Broker.Server" "Forking broker server thread" let Just (_, port) = parseHostAndPort ep
BrokerServerHandle <$> forkIO (brokerServerThread state logger) <*> forkIO (tradeSinkHandler c state tradeSinks) <*> pure compMv <*> pure killMv serverSocket <- socket AF_INET Stream defaultProtocol
bind serverSocket $ SockAddrInet (fromIntegral port) 0
log Info "Broker.Server" $ TL.toStrict $ [t|Listening on port %?|] $ fromIntegral port
listen serverSocket 1024
serverTid <- forkIO $ forever $ do
(client, addr) <- accept serverSocket
log Debug "Broker.Server" "Incoming connection"
rawRequest <- recv client 4096
case eitherDecode $ BL.fromStrict $ B.init rawRequest of
Left err -> log Warning "Broker.Server" $ "Unable to decode client id: " <> (T.pack . show) rawRequest
Right (RequestSetClientIdentity requestId clientIdentity) -> do
log Info "Broker.Server" $ "Connected socket identity: " <> (T.pack . show) clientIdentity
egressQueue <- newBoundedChan 100
clientTid <- forkIO $ clientThread client egressQueue clientsMapRef state logger
let clientState = ClientState clientTid client clientIdentity egressQueue
atomicModifyIORef' clientsMapRef (\m -> (M.insert clientIdentity clientState m, ()))
_ -> log Warning "Broker.Server" $ "Invalid first message: " <> (T.pack . show) rawRequest
mapM_ (\bro -> setNotificationCallback bro (Just $ notificationCallback state clientsMapRef logger)) brokers
pure $ BrokerServerHandle serverTid clientsMapRef killMv compMv
where where
log = logWith logger log = logWith logger
parseHostAndPort :: T.Text -> Maybe (T.Text, Int)
parseHostAndPort str = case T.splitOn ":" str of
[host, port] ->
case readMay $ T.unpack port of
Just numPort -> Just (host, numPort)
_ -> Nothing
_ -> Nothing
notificationCallback :: IORef BrokerServerState -> notificationCallback :: IORef BrokerServerState ->
IORef (M.Map ClientIdentity ClientState) ->
LogAction IO Message -> LogAction IO Message ->
BrokerBackendNotification -> BrokerBackendNotification ->
IO () IO ()
notificationCallback state logger n = do notificationCallback state clientsMapRef logger n = do
log Debug "Broker.Server" $ "Notification: " <> (T.pack . show) n log Debug "Broker.Server" $ "Notification: " <> (T.pack . show) n
chan <- tradeSink <$> readIORef state chan <- tradeSink <$> readIORef state
case n of case n of
@ -180,8 +191,10 @@ notificationCallback state logger n = do
case M.lookup clientIdentity . pendingNotifications $ s of case M.lookup clientIdentity . pendingNotifications $ s of
Just ns -> s { pendingNotifications = M.insert clientIdentity (n : ns) (pendingNotifications s)} Just ns -> s { pendingNotifications = M.insert clientIdentity (n : ns) (pendingNotifications s)}
Nothing -> s { pendingNotifications = M.insert clientIdentity [n] (pendingNotifications s)}) Nothing -> s { pendingNotifications = M.insert clientIdentity [n] (pendingNotifications s)})
sock <- bsNotificationsSocket <$> readIORef state clients <- readIORef clientsMapRef
sendMulti sock (E.encodeUtf8 clientIdentity :| [BL.toStrict $ encode n]) case M.lookup clientIdentity clients of
Just client -> void $ tryWriteChan (cEgressQueue client) $ BL.toStrict $ encode n
Nothing -> log Warning "Broker.Server" $ TL.toStrict $ [t|Unable to send notification to %?|] clientIdentity
tradeSinkHandler :: Context -> IORef BrokerServerState -> [TradeSink] -> IO () tradeSinkHandler :: Context -> IORef BrokerServerState -> [TradeSink] -> IO ()
tradeSinkHandler c state tradeSinks = unless (null tradeSinks) $ tradeSinkHandler c state tradeSinks = unless (null tradeSinks) $
@ -195,118 +208,108 @@ tradeSinkHandler c state tradeSinks = unless (null tradeSinks) $
wasKilled = isJust <$> (readIORef state >>= tryReadMVar . killMvar) wasKilled = isJust <$> (readIORef state >>= tryReadMVar . killMvar)
brokerServerThread :: IORef BrokerServerState -> clientThread :: Socket ->
LogAction IO Message -> BoundedChan B.ByteString ->
IO () IORef (M.Map ClientIdentity ClientState) ->
brokerServerThread state logger = finally brokerServerThread' cleanup IORef BrokerServerState ->
LogAction IO Message ->
IO ()
clientThread socket egressQueue clients serverState logger =
bracket
(forkIO sendingThread)
(\tid -> do
log Debug "Broker.Server" "Killing sending thread"
killThread tid)
brokerServerThread'
where where
log = logWith logger log = logWith logger
brokerServerThread' = whileM_ (fmap killMvar (readIORef state) >>= fmap isNothing . tryReadMVar) $ do brokerServerThread' _ = do
sock <- bsSocket <$> readIORef state pipeRef <- newIORef emptyMessagePipe
events <- poll 100 [Sock sock [In] Nothing] brokerServerThread'' pipeRef
unless (null . L.head $ events) $ do log Info "Broker.Server" "Client disconnected"
msg <- receiveMulti sock
case msg of brokerServerThread'' pipeRef = do
[peerId, _, payload] -> do rawData <- recv socket 4096
case eitherDecode . BL.fromStrict $ payload of when (B.length rawData > 0) $ do
Right request -> do pipe <- readIORef pipeRef
let sqnum = requestSqnum request let (pipe', chunks) = getMessages (push rawData pipe)
-- Here, we should check if previous packet sequence number is the same writeIORef pipeRef pipe'
-- If it is, we should resend previous response mapM_ (handleChunk egressQueue) chunks
lastPackMap <- lastPacket <$> readIORef state brokerServerThread'' pipeRef
case shouldResend sqnum peerId lastPackMap of
Just response -> do
log Debug "Broker.Server" $ "Resending packet for peerId: " <> (T.pack . show) peerId
sendMessage sock peerId response -- Resend
atomicMapIORef state (\s -> s { lastPacket = M.delete peerId (lastPacket s)})
Nothing -> do
-- Handle incoming request, send response
response <- handleMessage peerId request
sendMessage sock peerId response
-- and store response in case we'll need to resend it
atomicMapIORef state (\s -> s { lastPacket = M.insert peerId (sqnum, response) (lastPacket s)})
Left errmsg -> do
-- If we weren't able to parse request, we should send error
-- but shouldn't update lastPacket
let response = ResponseError $ "Invalid request: " <> T.pack errmsg
sendMessage sock peerId response
_ -> log Warning "Broker.Server" ("Invalid packet received: " <> (T.pack . show) msg)
shouldResend sqnum peerId lastPackMap = case M.lookup peerId lastPackMap of sendingThread = forever $ do
Just (lastSqnum, response) -> if sqnum == lastSqnum packet <- readChan egressQueue
then Just response log Debug "Broker.Server" $ TL.toStrict $ [t|Sending packet: %?|] packet
else Nothing sendAll socket $ B.snoc packet 0
Nothing -> Nothing
cleanup = do enqueueEgressPacket = tryWriteChan egressQueue
sock <- bsSocket <$> readIORef state
close sock
mv <- completionMvar <$> readIORef state handleChunk egressQueue payload = do
putMVar mv () response <- case eitherDecode . BL.fromStrict $ payload of
Right request -> handleMessage request
Left errmsg -> pure $ ResponseError (RequestId 0) $ "Invalid request: " <> T.pack errmsg
enqueueEgressPacket $ BL.toStrict $ encode response
handleMessage :: PeerId -> BrokerServerRequest -> IO BrokerServerResponse handleMessage :: BrokerServerRequest -> IO BrokerServerResponse
handleMessage peerId request = do handleMessage request = do
bros <- brokers <$> readIORef state log Debug "Broker.Server" "Handle message"
bros <- brokers <$> readIORef serverState
case request of case request of
RequestSubmitOrder sqnum clientIdentity order -> do RequestSubmitOrder requestId clientIdentity order -> do
log Debug "Broker.Server" $ "Request: submit order:" <> (T.pack . show) request log Debug "Broker.Server" $ "Request: submit order:" <> (T.pack . show) request
case findBrokerForAccount (orderAccountId order) bros of case findBrokerForAccount (orderAccountId order) bros of
Just bro -> do Just bro -> do
globalOrderId <- nextOrderId globalOrderId <- nextOrderId
let fullOrderId = FullOrderId clientIdentity (orderId order) let fullOrderId = FullOrderId clientIdentity (orderId order)
atomicMapIORef state (\s -> s { atomicMapIORef serverState (\s -> s {
orderToBroker = M.insert fullOrderId bro (orderToBroker s), orderToBroker = M.insert fullOrderId bro (orderToBroker s),
orderMap = BM.insert fullOrderId globalOrderId (orderMap s) }) orderMap = BM.insert fullOrderId globalOrderId (orderMap s) })
submitOrder bro order { orderId = globalOrderId } submitOrder bro order { orderId = globalOrderId }
return ResponseOk return $ ResponseOk requestId
Nothing -> do Nothing -> do
log Warning "Broker.Server" $ "Unknown account: " <> (orderAccountId order) log Warning "Broker.Server" $ "Unknown account: " <> orderAccountId order
return $ ResponseError "Unknown account" return $ ResponseError requestId "Unknown account"
RequestCancelOrder sqnum clientIdentity localOrderId -> do RequestCancelOrder requestId clientIdentity localOrderId -> do
log Debug "Broker.Server" $ "Request: cancel order:" <> (T.pack . show) request log Debug "Broker.Server" $ "Request: cancel order:" <> (T.pack . show) request
m <- orderToBroker <$> readIORef state m <- orderToBroker <$> readIORef serverState
bm <- orderMap <$> readIORef state bm <- orderMap <$> readIORef serverState
let fullOrderId = FullOrderId clientIdentity localOrderId let fullOrderId = FullOrderId clientIdentity localOrderId
case (M.lookup fullOrderId m, BM.lookup fullOrderId bm) of case (M.lookup fullOrderId m, BM.lookup fullOrderId bm) of
(Just bro, Just globalOrderId) -> do (Just bro, Just globalOrderId) -> do
cancelOrder bro globalOrderId cancelOrder bro globalOrderId
return ResponseOk return $ ResponseOk requestId
_ -> return $ ResponseError "Unknown order" _ -> return $ ResponseError requestId "Unknown order"
RequestNotifications sqnum clientIdentity initialSqnum -> do RequestNotifications requestId clientIdentity initialSqnum -> do
log Debug "Broker.Server" $ "Request: notifications:" <> (T.pack . show) request log Debug "Broker.Server" $ "Request: notifications:" <> (T.pack . show) request
maybeNs <- M.lookup clientIdentity . pendingNotifications <$> readIORef state maybeNs <- M.lookup clientIdentity . pendingNotifications <$> readIORef serverState
case maybeNs of case maybeNs of
Just ns -> do Just ns -> do
let filtered = L.filter (\n -> getNotificationSqnum n >= initialSqnum) ns let filtered = L.filter (\n -> getNotificationSqnum n >= initialSqnum) ns
atomicMapIORef state (\s -> s { pendingNotifications = M.insert clientIdentity filtered (pendingNotifications s)}) atomicMapIORef serverState (\s -> s { pendingNotifications = M.insert clientIdentity filtered (pendingNotifications s)})
return $ ResponseNotifications . L.reverse $ filtered return $ ResponseNotifications requestId . L.reverse $ filtered
Nothing -> return $ ResponseNotifications [] Nothing -> return $ ResponseNotifications requestId []
RequestCurrentSqnum sqnum clientIdentity -> do RequestCurrentSqnum requestId clientIdentity -> do
log Debug "Broker.Server" $ "Request: current sqnum:" <> (T.pack . show) request log Debug "Broker.Server" $ "Request: current sqnum:" <> (T.pack . show) request
sqnumMap <- notificationSqnum <$> readIORef state sqnumMap <- notificationSqnum <$> readIORef serverState
notifMap <- pendingNotifications <$> readIORef state notifMap <- pendingNotifications <$> readIORef serverState
case M.lookup clientIdentity notifMap of case M.lookup clientIdentity notifMap of
Just [] -> Just [] ->
case M.lookup clientIdentity sqnumMap of case M.lookup clientIdentity sqnumMap of
Just sqnum -> return (ResponseCurrentSqnum sqnum) Just sqnum -> return (ResponseCurrentSqnum requestId sqnum)
_ -> return (ResponseCurrentSqnum (NotificationSqnum 1)) _ -> return (ResponseCurrentSqnum requestId (NotificationSqnum 1))
Just notifs -> case lastMay notifs of Just notifs -> case lastMay notifs of
Just v -> return (ResponseCurrentSqnum (getNotificationSqnum v)) Just v -> return (ResponseCurrentSqnum requestId (getNotificationSqnum v))
_ -> return (ResponseCurrentSqnum (NotificationSqnum 1)) _ -> return (ResponseCurrentSqnum requestId (NotificationSqnum 1))
Nothing -> return (ResponseCurrentSqnum (NotificationSqnum 1)) Nothing -> return (ResponseCurrentSqnum requestId (NotificationSqnum 1))
RequestSetClientIdentity requestId _ -> pure $ ResponseError requestId "Client identity change is not supported"
sendMessage sock peerId resp = sendMulti sock (peerId :| [B.empty, BL.toStrict . encode $ resp])
findBrokerForAccount account = L.find (L.elem account . accounts) findBrokerForAccount account = L.find (L.elem account . accounts)
nextOrderId = atomicModifyIORef' state (\s -> ( s {orderIdCounter = 1 + orderIdCounter s}, orderIdCounter s)) nextOrderId = atomicModifyIORef' serverState (\s -> ( s {orderIdCounter = 1 + orderIdCounter s}, orderIdCounter s))
stopBrokerServer :: BrokerServerHandle -> IO () stopBrokerServer :: BrokerServerHandle -> IO ()
stopBrokerServer (BrokerServerHandle tid tstid compMv killMv) = do stopBrokerServer (BrokerServerHandle tid clients compMv killMv) = do
putMVar killMv () putMVar killMv ()
killThread tstid readIORef clients >>= mapM_ (killThread . cThreadId) . M.elems
yield yield
readMVar compMv readMVar compMv

Loading…
Cancel
Save