Browse Source

Simplifed database access & added tests

master
Denis Tereshkin 8 years ago
parent
commit
889aaa01ce
  1. 21
      mds.cabal
  2. 154
      src/ATrade/MDS/Database.hs
  3. 34
      src/ATrade/MDS/HistoryServer.hs
  4. 23
      src/ATrade/MDS/Protocol.hs
  5. 6
      src/Lib.hs
  6. 5
      stack.yaml
  7. 77
      test/Integration/Database.hs
  8. 12
      test/Integration/Spec.hs
  9. 11
      test/Spec.hs

21
mds.cabal

@ -15,10 +15,12 @@ cabal-version: >=1.10
library library
hs-source-dirs: src hs-source-dirs: src
ghc-options: -Wall -Werror
exposed-modules: ATrade.MDS.Database exposed-modules: ATrade.MDS.Database
, ATrade.MDS.HistoryServer
build-depends: base >= 4.7 && < 5 build-depends: base >= 4.7 && < 5
, HDBC , HDBC
, HDBC-postgresql , HDBC-sqlite3
, configurator , configurator
, text , text
, vector , vector
@ -28,12 +30,16 @@ library
, monad-loops , monad-loops
, text-format , text-format
, zeromq4-haskell , zeromq4-haskell
, aeson
, safe
, bytestring
default-language: Haskell2010 default-language: Haskell2010
other-modules: ATrade.MDS.Protocol
executable mds-exe executable mds-exe
hs-source-dirs: app hs-source-dirs: app
main-is: Main.hs main-is: Main.hs
ghc-options: -threaded -rtsopts -with-rtsopts=-N ghc-options: -threaded -rtsopts -with-rtsopts=-N -Wall -Werror
build-depends: base build-depends: base
, mds , mds
default-language: Haskell2010 default-language: Haskell2010
@ -44,8 +50,19 @@ test-suite mds-test
main-is: Spec.hs main-is: Spec.hs
build-depends: base build-depends: base
, mds , mds
, libatrade
, temporary
, datetime
, vector
, text
, time
, tasty
, tasty-hunit
ghc-options: -threaded -rtsopts -with-rtsopts=-N ghc-options: -threaded -rtsopts -with-rtsopts=-N
default-language: Haskell2010 default-language: Haskell2010
other-modules: Integration.Spec
, Integration.Database
extensions: OverloadedStrings
source-repository head source-repository head
type: git type: git

154
src/ATrade/MDS/Database.hs

