Browse Source

Broker: use tcp sockets instead of 0mq

master
Denis Tereshkin 11 months ago
parent
commit
ca36dcc0df
  1. 8
      libatrade.cabal
  2. 335
      src/ATrade/Broker/Client.hs
  3. 99
      src/ATrade/Broker/Protocol.hs
  4. 277
      src/ATrade/Broker/Server.hs

8
libatrade.cabal

@ -30,11 +30,12 @@ library
, ATrade.Broker.TradeSinks.ZMQTradeSink , ATrade.Broker.TradeSinks.ZMQTradeSink
, ATrade.Broker.TradeSinks.GotifyTradeSink , ATrade.Broker.TradeSinks.GotifyTradeSink
, ATrade.Util , ATrade.Util
, ATrade
other-modules: Paths_libatrade other-modules: Paths_libatrade
, ATrade.Utils.MessagePipe
build-depends: base >= 4.7 && < 5 build-depends: base >= 4.7 && < 5
, BoundedChan , BoundedChan
, aeson , aeson
, attoparsec
, bimap , bimap
, binary , binary
, bytestring , bytestring
@ -59,6 +60,8 @@ library
, co-log , co-log
, ansi-terminal , ansi-terminal
, net-mqtt , net-mqtt
, network
, network-run
default-language: Haskell2010 default-language: Haskell2010
@ -80,6 +83,7 @@ test-suite libatrade-test
, tuple , tuple
, time , time
, aeson , aeson
, attoparsec
, text , text
, BoundedChan , BoundedChan
, zeromq4-haskell , zeromq4-haskell
@ -88,6 +92,8 @@ test-suite libatrade-test
, monad-loops , monad-loops
, uuid , uuid
, stm , stm
, network
, network-run
ghc-options: -threaded -rtsopts -with-rtsopts=-N -Wincomplete-patterns -Wno-orphans ghc-options: -threaded -rtsopts -with-rtsopts=-N -Wincomplete-patterns -Wno-orphans
default-language: Haskell2010 default-language: Haskell2010
other-modules: ArbitraryInstances other-modules: ArbitraryInstances

335
src/ATrade/Broker/Client.hs

