diff --git a/src/ATrade/Broker/Client.hs b/src/ATrade/Broker/Client.hs index 053e93f..6854c2d 100644 --- a/src/ATrade/Broker/Client.hs +++ b/src/ATrade/Broker/Client.hs @@ -5,7 +5,8 @@ module ATrade.Broker.Client ( stopBrokerClient, submitOrder, cancelOrder, - getNotifications + getNotifications, + NotificationCallback(..) ) where import ATrade.Broker.Protocol @@ -26,11 +27,14 @@ import Data.List.NonEmpty import Data.Maybe import qualified Data.Text as T import Data.Text.Encoding +import qualified Data.Text.Encoding as T import System.Log.Logger import System.Timeout import System.ZMQ4 import System.ZMQ4.ZAP +type NotificationCallback = Notification -> IO () + data BrokerClientHandle = BrokerClientHandle { tid :: ThreadId, completionMvar :: MVar (), @@ -39,7 +43,9 @@ data BrokerClientHandle = BrokerClientHandle { cancelOrder :: OrderId -> IO (Either T.Text ()), getNotifications :: IO (Either T.Text [Notification]), cmdVar :: MVar (BrokerServerRequest, MVar BrokerServerResponse), - lastKnownNotificationRef :: IORef NotificationSqnum + lastKnownNotificationRef :: IORef NotificationSqnum, + notificationCallback :: [NotificationCallback], + notificationThreadId :: ThreadId } brokerClientThread :: B.ByteString -> Context -> T.Text -> MVar (BrokerServerRequest, MVar BrokerServerResponse) -> MVar () -> MVar () -> ClientSecurityParams -> IO () @@ -83,14 +89,35 @@ brokerClientThread socketIdentity ctx ep cmd comp killMv secParams = finally bro threadDelay 1000000) isZMQError e = "ZMQError" `L.isPrefixOf` show e -startBrokerClient :: B.ByteString -> Context -> T.Text -> ClientSecurityParams -> IO BrokerClientHandle -startBrokerClient socketIdentity ctx endpoint secParams = do + +notificationThread :: ClientIdentity -> [NotificationCallback] -> Context -> T.Text -> MVar () -> IO () +notificationThread clientIdentity callbacks ctx ep killMv = flip finally (return ()) $ do + whileM_ (isNothing <$> tryReadMVar killMv) $ + withSocket ctx Sub $ \sock -> do + setTcpKeepAlive On sock + setTcpKeepAliveCount (restrict 5) sock + setTcpKeepAliveIdle (restrict 60) sock + setTcpKeepAliveInterval (restrict 10) sock + connect sock $ T.unpack ep + subscribe sock $ T.encodeUtf8 clientIdentity + whileM_ (isNothing <$> tryReadMVar killMv) $ do + msg <- receiveMulti sock + case msg of + [_, payload] -> case decode (BL.fromStrict payload) of + Just notification -> forM_ callbacks $ \c -> c notification + _ -> return () + _ -> return () + + +startBrokerClient :: B.ByteString -> Context -> T.Text -> T.Text -> [NotificationCallback] -> ClientSecurityParams -> IO BrokerClientHandle +startBrokerClient socketIdentity ctx endpoint notifEndpoint notificationCallbacks secParams = do idCounter <- newIORef 1 compMv <- newEmptyMVar killMv <- newEmptyMVar cmdVar <- newEmptyMVar :: IO (MVar (BrokerServerRequest, MVar BrokerServerResponse)) tid <- forkIO (brokerClientThread socketIdentity ctx endpoint cmdVar compMv killMv secParams) notifSqnumRef <- newIORef (NotificationSqnum 0) + notifThreadId <- forkIO (notificationThread (T.decodeUtf8 socketIdentity) notificationCallbacks ctx notifEndpoint killMv) return BrokerClientHandle { tid = tid, @@ -100,11 +127,19 @@ startBrokerClient socketIdentity ctx endpoint secParams = do cancelOrder = bcCancelOrder (decodeUtf8 socketIdentity) idCounter cmdVar, getNotifications = bcGetNotifications (decodeUtf8 socketIdentity) idCounter notifSqnumRef cmdVar, cmdVar = cmdVar, - lastKnownNotificationRef = notifSqnumRef + lastKnownNotificationRef = notifSqnumRef, + notificationCallback = [], + notificationThreadId = notifThreadId } stopBrokerClient :: BrokerClientHandle -> IO () -stopBrokerClient handle = putMVar (killMvar handle) () >> yield >> killThread (tid handle) >> readMVar (completionMvar handle) +stopBrokerClient handle = do + putMVar (killMvar handle) () + yield + killThread (tid handle) + killThread (notificationThreadId handle) + yield + readMVar (completionMvar handle) nextId cnt = atomicModifyIORef' cnt (\v -> (v + 1, v)) diff --git a/test/MockBroker.hs b/test/MockBroker.hs index 88aa8f3..7fe2cd2 100644 --- a/test/MockBroker.hs +++ b/test/MockBroker.hs @@ -39,7 +39,12 @@ mockCancelOrder :: IORef MockBrokerState -> OrderId -> IO () mockCancelOrder state oid = do ors <- orders <$> readIORef state case L.find (\o -> orderId o == oid) ors of - Just order -> atomicModifyIORef' state (\s -> (s { cancelledOrders = order : cancelledOrders s}, ())) + Just order -> do + atomicModifyIORef' state (\s -> (s { cancelledOrders = order : cancelledOrders s}, ())) + maybeCb <- notificationCallback <$> readIORef state + case maybeCb of + Just cb -> cb $ BackendOrderNotification (orderId order) Cancelled + Nothing -> return () Nothing -> return () mockStopBroker :: IORef MockBrokerState -> IO () diff --git a/test/TestBrokerClient.hs b/test/TestBrokerClient.hs index 52bfc79..484c655 100644 --- a/test/TestBrokerClient.hs +++ b/test/TestBrokerClient.hs @@ -55,21 +55,36 @@ defaultOrder = mkOrder { orderOperation = Buy } +makeNotificationCallback :: IO (IORef [Notification], NotificationCallback) +makeNotificationCallback = do + ref <- newIORef [] + return (ref, \n -> atomicModifyIORef' ref (\s -> (n : s, ()))) + testBrokerClientStartStop = testCase "Broker client: submit order" $ withContext (\ctx -> do (ep, notifEp) <- makeEndpoints + (ref, callback) <- makeNotificationCallback (mockBroker, broState) <- mkMockBroker ["demo"] bracket (startBrokerServer [mockBroker] ctx ep notifEp [] defaultServerSecurityParams) stopBrokerServer (\broS -> - bracket (startBrokerClient "foo" ctx ep defaultClientSecurityParams) stopBrokerClient (\broC -> do - oid <- submitOrder broC defaultOrder - case oid of + bracket (startBrokerClient "foo" ctx ep notifEp [callback] defaultClientSecurityParams) stopBrokerClient (\broC -> do + result <- submitOrder broC defaultOrder + case result of Left err -> assertFailure "Invalid response" - Right _ -> return ()))) + Right _ -> do + threadDelay 10000 -- Wait for callback + notifs <- readIORef ref + case head notifs of + OrderNotification _ oid newState -> do + newState @=? Submitted + oid @=? orderId defaultOrder + _ -> assertFailure "Invalid notification" + ))) testBrokerClientCancelOrder = testCase "Broker client: submit and cancel order" $ withContext (\ctx -> do (ep, notifEp) <- makeEndpoints + (ref, callback) <- makeNotificationCallback (mockBroker, broState) <- mkMockBroker ["demo"] bracket (startBrokerServer [mockBroker] ctx ep notifEp [] defaultServerSecurityParams) stopBrokerServer (\broS -> - bracket (startBrokerClient "foo" ctx ep defaultClientSecurityParams) stopBrokerClient (\broC -> do + bracket (startBrokerClient "foo" ctx ep notifEp [callback] defaultClientSecurityParams) stopBrokerClient (\broC -> do maybeOid <- submitOrder broC defaultOrder case maybeOid of Left err -> assertFailure "Invalid response" @@ -82,9 +97,10 @@ testBrokerClientCancelOrder = testCase "Broker client: submit and cancel order" testBrokerClientGetNotifications = testCase "Broker client: get notifications" $ withContext (\ctx -> do (ep, notifEp) <- makeEndpoints + (ref, callback) <- makeNotificationCallback (mockBroker, broState) <- mkMockBroker ["demo"] bracket (startBrokerServer [mockBroker] ctx ep notifEp [] defaultServerSecurityParams) stopBrokerServer (\broS -> - bracket (startBrokerClient "foo" ctx ep defaultClientSecurityParams) stopBrokerClient (\broC -> do + bracket (startBrokerClient "foo" ctx ep notifEp [callback] defaultClientSecurityParams) stopBrokerClient (\broC -> do maybeOid <- submitOrder broC defaultOrder case maybeOid of Left err -> assertFailure "Invalid response"