@ -2,129 +2,93 @@
module ATrade.MDS.Database ( module ATrade.MDS.Database (
DatabaseConfig(..), DatabaseConfig(..),
DatabaseInterface(..), MdsHandle,
startDatabase, initDatabase,
stopDatabase closeDatabase,
getData,
putData,
TimeInterval(..),
Timeframe(..),
timeframeDaily,
timeframeHour,
timeframeMinute
) where ) where
import qualified Data.Configurator as C
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Lazy as TL
import Data.Text.Format
import qualified Data.Vector as V import qualified Data.Vector as V
import ATrade.Types import ATrade.Types
import Data.Time.Clock import Data.Time.Clock
import Data.Time.Clock.POSIX import Data.Time.Clock.POSIX
import Data.Maybe import Data.Maybe
import Control.Concurrent.MVar
import Control.Concurrent
import System.Log.Logger
import Database.HDBC import Database.HDBC
import Database.HDBC.PostgreSQL import Database.HDBC.Sqlite3
import Control.Monad import Control.Monad
import Control.Monad.Loops
data TimeInterval = TimeInterval UTCTime UTCTime data TimeInterval = TimeInterval UTCTime UTCTime
data Timeframe = Timeframe Int data Timeframe = Timeframe Int
timeframeDaily = Timeframe 86400 timeframeDaily :: Int -> Timeframe
timeframeHour = Timeframe 3600 timeframeDaily days = Timeframe (days * 86400)
timeframeMinute = Timeframe 60
timeframeHour :: Int -> Timeframe
timeframeHour hours = Timeframe (hours * 3600)
timeframeMinute :: Int -> Timeframe
timeframeMinute mins = Timeframe (mins * 60)
data DatabaseCommand = DBGet TickerId TimeInterval Timeframe | DBPut TickerId TimeInterval Timeframe (V.Vector Bar)
data DatabaseResponse = DBOk | DBData [(TimeInterval, V.Vector Bar)] | DBError T.Text
data DatabaseConfig = DatabaseConfig { data DatabaseConfig = DatabaseConfig {
dbHost :: T.Text, dbPath :: T.Text,
dbDatabase :: T.Text, dbDatabase :: T.Text,
dbUser :: T.Text, dbUser :: T.Text,
dbPassword :: T.Text dbPassword :: T.Text
} deriving (Show, Eq) } deriving (Show, Eq)
data DatabaseInterface = DatabaseInterface { type MdsHandle = Connection
tid :: ThreadId,
getData :: TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)],
putData :: TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO ()
}
startDatabase :: DatabaseConfig -> IO DatabaseInterface initDatabase :: DatabaseConfig -> IO MdsHandle
startDatabase config = do initDatabase config = do
conn <- connectPostgreSQL (mkConnectionString config) conn <- connectSqlite3 (T.unpack $ dbPath config)
makeSchema conn makeSchema conn
cmdVar <- newEmptyMVar return conn
respVar <- newEmptyMVar where
compVar <- newEmptyMVar makeSchema conn = runRaw conn "CREATE TABLE IF NOT EXISTS bars (id SERIAL PRIMARY KEY, ticker TEXT, timestamp BIGINT, timeframe INTEGER, open NUMERIC(20, 10), high NUMERIC(20, 10), low NUMERIC(20, 10), close NUMERIC(20,10), volume BIGINT);"
tid <- forkFinally (dbThread conn cmdVar respVar) (cleanup conn cmdVar respVar compVar)
return DatabaseInterface { closeDatabase :: MdsHandle -> IO ()
tid = tid, closeDatabase = disconnect
getData = doGetData cmdVar respVar,
putData = doPutData cmdVar respVar } getData :: MdsHandle -> TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)]
getData db tickerId interval@(TimeInterval start end) (Timeframe tfSec) = do
rows <- quickQuery' db "SELECT timestamp, timeframe, open, high, low, close, volume FROM bars WHERE ticker == ? AND timeframe == ? AND timestamp >= ? AND timestamp <= ? ORDER BY timestamp ASC;" [(toSql. T.unpack) tickerId, toSql tfSec, (toSql . utcTimeToPOSIXSeconds) start, (toSql . utcTimeToPOSIXSeconds) end]
return [(interval, V.fromList $ mapMaybe (barFromResult tickerId) rows)]
where where
makeSchema conn = runRaw conn "CREATE TABLE IF NOT EXISTS bars (id SERIAL PRIMARY KEY, ticker TEXT, timestamp BIGINT, open NUMERIC(20, 10), high NUMERIC(20, 10), low NUMERIC(20, 10), close NUMERIC(20,10), volume BIGINT);" barFromResult ticker [ts, _, open, high, low, close, vol] = Just Bar {
mkConnectionString config = TL.unpack $ format "User ID={};Password={};Host={};Port=5432;Database={}" (dbUser config, dbPassword config, dbHost config, dbDatabase config) barSecurity = ticker,
dbThread conn cmdVar respVar = forever $ do barTimestamp = fromSql ts,
cmd <- readMVar cmdVar barOpen = fromDouble $ fromSql open,
handleCmd conn cmd >>= putMVar respVar barHigh = fromDouble $ fromSql high,
whileM_ (isJust <$> tryReadMVar respVar) yield barLow = fromDouble $ fromSql low,
takeMVar cmdVar barClose = fromDouble $ fromSql close,
cleanup conn cmdVar respVar compVar _ = disconnect conn >> putMVar compVar () barVolume = fromSql vol
handleCmd conn cmd = case cmd of }
DBPut tickerId (TimeInterval start end) tf@(Timeframe timeframeSecs) bars -> do
delStmt <- prepare conn "DELETE FROM bars WHERE timestamp > ? AND timestamp < ? AND ticker == ? AND timeframe == ?;"
execute delStmt [(SqlPOSIXTime . utcTimeToPOSIXSeconds) start, (SqlPOSIXTime . utcTimeToPOSIXSeconds) end, (SqlString . T.unpack) tickerId, (SqlInteger . toInteger) timeframeSecs]
stmt <- prepare conn $ "INSERT INTO bars (ticker, timeframe, timestamp, open, high, low, close, volume)" ++
" values (?, ?, ?, ?, ?, ?, ?, ?); "
executeMany stmt (map (barToSql tf) $ V.toList bars)
return DBOk
DBGet tickerId interval@(TimeInterval start end) (Timeframe timeframeSecs) -> do
rows <- quickQuery' conn "SELECT timestamp, open, high, low, close, volume FROM bars WHERE ticker == ? AND timeframe == ? AND timestamp > ? AND timestamp < ?;" [(toSql. T.unpack) tickerId, toSql timeframeSecs, (toSql . utcTimeToPOSIXSeconds) start, (toSql . utcTimeToPOSIXSeconds) end]
return $ DBData [(interval, V.fromList $ mapMaybe (barFromResult tickerId) rows)]
barFromResult ticker [ts, open, high, low, close, volume] = Just Bar {
barSecurity = ticker,
barTimestamp = fromSql ts,
barOpen = fromRational $ fromSql open,
barHigh = fromRational $ fromSql high,
barLow = fromRational $ fromSql low,
barClose = fromRational $ fromSql close,
barVolume = fromSql volume
}
barFromResult _ _ = Nothing barFromResult _ _ = Nothing
putData :: MdsHandle -> TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO ()
putData db tickerId (TimeInterval start end) tf@(Timeframe tfSec) bars = do
delStmt <- prepare db "DELETE FROM bars WHERE timestamp >= ? AND timestamp <= ? AND ticker == ? AND timeframe == ?;"
void $ execute delStmt [(SqlPOSIXTime . utcTimeToPOSIXSeconds) start, (SqlPOSIXTime . utcTimeToPOSIXSeconds) end, (SqlString . T.unpack) tickerId, (SqlInteger . toInteger) tfSec]
stmt <- prepare db $ "INSERT INTO bars (ticker, timeframe, timestamp, open, high, low, close, volume)" ++
" values (?, ?, ?, ?, ?, ?, ?, ?); "
executeMany stmt (map (barToSql tf) $ V.toList bars)
where
barToSql :: Timeframe -> Bar -> [SqlValue] barToSql :: Timeframe -> Bar -> [SqlValue]
barToSql (Timeframe timeframeSecs) bar = [(SqlString . T.unpack . barSecurity) bar, barToSql (Timeframe timeframeSecs) bar = [(SqlString . T.unpack . barSecurity) bar,
(SqlInteger . toInteger) timeframeSecs, (SqlInteger . toInteger) timeframeSecs,
(SqlRational . toRational . barOpen) bar, (SqlPOSIXTime . utcTimeToPOSIXSeconds . barTimestamp) bar,
(SqlRational . toRational . barHigh) bar, (SqlDouble . toDouble . barOpen) bar,
(SqlRational . toRational . barLow) bar, (SqlDouble . toDouble . barHigh) bar,
(SqlRational . toRational . barClose) bar, (SqlDouble . toDouble . barLow) bar,
(SqlRational . toRational . barVolume) bar ] (SqlDouble . toDouble . barClose) bar,
(SqlInteger . barVolume) bar ]
stopDatabase :: MVar () -> DatabaseInterface -> IO ()
stopDatabase compVar db = killThread (tid db) >> readMVar compVar
doGetData :: MVar DatabaseCommand -> MVar DatabaseResponse -> TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)]
doGetData cmdVar respVar tickerId timeInterval timeframe = do
putMVar cmdVar (DBGet tickerId timeInterval timeframe)
resp <- takeMVar respVar
case resp of
DBData x -> return x
DBError err -> do
warningM "DB.Client" $ "Error while calling getData: " ++ show err
return []
_ -> do
warningM "DB.Client" "Unexpected response"
return []
doPutData :: MVar DatabaseCommand -> MVar DatabaseResponse -> TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO ()
doPutData cmdVar respVar tickerId timeInterval timeframe bars = do
putMVar cmdVar (DBPut tickerId timeInterval timeframe bars)
resp <- takeMVar respVar
case resp of
DBOk -> return ()
DBError err -> do
warningM "DB.Client" $ "Error while calling putData: " ++ show err
return ()
_ -> do
warningM "DB.Client" "Unexpected response"
return ()