@ -1,4 +1,6 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module ATrade.Broker.Client ( module ATrade.Broker.Client (
startBrokerClient, startBrokerClient,
@ -14,14 +16,19 @@ import ATrade.Broker.Protocol (BrokerServerRequest (..),
BrokerServerResponse (..), BrokerServerResponse (..),
ClientIdentity, Notification, ClientIdentity, Notification,
NotificationSqnum (NotificationSqnum), NotificationSqnum (NotificationSqnum),
RequestSqnum, RequestId (..),
getNotificationSqnum, getNotificationSqnum,
getRequestId,
getResponseRequestId,
nextSqnum) nextSqnum)
import ATrade.Logging (Message, import ATrade.Logging (Message,
Severity (Debug, Info, Warning), Severity (Debug, Info, Warning),
logWith) logWith)
import ATrade.Types (ClientSecurityParams (cspCertificate, cspServerCertificate), import ATrade.Types (ClientSecurityParams (cspCertificate, cspServerCertificate),
Order, OrderId) Order, OrderId)
import ATrade.Util (atomicMapIORef)
import ATrade.Utils.MessagePipe (MessagePipe, emptyMessagePipe,
getMessages, push)
import Colog (LogAction) import Colog (LogAction)
import Control.Concurrent (MVar, ThreadId, forkIO, import Control.Concurrent (MVar, ThreadId, forkIO,
killThread, newEmptyMVar, killThread, newEmptyMVar,
@ -29,12 +36,14 @@ import Control.Concurrent (MVar, ThreadId, forkIO,
threadDelay, tryReadMVar, threadDelay, tryReadMVar,
yield) yield)
import Control.Concurrent.BoundedChan () import Control.Concurrent.BoundedChan ()
import Control.Concurrent.MVar () import Control.Concurrent.MVar (tryPutMVar)
import Control.Exception (SomeException, finally, handle, import Control.Exception (SomeException, bracket, catch,
throwIO) finally, handle, throwIO)
import Control.Monad (forM_, when) import Control.Monad (forM_, forever, void, when)
import Control.Monad.Loops (andM, whileM_) import Control.Monad.Loops (andM, whileM_)
import Data.Aeson (decode, encode) import Data.Aeson (decode, encode)
import Data.Attoparsec.Text (char, decimal, maybeResult,
parseOnly)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy as BL
import Data.Int (Int64) import Data.Int (Int64)
@ -43,25 +52,22 @@ import Data.IORef (IORef, atomicModifyIORef',
readIORef, writeIORef) readIORef, writeIORef)
import qualified Data.List as L import qualified Data.List as L
import Data.List.NonEmpty () import Data.List.NonEmpty ()
import Data.Maybe (isNothing) import Data.Maybe (isNothing, mapMaybe)
import qualified Data.Text as T import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8) import Data.Text.Encoding (decodeUtf8)
import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding as T
import qualified Data.Text.Lazy as TL
import Data.Time (UTCTime, diffUTCTime,
getCurrentTime)
import Language.Haskell.Printf
import Network.Socket (Family (AF_INET),
SockAddr (SockAddrInet),
Socket, SocketType (Stream),
connect, defaultProtocol,
socket, tupleToHostAddress)
import Network.Socket.ByteString (recv, sendAll)
import Safe (lastMay) import Safe (lastMay)
import System.Timeout (timeout) import System.Timeout (timeout)
import System.ZMQ4 (Context, Event (In),
Poll (Sock), Req (Req),
Sub (Sub), Switch (On),
connect, poll, receive,
receiveMulti, restrict, send,
setLinger, setReqCorrelate,
setReqRelaxed, setTcpKeepAlive,
setTcpKeepAliveCount,
setTcpKeepAliveIdle,
setTcpKeepAliveInterval,
subscribe, withSocket)
import System.ZMQ4.ZAP (zapApplyCertificate,
zapSetServerCertificate)
type NotificationCallback = Notification -> IO () type NotificationCallback = Notification -> IO ()
@ -72,177 +78,129 @@ data BrokerClientHandle = BrokerClientHandle {
submitOrder :: Order -> IO (Either T.Text ()), submitOrder :: Order -> IO (Either T.Text ()),
cancelOrder :: OrderId -> IO (Either T.Text ()), cancelOrder :: OrderId -> IO (Either T.Text ()),
getNotifications :: IO (Either T.Text [Notification]), getNotifications :: IO (Either T.Text [Notification]),
cmdVar :: MVar (BrokerServerRequest, MVar BrokerServerResponse), cmdVar :: MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime),
lastKnownNotificationRef :: IORef NotificationSqnum, lastKnownNotificationRef :: IORef NotificationSqnum,
notificationCallback :: [NotificationCallback], notificationCallback :: [NotificationCallback],
notificationThreadId :: ThreadId notificationThreadId :: ThreadId
} }
brokerClientThread :: B.ByteString -> data BrokerClientEvent = IncomingResponse BrokerServerResponse
Context -> | IncomingNotification Notification
brokerClientThread :: T.Text ->
T.Text -> T.Text ->
MVar (BrokerServerRequest, MVar BrokerServerResponse) -> MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) ->
MVar () -> MVar () ->
MVar () -> MVar () ->
ClientSecurityParams -> [NotificationCallback] ->
LogAction IO Message -> LogAction IO Message ->
IO () IO ()
brokerClientThread socketIdentity ctx ep cmd comp killMv secParams logger = finally brokerClientThread' cleanup brokerClientThread socketIdentity ep cmd comp killMv notificationCallbacks logger = finally brokerClientThread' cleanup
where where
log = logWith logger log = logWith logger
cleanup = log Info "Broker.Client" "Quitting broker client thread" >> putMVar comp () cleanup = log Info "Broker.Client" "Quitting broker client thread" >> putMVar comp ()
brokerClientThread' = whileM_ (isNothing <$> tryReadMVar killMv) $ do brokerClientThread' :: IO ()
log Debug "Broker.Client" "Starting event loop" brokerClientThread' = do
handle (\e -> do pendingResp <- newEmptyMVar
log Warning "Broker.Client" $ "Broker client: exception: " <> (T.pack . show) (e :: SomeException) <> "; isZMQ: " <> (T.pack . show) (isZMQError e) pipeRef <- newIORef emptyMessagePipe
if isZMQError e case parseHostAndPort ep of
then do Right (host, port) -> forever $ do
log Debug "Broker.Client" "Rethrowing exception" clientSocket <- socket AF_INET Stream defaultProtocol
throwIO e flip catch (\(_ :: SomeException) -> log Warning "Broker.Client" "Connection error") $ forever $ do
else do connect clientSocket $ SockAddrInet (fromIntegral port) host
return ()) $ withSocket ctx Req (\sock -> do sendAll clientSocket $ B.snoc (BL.toStrict $ encode (RequestSetClientIdentity (RequestId 0) socketIdentity)) 0
setLinger (restrict 0) sock bracket (forkIO $ sendThread cmd clientSocket pendingResp) killThread $ \_ -> do
setReqCorrelate True sock
setReqRelaxed True sock
case cspCertificate secParams of
Just clientCert -> zapApplyCertificate clientCert sock
Nothing -> return ()
case cspServerCertificate secParams of
Just serverCert -> zapSetServerCertificate serverCert sock
Nothing -> return ()
connect sock $ T.unpack ep
log Debug "Broker.Client" "Connected"
isTimeout <- newIORef False isTimeout <- newIORef False
whileM_ (andM [isNothing <$> tryReadMVar killMv, not <$> readIORef isTimeout]) $ do whileM_ (andM [isNothing <$> tryReadMVar killMv, not <$> readIORef isTimeout]) $ do
(request, resp) <- takeMVar cmd maybeRawData <- timeout 1000000 $ recv clientSocket 4096
send sock [] (BL.toStrict $ encode request) case maybeRawData of
incomingMessage <- timeout 5000000 $ receive sock Just rawData -> do
case incomingMessage of if B.length rawData > 0
Just msg -> case decode . BL.fromStrict $ msg of then do
Just response -> putMVar resp response atomicMapIORef pipeRef (push rawData)
Nothing -> putMVar resp (ResponseError "Unable to decode response") messages <- atomicModifyIORef' pipeRef getMessages
let parsed = mapMaybe decodeEvent messages
mapM_ (handleMessage pendingResp) parsed
else writeIORef isTimeout True
Nothing -> do Nothing -> do
putMVar resp (ResponseError "Response timeout") maybePending <- tryReadMVar pendingResp
writeIORef isTimeout True case maybePending of
threadDelay 1000000) Just (req, respVar, timestamp) -> do
isZMQError e = "ZMQError" `L.isPrefixOf` show e now <- getCurrentTime
when (now `diffUTCTime` timestamp > 5.0) $ do
log Warning "Broker.Client" $ TL.toStrict $ [t|Request timeout: %?|] req
void $ takeMVar pendingResp
putMVar respVar $ ResponseError (getRequestId req) "Timeout"
_ -> pure ()
log Debug "Broker.Client" "Recv thread done"
threadDelay 1000000
Left err -> log Warning "Broker.Client" $ "Error: " <> (T.pack . show) err
notificationThread :: ClientIdentity -> sendThread cmdvar sock pendingResp = forever $ do
[NotificationCallback] -> (req, respVar, timestamp) <- takeMVar cmdvar
Context -> putMVar pendingResp (req, respVar, timestamp)
T.Text -> let json = encode req
IORef RequestSqnum -> log Debug "Broker.Client" $ T.pack $ "sendThread: sending " <> show json
MVar (BrokerServerRequest, MVar BrokerServerResponse) -> sendAll sock $ BL.toStrict $ BL.snoc json 0
MVar () ->
ClientSecurityParams ->
LogAction IO Message ->
IORef NotificationSqnum ->
IO ()
notificationThread clientIdentity callbacks ctx ep idCounter cmdVar killMv secParams logger lastKnownNotificationSqnum = flip finally (return ()) $ do
whileM_ (isNothing <$> tryReadMVar killMv) $
withSocket ctx Sub $ \sock -> do
setLinger (restrict 0) sock
case cspCertificate secParams of
Just clientCert -> zapApplyCertificate clientCert sock
Nothing -> return ()
case cspServerCertificate secParams of
Just serverCert -> zapSetServerCertificate serverCert sock
Nothing -> return ()
setTcpKeepAlive On sock
setTcpKeepAliveCount (restrict 5) sock
setTcpKeepAliveIdle (restrict 60) sock
setTcpKeepAliveInterval (restrict 10) sock
connect sock $ T.unpack ep
log Debug "Broker.Client" $ "Subscribing: [" <> clientIdentity <> "]"
subscribe sock $ T.encodeUtf8 clientIdentity
initialSqnum <- requestCurrentSqnum cmdVar idCounter clientIdentity
log Debug "Broker.Client" $ "Got current sqnum: " <> (T.pack . show) initialSqnum decodeEvent :: B.ByteString -> Maybe BrokerClientEvent
notifSqnumRef <- newIORef initialSqnum decodeEvent raw = case decode $ BL.fromStrict raw :: Maybe Notification of
whileM_ (isNothing <$> tryReadMVar killMv) $ do Just notif -> Just $ IncomingNotification notif
evs <- poll 5000 [Sock sock [In] Nothing] Nothing -> case decode $ BL.fromStrict raw :: Maybe BrokerServerResponse of
if null . L.head $ evs Just response -> Just $ IncomingResponse response
then do Nothing -> Nothing
respVar <- newEmptyMVar
sqnum <- nextId idCounter handleMessage :: MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) -> BrokerClientEvent -> IO ()
notifSqnum <- readIORef notifSqnumRef handleMessage respVar (IncomingResponse resp) = do
putMVar cmdVar (RequestNotifications sqnum clientIdentity notifSqnum, respVar) (req, respVar, _) <- takeMVar respVar
resp <- takeMVar respVar if getRequestId req == getResponseRequestId resp
case resp of then putMVar respVar resp
(ResponseNotifications ns) -> do
forM_ ns $ \notif -> do
lastSqnum <- readIORef notifSqnumRef
when (getNotificationSqnum notif >= lastSqnum) $ do
forM_ callbacks $ \c -> c notif
atomicWriteIORef notifSqnumRef (nextSqnum lastSqnum)
(ResponseError msg) -> log Warning "Broker.Client" $ "ResponseError: " <> msg
_ -> log Warning "Broker.Client" "Unknown error when requesting notifications"
else do else do
msg <- receiveMulti sock log Warning "Broker.Client" $ TL.toStrict $ [t|Request ID mismatch: %?/%?|] (getRequestId req) (getResponseRequestId resp)
case msg of putMVar respVar (ResponseError (getRequestId req) "Request ID mismatch")
[_, payload] -> case decode (BL.fromStrict payload) of handleMessage _ (IncomingNotification notif) = callNotificationCallbacks notif
Just notification -> do
currentSqnum <- readIORef notifSqnumRef
when (getNotificationSqnum notification /= currentSqnum) $ do
log Warning "Broker.Client" $
"Notification sqnum mismatch: " <> (T.pack . show) currentSqnum <> " -> " <> (T.pack . show) (getNotificationSqnum notification)
atomicWriteIORef notifSqnumRef (nextSqnum $ getNotificationSqnum notification)
forM_ callbacks $ \c -> c notification
atomicWriteIORef lastKnownNotificationSqnum currentSqnum
_ -> return ()
_ -> return ()
where
log = logWith logger
requestCurrentSqnum cmdVar idCounter clientIdentity = do
respVar <- newEmptyMVar
sqnum <- nextId idCounter
putMVar cmdVar (RequestCurrentSqnum sqnum clientIdentity, respVar)
resp <- takeMVar respVar
case resp of
(ResponseCurrentSqnum sqnum) -> return sqnum
(ResponseError msg) -> do
log Warning "Broker.Client" $ "ResponseError: " <> msg
return (NotificationSqnum 1)
_ -> do
log Warning "Broker.Client" "Unknown error when requesting notifications"
return (NotificationSqnum 1)
callNotificationCallbacks notif = mapM_ (\cb -> cb notif) notificationCallbacks
parseHostAndPort = parseOnly endpointParser
startBrokerClient :: B.ByteString -- ^ Socket Identity endpointParser = do
-> Context -- ^ ZeroMQ context b1 <- decimal
-> T.Text -- ^ Broker endpoing void $ char '.'
-> T.Text -- ^ Notification endpoing b2 <- decimal
void $ char '.'
b3 <- decimal
void $ char '.'
b4 <- decimal
void $ char ':'
port <- decimal
pure (tupleToHostAddress (b1, b2, b3, b4), port)
startBrokerClient :: T.Text -- ^ Socket Identity
-> T.Text -- ^ Broker endpoint
-> [NotificationCallback] -- ^ List of notification callbacks -> [NotificationCallback] -- ^ List of notification callbacks
-> ClientSecurityParams -- ^
-> LogAction IO Message -> LogAction IO Message
-> IO BrokerClientHandle -> IO BrokerClientHandle
startBrokerClient socketIdentity ctx endpoint notifEndpoint notificationCallbacks secParams logger = do startBrokerClient socketIdentity endpoint notificationCallbacks logger = do
idCounter <- newIORef 1
compMv <- newEmptyMVar compMv <- newEmptyMVar
killMv <- newEmptyMVar killMv <- newEmptyMVar
cmdVar <- newEmptyMVar :: IO (MVar (BrokerServerRequest, MVar BrokerServerResponse)) cmdVar <- newEmptyMVar :: IO (MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime))
tid <- forkIO (brokerClientThread socketIdentity ctx endpoint cmdVar compMv killMv secParams logger) idCounter <- newIORef 0
tid <- forkIO (brokerClientThread socketIdentity endpoint cmdVar compMv killMv notificationCallbacks logger)
notifSqnumRef <- newIORef (NotificationSqnum 0) notifSqnumRef <- newIORef (NotificationSqnum 0)
lastKnownNotification <- newIORef (NotificationSqnum 0) lastKnownNotification <- newIORef (NotificationSqnum 0)
notifThreadId <- forkIO (notificationThread (T.decodeUtf8 socketIdentity) notificationCallbacks ctx notifEndpoint idCounter cmdVar killMv secParams logger
lastKnownNotification)
return BrokerClientHandle { return BrokerClientHandle {
tid = tid, tid = tid,
completionMvar = compMv, completionMvar = compMv,
killMvar = killMv, killMvar = killMv,
submitOrder = bcSubmitOrder (decodeUtf8 socketIdentity) idCounter cmdVar, submitOrder = bcSubmitOrder socketIdentity idCounter cmdVar logger,
cancelOrder = bcCancelOrder (decodeUtf8 socketIdentity) idCounter cmdVar, cancelOrder = bcCancelOrder socketIdentity idCounter cmdVar logger,
getNotifications = bcGetNotifications (decodeUtf8 socketIdentity) idCounter notifSqnumRef cmdVar lastKnownNotification, getNotifications = bcGetNotifications socketIdentity idCounter notifSqnumRef cmdVar lastKnownNotification logger,
cmdVar = cmdVar, cmdVar = cmdVar,
lastKnownNotificationRef = notifSqnumRef, lastKnownNotificationRef = notifSqnumRef,
notificationCallback = [], notificationCallback = []
notificationThreadId = notifThreadId
} }
stopBrokerClient :: BrokerClientHandle -> IO () stopBrokerClient :: BrokerClientHandle -> IO ()
@ -256,45 +214,84 @@ stopBrokerClient handle = do
nextId cnt = atomicModifyIORef' cnt (\v -> (v + 1, v)) nextId cnt = atomicModifyIORef' cnt (\v -> (v + 1, v))
bcSubmitOrder :: ClientIdentity -> IORef Int64 -> MVar (BrokerServerRequest, MVar BrokerServerResponse) -> Order -> IO (Either T.Text ()) bcSubmitOrder ::
bcSubmitOrder clientIdentity idCounter cmdVar order = do ClientIdentity ->
IORef Int64 ->
MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) ->
LogAction IO Message ->
Order ->
IO (Either T.Text ())
bcSubmitOrder clientIdentity idCounter cmdVar logger order = do
respVar <- newEmptyMVar respVar <- newEmptyMVar
sqnum <- nextId idCounter sqnum <- nextId idCounter
putMVar cmdVar (RequestSubmitOrder sqnum clientIdentity order, respVar) now <- getCurrentTime
result <- timeout 3000000 $ do
putMVar cmdVar (RequestSubmitOrder (RequestId sqnum) clientIdentity order, respVar, now)
resp <- takeMVar respVar resp <- takeMVar respVar
case resp of case resp of
ResponseOk -> return $ Right () ResponseOk (RequestId requestId) -> do
(ResponseError msg) -> return $ Left msg if requestId == sqnum
then return $ Right ()
else do
logWith logger Warning "Broker.Client" "SubmitOrder: requestId mismatch"
pure $ Left "requestid mismatch"
(ResponseError (RequestId _) msg) -> return $ Left msg
_ -> return $ Left "Unknown error" _ -> return $ Left "Unknown error"
case result of
Just r -> pure r
_ -> pure $ Left "Request timeout"
bcCancelOrder :: ClientIdentity -> IORef RequestSqnum -> MVar (BrokerServerRequest, MVar BrokerServerResponse) -> OrderId -> IO (Either T.Text ()) bcCancelOrder ::
bcCancelOrder clientIdentity idCounter cmdVar orderId = do ClientIdentity ->
IORef Int64 ->
MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) ->
LogAction IO Message ->
OrderId ->
IO (Either T.Text ())
bcCancelOrder clientIdentity idCounter cmdVar logger orderId = do
respVar <- newEmptyMVar respVar <- newEmptyMVar
sqnum <- nextId idCounter sqnum <- nextId idCounter
putMVar cmdVar (RequestCancelOrder sqnum clientIdentity orderId, respVar) now <- getCurrentTime
result <- timeout 3000000 $ do
putMVar cmdVar (RequestCancelOrder (RequestId sqnum) clientIdentity orderId, respVar, now)
resp <- takeMVar respVar resp <- takeMVar respVar
case resp of case resp of
ResponseOk -> return $ Right () ResponseOk (RequestId requestId) -> do
(ResponseError msg) -> return $ Left msg if requestId == sqnum
then return $ Right ()
else do
logWith logger Warning "Broker.Client" "CancelOrder: requestId mismatch"
pure $ Left "requestid mismatch"
(ResponseError (RequestId _) msg) -> return $ Left msg
_ -> return $ Left "Unknown error" _ -> return $ Left "Unknown error"
case result of
Just r -> pure $ r
_ -> pure $ Left "Request timeout"
bcGetNotifications :: ClientIdentity -> bcGetNotifications :: ClientIdentity ->
IORef RequestSqnum -> IORef Int64 ->
IORef NotificationSqnum -> IORef NotificationSqnum ->
MVar (BrokerServerRequest, MVar BrokerServerResponse) -> MVar (BrokerServerRequest, MVar BrokerServerResponse, UTCTime) ->
IORef NotificationSqnum -> IORef NotificationSqnum ->
LogAction IO Message ->
IO (Either T.Text [Notification]) IO (Either T.Text [Notification])
bcGetNotifications clientIdentity idCounter notifSqnumRef cmdVar lastKnownNotification = do bcGetNotifications clientIdentity idCounter notifSqnumRef cmdVar lastKnownNotification logger = do
respVar <- newEmptyMVar respVar <- newEmptyMVar
sqnum <- nextId idCounter sqnum <- nextId idCounter
notifSqnum <- nextSqnum <$> readIORef notifSqnumRef notifSqnum <- nextSqnum <$> readIORef notifSqnumRef
putMVar cmdVar (RequestNotifications sqnum clientIdentity notifSqnum, respVar) now <- getCurrentTime
putMVar cmdVar (RequestNotifications (RequestId sqnum) clientIdentity notifSqnum, respVar, now)
resp <- takeMVar respVar resp <- takeMVar respVar
case resp of case resp of
(ResponseNotifications ns) -> do (ResponseNotifications (RequestId requestId) ns) ->
if (requestId == sqnum)
then do
case lastMay ns of case lastMay ns of
Just n -> atomicWriteIORef notifSqnumRef (getNotificationSqnum n) Just n -> atomicWriteIORef notifSqnumRef (getNotificationSqnum n)
Nothing -> readIORef lastKnownNotification >>= atomicWriteIORef notifSqnumRef Nothing -> readIORef lastKnownNotification >>= atomicWriteIORef notifSqnumRef
return $ Right ns return $ Right ns
(ResponseError msg) -> return $ Left msg else do
logWith logger Warning "Broker.Client" "GetNotifications: requestId mismatch"
return $ Left "requestId mismatch"
(ResponseError (RequestId requestId) msg) -> return $ Left msg
_ -> return $ Left "Unknown error" _ -> return $ Left "Unknown error"