34
src/ATrade/MDS/HistoryServer.hs

@ -1,14 +1,42 @@
module ATrade.MDS.HistoryServer ( module ATrade.MDS.HistoryServer (
HistoryServer,
startHistoryServer
) where ) where
import System.ZMQ4 import System.ZMQ4
import ATrade.MDS.Database import ATrade.MDS.Database
import ATrade.MDS.Protocol
import Control.Concurrent import Control.Concurrent
import Control.Monad
import Data.Aeson
import Data.List.NonEmpty
import qualified Data.Vector as V
import Safe
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
data HistoryServer = HistoryServer ThreadId data HistoryServer = HistoryServer ThreadId
}
startHistoryServer :: DatabaseInterface -> Context -> IO HistoryServer startHistoryServer :: MdsHandle -> Context -> IO HistoryServer
startHistoryServer db ctx = do startHistoryServer db ctx = do
sock <- socket ctx Router
tid <- forkIO $ serve db sock
return $ HistoryServer tid
serve :: (Sender a, Receiver a) => MdsHandle -> Socket a -> IO ()
serve db sock = forever $ do
rq <- receiveMulti sock
let maybeCmd = (BL.fromStrict <$> rq `atMay` 2) >>= decode
case (headMay rq, maybeCmd) of
(Just peerId, Just cmd) -> handleCmd peerId cmd
_ -> return ()
where
handleCmd :: B.ByteString -> MDSRequest -> IO ()
handleCmd peerId cmd = case cmd of
rq -> do
qdata <- getData db (rqTicker rq) (TimeInterval (rqFrom rq) (rqTo rq)) (Timeframe (rqTimeframe rq))
bytes <- serializeBars $ V.concat $ fmap snd qdata
sendMulti sock $ peerId :| B.empty : bytes
serializeBars = undefined

23
src/ATrade/MDS/Protocol.hs

@ -0,0 +1,23 @@
{-# LANGUAGE DeriveGeneric #-}
module ATrade.MDS.Protocol (
MDSRequest(..)
) where
import GHC.Generics
import ATrade.Types
import Data.Aeson
import Data.Time.Clock
data MDSRequest = RequestData {
rqTicker :: TickerId,
rqFrom :: UTCTime,
rqTo :: UTCTime,
rqTimeframe :: Int
} deriving (Generic, Show, Eq)
instance ToJSON MDSRequest
instance FromJSON MDSRequest

6
src/Lib.hs

@ -1,6 +0,0 @@
module Lib
( someFunc
) where
someFunc :: IO ()
someFunc = putStrLn "someFunc"

5
stack.yaml

@ -15,7 +15,7 @@
# resolver: # resolver:
# name: custom-snapshot # name: custom-snapshot
# location: "./custom-snapshot.yaml" # location: "./custom-snapshot.yaml"
resolver: lts-7.4 resolver: lts-11.9
# User packages to be built. # User packages to be built.
# Various formats can be used as shown in the example below. # Various formats can be used as shown in the example below.
@ -38,9 +38,10 @@ resolver: lts-7.4
packages: packages:
- '.' - '.'
- '../libatrade' - '../libatrade'
- '../zeromq4-haskell-zap'
# Dependency packages to be pulled from upstream that are not in the resolver # Dependency packages to be pulled from upstream that are not in the resolver
# (e.g., acme-missiles-0.3) # (e.g., acme-missiles-0.3)
extra-deps: ["HDBC-postgresql-2.3.2.4", "datetime-0.3.1"] extra-deps: ["HDBC-sqlite3-2.3.3.1", "datetime-0.3.1"]
# Override default flag values for local packages and extra-deps # Override default flag values for local packages and extra-deps
flags: {} flags: {}

77
test/Integration/Database.hs

@ -0,0 +1,77 @@
module Integration.Database (
testDatabase
) where
import Test.Tasty
import Test.Tasty.HUnit
import ATrade.MDS.Database
import ATrade.Types
import Control.Exception
import Data.DateTime
import Data.Time.Clock
import qualified Data.Text as T
import qualified Data.Vector as V
import System.IO.Temp
testDatabase :: TestTree
testDatabase = testGroup "Database tests" [ testOpenClose
, testPutGet
, testGetReturnsSorted ]
testOpenClose :: TestTree
testOpenClose = testCase "Open/Close" $
withSystemTempDirectory "test" $ \fp -> do
let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty
db <- initDatabase dbConfig
closeDatabase db
bar :: UTCTime -> Price -> Price -> Price -> Price -> Integer -> Bar
bar dt o h l c v = Bar { barSecurity = "FOO",
barTimestamp = dt,
barOpen = o,
barHigh = h,
barLow = l,
barClose = c,
barVolume = v }
testPutGet :: TestTree
testPutGet = testCase "Put/Get" $
withSystemTempDirectory "test" $ \fp -> do
let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty
bracket (initDatabase dbConfig) closeDatabase $ \db -> do
putData db "FOO" interval (timeframeMinute 1) bars
retrievedBars <- (snd . head) <$> getData db "FOO" interval (timeframeMinute 1)
assertEqual "Retreived bars are different from saved" bars retrievedBars
where
interval = TimeInterval (fromGregorian 2010 1 1 12 0 0) (fromGregorian 2010 1 1 12 5 0)
bars = V.fromList $ [
bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1,
bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1,
bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1
]
testGetReturnsSorted :: TestTree
testGetReturnsSorted = testCase "Get returns sorted vector" $
withSystemTempDirectory "test" $ \fp -> do
let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty
bracket (initDatabase dbConfig) closeDatabase $ \db -> do
putData db "FOO" interval (timeframeMinute 1) bars
retrievedBars <- (snd . head) <$> getData db "FOO" interval (timeframeMinute 1)
assertEqual "Retreived bars are not sorted" sortedBars retrievedBars
where
interval = TimeInterval (fromGregorian 2010 1 1 12 0 0) (fromGregorian 2010 1 1 12 5 0)
bars = V.fromList $ [
bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1,
bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1,
bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1
]
sortedBars = V.fromList $ [
bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1,
bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1,
bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1
]

12
test/Integration/Spec.hs

@ -0,0 +1,12 @@
module Integration.Spec (
integrationTests
) where
import Test.Tasty
import Test.Tasty.HUnit
import Integration.Database
integrationTests = testGroup "Integration tests" [ testDatabase ]

11
test/Spec.hs

@ -1,2 +1,11 @@
import Test.Tasty
import Test.Tasty.HUnit
import Integration.Spec
main :: IO () main :: IO ()
main = putStrLn "Test suite not yet implemented" main = defaultMain tests
tests :: TestTree
tests = testGroup "Tests" [ integrationTests ]

Loading…
Cancel
Save