99
src/ATrade/Broker/Protocol.hs

@ -11,11 +11,12 @@ module ATrade.Broker.Protocol (
nextSqnum, nextSqnum,
getNotificationSqnum, getNotificationSqnum,
notificationOrderId, notificationOrderId,
RequestSqnum(..),
requestSqnum,
TradeSinkMessage(..), TradeSinkMessage(..),
mkTradeMessage, mkTradeMessage,
ClientIdentity(..) ClientIdentity(..),
RequestId(..),
getRequestId,
getResponseRequestId
) where ) where
import ATrade.Types import ATrade.Types
@ -32,8 +33,10 @@ import Data.Time.Clock
import Language.Haskell.Printf import Language.Haskell.Printf
import Text.Parsec hiding ((<|>)) import Text.Parsec hiding ((<|>))
data RequestId = RequestId Int64
deriving (Eq, Show, Ord)
type ClientIdentity = T.Text type ClientIdentity = T.Text
type RequestSqnum = Int64
newtype NotificationSqnum = NotificationSqnum { unNotificationSqnum :: Int64 } newtype NotificationSqnum = NotificationSqnum { unNotificationSqnum :: Int64 }
deriving (Eq, Show, Ord) deriving (Eq, Show, Ord)
@ -77,83 +80,101 @@ instance ToJSON Notification where
toJSON (TradeNotification sqnum trade) = object [ "notification-sqnum" .= toJSON (unNotificationSqnum sqnum), "trade" .= toJSON trade ] toJSON (TradeNotification sqnum trade) = object [ "notification-sqnum" .= toJSON (unNotificationSqnum sqnum), "trade" .= toJSON trade ]
data BrokerServerRequest = RequestSubmitOrder RequestSqnum ClientIdentity Order data BrokerServerRequest = RequestSubmitOrder RequestId ClientIdentity Order
| RequestCancelOrder RequestSqnum ClientIdentity OrderId | RequestCancelOrder RequestId ClientIdentity OrderId
| RequestNotifications RequestSqnum ClientIdentity NotificationSqnum | RequestNotifications RequestId ClientIdentity NotificationSqnum
| RequestCurrentSqnum RequestSqnum ClientIdentity | RequestCurrentSqnum RequestId ClientIdentity
| RequestSetClientIdentity RequestId ClientIdentity
deriving (Eq, Show) deriving (Eq, Show)
requestSqnum :: BrokerServerRequest -> RequestSqnum getRequestId :: BrokerServerRequest -> RequestId
requestSqnum (RequestSubmitOrder sqnum _ _) = sqnum getRequestId (RequestSubmitOrder rid _ _) = rid
requestSqnum (RequestCancelOrder sqnum _ _) = sqnum getRequestId (RequestCancelOrder rid _ _) = rid
requestSqnum (RequestNotifications sqnum _ _) = sqnum getRequestId (RequestNotifications rid _ _) = rid
requestSqnum (RequestCurrentSqnum sqnum _) = sqnum getRequestId (RequestCurrentSqnum rid _) = rid
getRequestId (RequestSetClientIdentity rid _) = rid
instance FromJSON BrokerServerRequest where instance FromJSON BrokerServerRequest where
parseJSON = withObject "object" (\obj -> do parseJSON = withObject "object" (\obj -> do
sqnum <- obj .: "request-sqnum"
clientIdentity <- obj .: "client-identity" clientIdentity <- obj .: "client-identity"
parseRequest sqnum clientIdentity obj) requestId <- obj .: "request-id"
parseRequest (RequestId requestId) clientIdentity obj)
where where
parseRequest :: RequestSqnum -> ClientIdentity -> Object -> Parser BrokerServerRequest parseRequest :: RequestId -> ClientIdentity -> Object -> Parser BrokerServerRequest
parseRequest sqnum clientIdentity obj parseRequest requestId clientIdentity obj
| KM.member "order" obj = do | KM.member "order" obj = do
order <- obj .: "order" order <- obj .: "order"
RequestSubmitOrder sqnum clientIdentity <$> parseJSON order RequestSubmitOrder requestId clientIdentity <$> parseJSON order
| KM.member "cancel-order" obj = do | KM.member "cancel-order" obj = do
orderId <- obj .: "cancel-order" orderId <- obj .: "cancel-order"
RequestCancelOrder sqnum clientIdentity <$> parseJSON orderId RequestCancelOrder requestId clientIdentity <$> parseJSON orderId
| KM.member "request-notifications" obj = do | KM.member "request-notifications" obj = do
initialSqnum <- obj .: "initial-sqnum" initialSqnum <- obj .: "initial-sqnum"
return (RequestNotifications sqnum clientIdentity (NotificationSqnum initialSqnum)) return (RequestNotifications requestId clientIdentity (NotificationSqnum initialSqnum))
| KM.member "request-current-sqnum" obj = | KM.member "request-current-sqnum" obj =
return (RequestCurrentSqnum sqnum clientIdentity) return (RequestCurrentSqnum requestId clientIdentity)
| KM.member "set-client-identity" obj =
return (RequestSetClientIdentity requestId clientIdentity)
parseRequest _ _ _ = fail "Invalid request object" parseRequest _ _ _ = fail "Invalid request object"
instance ToJSON BrokerServerRequest where instance ToJSON BrokerServerRequest where
toJSON (RequestSubmitOrder sqnum clientIdentity order) = object ["request-sqnum" .= sqnum, toJSON (RequestSubmitOrder (RequestId rid) clientIdentity order) = object [
"request-id" .= rid,
"client-identity" .= clientIdentity, "client-identity" .= clientIdentity,
"order" .= order ] "order" .= order ]
toJSON (RequestCancelOrder sqnum clientIdentity oid) = object ["request-sqnum" .= sqnum, toJSON (RequestCancelOrder (RequestId rid) clientIdentity oid) = object [
"request-id" .= rid,
"client-identity" .= clientIdentity, "client-identity" .= clientIdentity,
"cancel-order" .= oid ] "cancel-order" .= oid ]
toJSON (RequestNotifications sqnum clientIdentity initialNotificationSqnum) = object ["request-sqnum" .= sqnum, toJSON (RequestNotifications (RequestId rid) clientIdentity initialNotificationSqnum) = object [
"request-id" .= rid,
"client-identity" .= clientIdentity, "client-identity" .= clientIdentity,
"request-notifications" .= ("" :: T.Text), "request-notifications" .= ("" :: T.Text),
"initial-sqnum" .= unNotificationSqnum initialNotificationSqnum] "initial-sqnum" .= unNotificationSqnum initialNotificationSqnum]
toJSON (RequestCurrentSqnum sqnum clientIdentity) = object toJSON (RequestCurrentSqnum (RequestId rid) clientIdentity) = object
["request-sqnum" .= sqnum, ["request-id" .= rid,
"client-identity" .= clientIdentity, "client-identity" .= clientIdentity,
"request-current-sqnum" .= ("" :: T.Text) ] "request-current-sqnum" .= ("" :: T.Text) ]
toJSON (RequestSetClientIdentity (RequestId rid) clientIdentity) = object
["request-id" .= rid,
"client-identity" .= clientIdentity,
"set-client-identity" .= ("" :: T.Text) ]
data BrokerServerResponse = ResponseOk getResponseRequestId :: BrokerServerResponse -> RequestId
| ResponseNotifications [Notification] getResponseRequestId (ResponseOk reqId) = reqId
| ResponseCurrentSqnum NotificationSqnum getResponseRequestId (ResponseNotifications reqId _) = reqId
| ResponseError T.Text getResponseRequestId (ResponseCurrentSqnum reqId _) = reqId
getResponseRequestId (ResponseError reqId _) = reqId
data BrokerServerResponse = ResponseOk RequestId
| ResponseNotifications RequestId [Notification]
| ResponseCurrentSqnum RequestId NotificationSqnum
| ResponseError RequestId T.Text
deriving (Eq, Show) deriving (Eq, Show)
instance FromJSON BrokerServerResponse where instance FromJSON BrokerServerResponse where
parseJSON = withObject "object" (\obj -> parseJSON = withObject "object" (\obj -> do
requestId <- obj .: "request-id"
if | KM.member "result" obj -> do if | KM.member "result" obj -> do
result <- obj .: "result" result <- obj .: "result"
if (result :: T.Text) == "success" if (result :: T.Text) == "success"
then return ResponseOk then return $ ResponseOk (RequestId requestId)
else do else do
msg <- obj .:? "message" .!= "" msg <- obj .:? "message" .!= ""
return (ResponseError msg) return $ (ResponseError (RequestId requestId) msg)
| KM.member "notifications" obj -> do | KM.member "notifications" obj -> do
notifications <- obj .: "notifications" notifications <- obj .: "notifications"
ResponseNotifications <$> parseJSON notifications ResponseNotifications (RequestId requestId) <$> parseJSON notifications
| KM.member "current-sqnum" obj -> do | KM.member "current-sqnum" obj -> do
rawSqnum <- obj .: "current-sqnum" rawSqnum <- obj .: "current-sqnum"
return $ ResponseCurrentSqnum (NotificationSqnum rawSqnum) return $ ResponseCurrentSqnum (RequestId requestId) (NotificationSqnum rawSqnum)
| otherwise -> fail "Unable to parse BrokerServerResponse") | otherwise -> fail "Unable to parse BrokerServerResponse")
instance ToJSON BrokerServerResponse where instance ToJSON BrokerServerResponse where
toJSON ResponseOk = object [ "result" .= ("success" :: T.Text) ] toJSON (ResponseOk (RequestId rid)) = object [ "request-id" .= rid, "result" .= ("success" :: T.Text) ]
toJSON (ResponseNotifications notifications) = object [ "notifications" .= notifications ] toJSON (ResponseNotifications (RequestId rid) notifications) = object [ "request-id" .= rid, "notifications" .= notifications ]
toJSON (ResponseCurrentSqnum sqnum) = object [ "current-sqnum" .= unNotificationSqnum sqnum ] toJSON (ResponseCurrentSqnum (RequestId rid) sqnum) = object [ "request-id" .= rid, "current-sqnum" .= unNotificationSqnum sqnum ]
toJSON (ResponseError errorMessage) = object [ "result" .= ("error" :: T.Text), "message" .= errorMessage ] toJSON (ResponseError (RequestId rid) errorMessage) = object [ "request-id" .= rid, "result" .= ("error" :: T.Text), "message" .= errorMessage ]
data TradeSinkMessage = TradeSinkHeartBeat | TradeSinkTrade { data TradeSinkMessage = TradeSinkHeartBeat | TradeSinkTrade {
tsAccountId :: T.Text, tsAccountId :: T.Text,

277
src/ATrade/Broker/Server.hs

@ -1,4 +1,5 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
module ATrade.Broker.Server ( module ATrade.Broker.Server (
startBrokerServer, startBrokerServer,
@ -15,18 +16,19 @@ import ATrade.Broker.Protocol (BrokerServerRequest (..),
ClientIdentity, ClientIdentity,
Notification (..), Notification (..),
NotificationSqnum (NotificationSqnum), NotificationSqnum (NotificationSqnum),
RequestSqnum, RequestId (..),
getNotificationSqnum, getNotificationSqnum,
nextSqnum, requestSqnum) getRequestId, nextSqnum)
import ATrade.Logging (Message (Message), import ATrade.Logging (Message (Message),
Severity (Debug, Warning), Severity (Debug, Warning),
logWith) Severity (Info), logWith)
import ATrade.Logging (Severity (Info))
import ATrade.Types (Order (orderAccountId, orderId), import ATrade.Types (Order (orderAccountId, orderId),
OrderId, OrderId,
ServerSecurityParams (sspCertificate, sspDomain), ServerSecurityParams (sspCertificate, sspDomain),
Trade (tradeOrderId)) Trade (tradeOrderId))
import ATrade.Util (atomicMapIORef) import ATrade.Util (atomicMapIORef)
import ATrade.Utils.MessagePipe (emptyMessagePipe, getMessages,
push)
import Colog (LogAction) import Colog (LogAction)
import Control.Concurrent (MVar, ThreadId, forkIO, import Control.Concurrent (MVar, ThreadId, forkIO,
killThread, myThreadId, killThread, myThreadId,
@ -34,49 +36,54 @@ import Control.Concurrent (MVar, ThreadId, forkIO,
readMVar, threadDelay, readMVar, threadDelay,
tryReadMVar, yield) tryReadMVar, yield)
import Control.Concurrent.BoundedChan (BoundedChan, newBoundedChan, import Control.Concurrent.BoundedChan (BoundedChan, newBoundedChan,
tryReadChan, tryWriteChan) readChan, tryReadChan,
import Control.Exception (finally) tryWriteChan)
import Control.Monad (unless) import Control.Exception (bracket, finally)
import Control.Monad (unless, void, when)
import Control.Monad.Extra (forever)
import Control.Monad.Loops (whileM_) import Control.Monad.Loops (whileM_)
import Data.Aeson (eitherDecode, encode) import Data.Aeson (eitherDecode, encode)
import qualified Data.Bimap as BM import qualified Data.Bimap as BM
import qualified Data.ByteString as B hiding (putStrLn) import qualified Data.ByteString as B hiding (putStrLn)
import qualified Data.ByteString.Lazy as BL hiding (putStrLn) import qualified Data.ByteString.Lazy as BL hiding (putStrLn)
import Data.IORef (IORef, atomicModifyIORef', import Data.IORef (IORef, atomicModifyIORef',
newIORef, readIORef) newIORef, readIORef,
writeIORef)
import qualified Data.List as L import qualified Data.List as L
import Data.List.NonEmpty (NonEmpty ((:|))) import Data.List.NonEmpty (NonEmpty ((:|)))
import qualified Data.Map as M import qualified Data.Map as M
import Data.Maybe (isJust, isNothing) import Data.Maybe (isJust, isNothing)
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Encoding as E import qualified Data.Text.Encoding as E
import qualified Data.Text.Lazy as TL
import Data.Time.Clock () import Data.Time.Clock ()
import Safe (lastMay) import Language.Haskell.Printf
import Network.Socket (Family (AF_INET),
SockAddr (SockAddrInet),
Socket, SocketType (Stream),
accept, bind, defaultProtocol,
listen, socket)
import Network.Socket.ByteString (recv, sendAll)
import Safe (lastMay, readMay)
import System.Timeout () import System.Timeout ()
import System.ZMQ4 (Context, Event (In), import System.ZMQ4 hiding (Socket, Stream, bind,
Poll (Sock), Pub (..), socket)
Router (..), Socket,
Switch (On), bind, close, poll,
receiveMulti, restrict,
sendMulti, setCurveServer,
setLinger, setTcpKeepAlive,
setTcpKeepAliveCount,
setTcpKeepAliveIdle,
setTcpKeepAliveInterval,
setZapDomain, socket)
import System.ZMQ4.ZAP (zapApplyCertificate)
type PeerId = B.ByteString type PeerId = B.ByteString
data FullOrderId = FullOrderId ClientIdentity OrderId data FullOrderId = FullOrderId ClientIdentity OrderId
deriving (Show, Eq, Ord) deriving (Show, Eq, Ord)
data ClientState = ClientState {
cThreadId :: ThreadId,
cSocket :: Socket,
cClientIdentity :: ClientIdentity,
cEgressQueue :: BoundedChan B.ByteString
}
data BrokerServerState = BrokerServerState { data BrokerServerState = BrokerServerState {
bsSocket :: Socket Router,
bsNotificationsSocket :: Socket Pub,
orderToBroker :: M.Map FullOrderId BrokerBackend, orderToBroker :: M.Map FullOrderId BrokerBackend,
orderMap :: BM.Bimap FullOrderId OrderId, orderMap :: BM.Bimap FullOrderId OrderId,
lastPacket :: M.Map PeerId (RequestSqnum, BrokerServerResponse),
pendingNotifications :: M.Map ClientIdentity [Notification], pendingNotifications :: M.Map ClientIdentity [Notification],
notificationSqnum :: M.Map ClientIdentity NotificationSqnum, notificationSqnum :: M.Map ClientIdentity NotificationSqnum,
brokers :: [BrokerBackend], brokers :: [BrokerBackend],
@ -85,56 +92,33 @@ data BrokerServerState = BrokerServerState {
orderIdCounter :: OrderId, orderIdCounter :: OrderId,
tradeSink :: BoundedChan Trade, tradeSink :: BoundedChan Trade,
initialSqnum :: NotificationSqnum initialSqnum :: NotificationSqnum
} }
data BrokerServerHandle = BrokerServerHandle ThreadId ThreadId (MVar ()) (MVar ()) data BrokerServerHandle = BrokerServerHandle
{
bhServerTid :: ThreadId
, bhClients :: IORef (M.Map ClientIdentity ClientState)
, bhKillMVar :: MVar ()
, bhCompletionMVar :: MVar ()
}
type TradeSink = Trade -> IO () type TradeSink = Trade -> IO ()
startBrokerServer :: [BrokerBackend] -> startBrokerServer :: [BrokerBackend] ->
Context -> Context ->
T.Text -> T.Text ->
T.Text ->
NotificationSqnum -> NotificationSqnum ->
[TradeSink] -> [TradeSink] ->
ServerSecurityParams ->
LogAction IO Message -> LogAction IO Message ->
IO BrokerServerHandle IO BrokerServerHandle
startBrokerServer brokers c ep notificationsEp initialSqnum tradeSinks params logger = do startBrokerServer brokers c ep initialSqnum tradeSinks logger = do
sock <- socket c Router
notificationsSock <- socket c Pub
setLinger (restrict 0) sock
setLinger (restrict 0) notificationsSock
case sspDomain params of
Just domain -> do
setZapDomain (restrict $ E.encodeUtf8 domain) sock
setZapDomain (restrict $ E.encodeUtf8 domain) notificationsSock
Nothing -> return ()
case sspCertificate params of
Just cert -> do
setCurveServer True sock
zapApplyCertificate cert sock
setCurveServer True notificationsSock
zapApplyCertificate cert notificationsSock
Nothing -> return ()
bind sock (T.unpack ep)
setTcpKeepAlive On notificationsSock
setTcpKeepAliveCount (restrict 5) notificationsSock
setTcpKeepAliveIdle (restrict 60) notificationsSock
setTcpKeepAliveInterval (restrict 10) notificationsSock
bind notificationsSock (T.unpack notificationsEp)
tid <- myThreadId
compMv <- newEmptyMVar compMv <- newEmptyMVar
killMv <- newEmptyMVar killMv <- newEmptyMVar
tsChan <- newBoundedChan 100 tsChan <- newBoundedChan 100
clientsMapRef <- newIORef M.empty
state <- newIORef BrokerServerState { state <- newIORef BrokerServerState {
bsSocket = sock,
bsNotificationsSocket = notificationsSock,
orderMap = BM.empty, orderMap = BM.empty,
orderToBroker = M.empty, orderToBroker = M.empty,
lastPacket = M.empty,
pendingNotifications = M.empty, pendingNotifications = M.empty,
notificationSqnum = M.empty, notificationSqnum = M.empty,
brokers = brokers, brokers = brokers,
@ -144,18 +128,45 @@ startBrokerServer brokers c ep notificationsEp initialSqnum tradeSinks params lo
tradeSink = tsChan, tradeSink = tsChan,
initialSqnum = initialSqnum initialSqnum = initialSqnum
} }
mapM_ (\bro -> setNotificationCallback bro (Just $ notificationCallback state logger)) brokers
log Info "Broker.Server" "Forking broker server thread" let Just (_, port) = parseHostAndPort ep
BrokerServerHandle <$> forkIO (brokerServerThread state logger) <*> forkIO (tradeSinkHandler c state tradeSinks) <*> pure compMv <*> pure killMv serverSocket <- socket AF_INET Stream defaultProtocol
bind serverSocket $ SockAddrInet (fromIntegral port) 0
log Info "Broker.Server" $ TL.toStrict $ [t|Listening on port %?|] $ fromIntegral port
listen serverSocket 1024
serverTid <- forkIO $ forever $ do
(client, addr) <- accept serverSocket
log Debug "Broker.Server" "Incoming connection"
rawRequest <- recv client 4096
case eitherDecode $ BL.fromStrict $ B.init rawRequest of
Left err -> log Warning "Broker.Server" $ "Unable to decode client id: " <> (T.pack . show) rawRequest
Right (RequestSetClientIdentity requestId clientIdentity) -> do
log Info "Broker.Server" $ "Connected socket identity: " <> (T.pack . show) clientIdentity
egressQueue <- newBoundedChan 100
clientTid <- forkIO $ clientThread client egressQueue clientsMapRef state logger
let clientState = ClientState clientTid client clientIdentity egressQueue
atomicModifyIORef' clientsMapRef (\m -> (M.insert clientIdentity clientState m, ()))
_ -> log Warning "Broker.Server" $ "Invalid first message: " <> (T.pack . show) rawRequest
mapM_ (\bro -> setNotificationCallback bro (Just $ notificationCallback state clientsMapRef logger)) brokers
pure $ BrokerServerHandle serverTid clientsMapRef killMv compMv
where where
log = logWith logger log = logWith logger
parseHostAndPort :: T.Text -> Maybe (T.Text, Int)
parseHostAndPort str = case T.splitOn ":" str of
[host, port] ->
case readMay $ T.unpack port of
Just numPort -> Just (host, numPort)
_ -> Nothing
_ -> Nothing
notificationCallback :: IORef BrokerServerState -> notificationCallback :: IORef BrokerServerState ->
IORef (M.Map ClientIdentity ClientState) ->
LogAction IO Message -> LogAction IO Message ->
BrokerBackendNotification -> BrokerBackendNotification ->
IO () IO ()
notificationCallback state logger n = do notificationCallback state clientsMapRef logger n = do
log Debug "Broker.Server" $ "Notification: " <> (T.pack . show) n log Debug "Broker.Server" $ "Notification: " <> (T.pack . show) n
chan <- tradeSink <$> readIORef state chan <- tradeSink <$> readIORef state
case n of case n of
@ -180,8 +191,10 @@ notificationCallback state logger n = do
case M.lookup clientIdentity . pendingNotifications $ s of case M.lookup clientIdentity . pendingNotifications $ s of
Just ns -> s { pendingNotifications = M.insert clientIdentity (n : ns) (pendingNotifications s)} Just ns -> s { pendingNotifications = M.insert clientIdentity (n : ns) (pendingNotifications s)}
Nothing -> s { pendingNotifications = M.insert clientIdentity [n] (pendingNotifications s)}) Nothing -> s { pendingNotifications = M.insert clientIdentity [n] (pendingNotifications s)})
sock <- bsNotificationsSocket <$> readIORef state clients <- readIORef clientsMapRef
sendMulti sock (E.encodeUtf8 clientIdentity :| [BL.toStrict $ encode n]) case M.lookup clientIdentity clients of
Just client -> void $ tryWriteChan (cEgressQueue client) $ BL.toStrict $ encode n
Nothing -> log Warning "Broker.Server" $ TL.toStrict $ [t|Unable to send notification to %?|] clientIdentity
tradeSinkHandler :: Context -> IORef BrokerServerState -> [TradeSink] -> IO () tradeSinkHandler :: Context -> IORef BrokerServerState -> [TradeSink] -> IO ()
tradeSinkHandler c state tradeSinks = unless (null tradeSinks) $ tradeSinkHandler c state tradeSinks = unless (null tradeSinks) $
@ -195,118 +208,108 @@ tradeSinkHandler c state tradeSinks = unless (null tradeSinks) $
wasKilled = isJust <$> (readIORef state >>= tryReadMVar . killMvar) wasKilled = isJust <$> (readIORef state >>= tryReadMVar . killMvar)
brokerServerThread :: IORef BrokerServerState -> clientThread :: Socket ->
BoundedChan B.ByteString ->
IORef (M.Map ClientIdentity ClientState) ->
IORef BrokerServerState ->
LogAction IO Message -> LogAction IO Message ->
IO () IO ()
brokerServerThread state logger = finally brokerServerThread' cleanup clientThread socket egressQueue clients serverState logger =
bracket
(forkIO sendingThread)
(\tid -> do
log Debug "Broker.Server" "Killing sending thread"
killThread tid)
brokerServerThread'
where where
log = logWith logger log = logWith logger
brokerServerThread' = whileM_ (fmap killMvar (readIORef state) >>= fmap isNothing . tryReadMVar) $ do brokerServerThread' _ = do
sock <- bsSocket <$> readIORef state pipeRef <- newIORef emptyMessagePipe
events <- poll 100 [Sock sock [In] Nothing] brokerServerThread'' pipeRef
unless (null . L.head $ events) $ do log Info "Broker.Server" "Client disconnected"
msg <- receiveMulti sock
case msg of brokerServerThread'' pipeRef = do
[peerId, _, payload] -> do rawData <- recv socket 4096
case eitherDecode . BL.fromStrict $ payload of when (B.length rawData > 0) $ do
Right request -> do pipe <- readIORef pipeRef
let sqnum = requestSqnum request let (pipe', chunks) = getMessages (push rawData pipe)
-- Here, we should check if previous packet sequence number is the same writeIORef pipeRef pipe'
-- If it is, we should resend previous response mapM_ (handleChunk egressQueue) chunks
lastPackMap <- lastPacket <$> readIORef state brokerServerThread'' pipeRef
case shouldResend sqnum peerId lastPackMap of
Just response -> do
log Debug "Broker.Server" $ "Resending packet for peerId: " <> (T.pack . show) peerId
sendMessage sock peerId response -- Resend
atomicMapIORef state (\s -> s { lastPacket = M.delete peerId (lastPacket s)})
Nothing -> do
-- Handle incoming request, send response
response <- handleMessage peerId request
sendMessage sock peerId response
-- and store response in case we'll need to resend it
atomicMapIORef state (\s -> s { lastPacket = M.insert peerId (sqnum, response) (lastPacket s)})
Left errmsg -> do
-- If we weren't able to parse request, we should send error
-- but shouldn't update lastPacket
let response = ResponseError $ "Invalid request: " <> T.pack errmsg
sendMessage sock peerId response
_ -> log Warning "Broker.Server" ("Invalid packet received: " <> (T.pack . show) msg)
shouldResend sqnum peerId lastPackMap = case M.lookup peerId lastPackMap of sendingThread = forever $ do
Just (lastSqnum, response) -> if sqnum == lastSqnum packet <- readChan egressQueue
then Just response log Debug "Broker.Server" $ TL.toStrict $ [t|Sending packet: %?|] packet
else Nothing sendAll socket $ B.snoc packet 0
Nothing -> Nothing
cleanup = do enqueueEgressPacket = tryWriteChan egressQueue
sock <- bsSocket <$> readIORef state
close sock
mv <- completionMvar <$> readIORef state handleChunk egressQueue payload = do
putMVar mv () response <- case eitherDecode . BL.fromStrict $ payload of
Right request -> handleMessage request
Left errmsg -> pure $ ResponseError (RequestId 0) $ "Invalid request: " <> T.pack errmsg
enqueueEgressPacket $ BL.toStrict $ encode response
handleMessage :: PeerId -> BrokerServerRequest -> IO BrokerServerResponse handleMessage :: BrokerServerRequest -> IO BrokerServerResponse
handleMessage peerId request = do handleMessage request = do
bros <- brokers <$> readIORef state log Debug "Broker.Server" "Handle message"
bros <- brokers <$> readIORef serverState
case request of case request of
RequestSubmitOrder sqnum clientIdentity order -> do RequestSubmitOrder requestId clientIdentity order -> do
log Debug "Broker.Server" $ "Request: submit order:" <> (T.pack . show) request log Debug "Broker.Server" $ "Request: submit order:" <> (T.pack . show) request
case findBrokerForAccount (orderAccountId order) bros of case findBrokerForAccount (orderAccountId order) bros of
Just bro -> do Just bro -> do
globalOrderId <- nextOrderId globalOrderId <- nextOrderId
let fullOrderId = FullOrderId clientIdentity (orderId order) let fullOrderId = FullOrderId clientIdentity (orderId order)
atomicMapIORef state (\s -> s { atomicMapIORef serverState (\s -> s {
orderToBroker = M.insert fullOrderId bro (orderToBroker s), orderToBroker = M.insert fullOrderId bro (orderToBroker s),
orderMap = BM.insert fullOrderId globalOrderId (orderMap s) }) orderMap = BM.insert fullOrderId globalOrderId (orderMap s) })
submitOrder bro order { orderId = globalOrderId } submitOrder bro order { orderId = globalOrderId }
return ResponseOk return $ ResponseOk requestId
Nothing -> do Nothing -> do
log Warning "Broker.Server" $ "Unknown account: " <> (orderAccountId order) log Warning "Broker.Server" $ "Unknown account: " <> orderAccountId order
return $ ResponseError "Unknown account" return $ ResponseError requestId "Unknown account"
RequestCancelOrder sqnum clientIdentity localOrderId -> do RequestCancelOrder requestId clientIdentity localOrderId -> do
log Debug "Broker.Server" $ "Request: cancel order:" <> (T.pack . show) request log Debug "Broker.Server" $ "Request: cancel order:" <> (T.pack . show) request
m <- orderToBroker <$> readIORef state m <- orderToBroker <$> readIORef serverState
bm <- orderMap <$> readIORef state bm <- orderMap <$> readIORef serverState
let fullOrderId = FullOrderId clientIdentity localOrderId let fullOrderId = FullOrderId clientIdentity localOrderId
case (M.lookup fullOrderId m, BM.lookup fullOrderId bm) of case (M.lookup fullOrderId m, BM.lookup fullOrderId bm) of
(Just bro, Just globalOrderId) -> do (Just bro, Just globalOrderId) -> do
cancelOrder bro globalOrderId cancelOrder bro globalOrderId
return ResponseOk return $ ResponseOk requestId
_ -> return $ ResponseError "Unknown order" _ -> return $ ResponseError requestId "Unknown order"
RequestNotifications sqnum clientIdentity initialSqnum -> do RequestNotifications requestId clientIdentity initialSqnum -> do
log Debug "Broker.Server" $ "Request: notifications:" <> (T.pack . show) request log Debug "Broker.Server" $ "Request: notifications:" <> (T.pack . show) request
maybeNs <- M.lookup clientIdentity . pendingNotifications <$> readIORef state maybeNs <- M.lookup clientIdentity . pendingNotifications <$> readIORef serverState
case maybeNs of case maybeNs of
Just ns -> do Just ns -> do
let filtered = L.filter (\n -> getNotificationSqnum n >= initialSqnum) ns let filtered = L.filter (\n -> getNotificationSqnum n >= initialSqnum) ns
atomicMapIORef state (\s -> s { pendingNotifications = M.insert clientIdentity filtered (pendingNotifications s)}) atomicMapIORef serverState (\s -> s { pendingNotifications = M.insert clientIdentity filtered (pendingNotifications s)})
return $ ResponseNotifications . L.reverse $ filtered return $ ResponseNotifications requestId . L.reverse $ filtered
Nothing -> return $ ResponseNotifications [] Nothing -> return $ ResponseNotifications requestId []
RequestCurrentSqnum sqnum clientIdentity -> do RequestCurrentSqnum requestId clientIdentity -> do
log Debug "Broker.Server" $ "Request: current sqnum:" <> (T.pack . show) request log Debug "Broker.Server" $ "Request: current sqnum:" <> (T.pack . show) request
sqnumMap <- notificationSqnum <$> readIORef state sqnumMap <- notificationSqnum <$> readIORef serverState
notifMap <- pendingNotifications <$> readIORef state notifMap <- pendingNotifications <$> readIORef serverState
case M.lookup clientIdentity notifMap of case M.lookup clientIdentity notifMap of
Just [] -> Just [] ->
case M.lookup clientIdentity sqnumMap of case M.lookup clientIdentity sqnumMap of
Just sqnum -> return (ResponseCurrentSqnum sqnum) Just sqnum -> return (ResponseCurrentSqnum requestId sqnum)
_ -> return (ResponseCurrentSqnum (NotificationSqnum 1)) _ -> return (ResponseCurrentSqnum requestId (NotificationSqnum 1))
Just notifs -> case lastMay notifs of Just notifs -> case lastMay notifs of
Just v -> return (ResponseCurrentSqnum (getNotificationSqnum v)) Just v -> return (ResponseCurrentSqnum requestId (getNotificationSqnum v))
_ -> return (ResponseCurrentSqnum (NotificationSqnum 1)) _ -> return (ResponseCurrentSqnum requestId (NotificationSqnum 1))
Nothing -> return (ResponseCurrentSqnum (NotificationSqnum 1)) Nothing -> return (ResponseCurrentSqnum requestId (NotificationSqnum 1))
RequestSetClientIdentity requestId _ -> pure $ ResponseError requestId "Client identity change is not supported"
sendMessage sock peerId resp = sendMulti sock (peerId :| [B.empty, BL.toStrict . encode $ resp])
findBrokerForAccount account = L.find (L.elem account . accounts) findBrokerForAccount account = L.find (L.elem account . accounts)
nextOrderId = atomicModifyIORef' state (\s -> ( s {orderIdCounter = 1 + orderIdCounter s}, orderIdCounter s)) nextOrderId = atomicModifyIORef' serverState (\s -> ( s {orderIdCounter = 1 + orderIdCounter s}, orderIdCounter s))
stopBrokerServer :: BrokerServerHandle -> IO () stopBrokerServer :: BrokerServerHandle -> IO ()
stopBrokerServer (BrokerServerHandle tid tstid compMv killMv) = do stopBrokerServer (BrokerServerHandle tid clients compMv killMv) = do
putMVar killMv () putMVar killMv ()
killThread tstid readIORef clients >>= mapM_ (killThread . cThreadId) . M.elems
yield yield
readMVar compMv readMVar compMv

Loading…
Cancel
Save