From 6192f4435c0c499ed703b270b338d0117ab4938c Mon Sep 17 00:00:00 2001 From: Denis Tereshkin Date: Mon, 19 Mar 2018 19:03:00 +0700 Subject: [PATCH] Strategy: enter/exit methods --- examples/multiple_assets.py | 2 +- examples/single_asset.py | 2 +- src/naiback/broker/broker.py | 13 ++ src/naiback/broker/position.py | 3 + src/naiback/data/bars.py | 15 +- src/naiback/data/feeds/genericcsvfeed.py | 19 +- src/naiback/strategy/strategy.py | 159 ++++++++++++- tests/test_bars.py | 2 +- tests/test_genericcsvfeed.py | 4 +- tests/test_strategy.py | 270 +++++++++++++++++++++++ 10 files changed, 473 insertions(+), 16 deletions(-) create mode 100644 tests/test_strategy.py diff --git a/examples/multiple_assets.py b/examples/multiple_assets.py index bee8d66..8058eed 100644 --- a/examples/multiple_assets.py +++ b/examples/multiple_assets.py @@ -3,7 +3,7 @@ from naiback.strategy import Strategy from naiback.data.feeds import FinamCSVFeed from naiback.indicators import SMA, RSI -class MyStrategy(BarStrategy): +class MyStrategy(Strategy): def __init__(self): super().__init__() diff --git a/examples/single_asset.py b/examples/single_asset.py index 01764c0..025c6bb 100644 --- a/examples/single_asset.py +++ b/examples/single_asset.py @@ -13,7 +13,7 @@ class MyStrategy(SingleAssetStrategy): sma = SMA(self.bars.close, 200) rsi = RSI(self.bars.close, 2) stop = 0 - for i, bar in self.bars[200:]: + for i in self.bars.index[200:]: if self.last_position_is_active(): if not self.exit_at_stop(i, self.last_position(), stop): if self.bars.close[i] < exit_sma[i]: diff --git a/src/naiback/broker/broker.py b/src/naiback/broker/broker.py index 9d70015..515f419 100644 --- a/src/naiback/broker/broker.py +++ b/src/naiback/broker/broker.py @@ -37,10 +37,23 @@ class Broker: self.cash_ += price * size self.cash_ -= volume * 0.01 * self.commission_percentage + return True def set_commission(self, percentage): self.commission_percentage = percentage + def last_position(self): + return self.positions[-1] + def all_positions(self): return self.positions + def last_position_is_active(self): + if len(self.positions) == 0: + return False + + if self.last_position().exit_price() is None: + return True + + return False + diff --git a/src/naiback/broker/position.py b/src/naiback/broker/position.py index 3742fe6..8d5f6d8 100644 --- a/src/naiback/broker/position.py +++ b/src/naiback/broker/position.py @@ -19,6 +19,9 @@ class Position: def size(self): return self.size_ + def is_long(self): + return self.size_ > 0 + def entry_commission(self): return self.entry_metadata['commission'] diff --git a/src/naiback/data/bars.py b/src/naiback/data/bars.py index bdee2aa..e53f34e 100644 --- a/src/naiback/data/bars.py +++ b/src/naiback/data/bars.py @@ -7,7 +7,8 @@ class Bars: Basic bar series structure """ - def __init__(self): + def __init__(self, ticker): + self.ticker = ticker self.index = [] self.open = [] self.high = [] @@ -28,11 +29,19 @@ class Bars: self.volume.append(volume) self.timestamp.append(timestamp) + def insert_bar(self, index, open_, high, low, close, volume, timestamp): + self.open.insert(index, open_) + self.high.insert(index, high) + self.low.insert(index, low) + self.close.insert(index, close) + self.volume.insert(index, volume) + self.timestamp.insert(index, timestamp) + @classmethod - def from_feed(feed): + def from_feed(cls, feed): if feed.type() != 'bars': raise NaibackException('Invalid feed type: "{}", should be "bars"'.format(feed.type())) - bars = Bars() + bars = Bars(feed.ticker()) for bar in feed.items(): bars.append_bar(bar.open, bar.high, bar.low, bar.close, bar.volume, bar.timestamp) return bars diff --git a/src/naiback/data/feeds/genericcsvfeed.py b/src/naiback/data/feeds/genericcsvfeed.py index 5532bb0..a77d9f2 100644 --- a/src/naiback/data/feeds/genericcsvfeed.py +++ b/src/naiback/data/feeds/genericcsvfeed.py @@ -8,16 +8,17 @@ class GenericCSVFeed(Feed): def __init__(self, fp): self.bars = [] + self.ticker_ = None reader = csv.reader(fp, delimiter=',') next(reader) - next(reader) for row in reader: try: - open_ = row[4] - high = row[5] - low = row[6] - close = row[7] - volume = row[8] + self.ticker_ = row[0] + open_ = float(row[4]) + high = float(row[5]) + low = float(row[6]) + close = float(row[7]) + volume = int(row[8]) date = row[2] time = row[3] dt = datetime.datetime.strptime(date + "_" + time, "%Y%m%d_%H%M%S") @@ -25,5 +26,11 @@ class GenericCSVFeed(Feed): except IndexError: pass + def type(self): + return 'bars' + def items(self): return self.bars + + def ticker(self): + return self.ticker_ diff --git a/src/naiback/strategy/strategy.py b/src/naiback/strategy/strategy.py index 50dff0f..69e073e 100644 --- a/src/naiback/strategy/strategy.py +++ b/src/naiback/strategy/strategy.py @@ -1,12 +1,17 @@ from abc import abstractmethod +from naiback.broker.position import Position +from naiback.broker.broker import Broker +from naiback.data.bars import Bars class Strategy: """ - Internal base class for strategies. User should use it's subclasses (i.e. SingleAssetStrategy) """ def __init__(self): self.feeds = [] + self.all_bars = [] + self.all_positions = [] + self.broker = Broker() def add_feed(self, feed): """ @@ -25,4 +30,156 @@ class Strategy: """ By default, just calls execute. """ + self._prepare_bars() self.execute() + + def _prepare_bars(self): + if len(self.feeds) == 0: + raise NaibackException('No feeds added to strategy') + + self.all_bars.clear() + for feed in self.feeds: + self.all_bars.append(Bars.from_feed(feed)) + + all_dates = list(sorted(self._combine_dates())) + + for bars in self.all_bars: + self._synchronize_bars(bars, all_dates) + + def _get_bars(self, ticker): + for bars in self.all_bars: + if bars.ticker == ticker: + return bars + + return None + + def last_position(self): + return self.broker.last_position() + + def all_positions(self): + return self.broker.all_positions() + + def last_position_is_active(self): + return self.broker.last_position_is_active() + + def _synchronize_bars(self, bars, all_dates): + bar_pos = 0 + for dt in all_dates: + if bars.timestamp[bar_pos] > dt: + open_ = bars.open[bar_pos] + high = bars.high[bar_pos] + low = bars.low[bar_pos] + close = bars.close[bar_pos] + volume = bars.volume[bar_pos] + + bars.insert_bar(bar_pos, open_, high, low, close, volume, dt) + + def _combine_dates(self): + dates = set() + for bars in self.all_bars: + dates.update(bars.timestamp) + + return dates + + def buy_at_open(self, bar, ticker): + bars = self._get_bars(ticker) + return self.broker.add_position(ticker, bars.open[bar], 1) + + def buy_at_limit(self, bar, price, ticker): + bars = self._get_bars(ticker) + if bars.low[bar] <= price: + if bars.open[bar] > price: + return self.broker.add_position(ticker, price, 1) + else: + return self.broker.add_position(ticker, bars.open[bar], 1) + else: + return None + + def buy_at_stop(self, bar, price, ticker): + bars = self._get_bars(ticker) + if bars.high[bar] >= price: + if bars.open[bar] < price: + return self.broker.add_position(ticker, price, 1) + else: + return self.broker.add_position(ticker, bars.open[bar], 1) + else: + return None + + def buy_at_close(self, bar, ticker): + bars = self._get_bars(ticker) + return self.broker.add_position(ticker, bars.close[bar], 1) + + def short_at_open(self, bar, ticker): + bars = self._get_bars(ticker) + return self.broker.add_position(ticker, bars.open[bar], -1) + + def short_at_limit(self, bar, price, ticker): + bars = self._get_bars(ticker) + if bars.high[bar] >= price: + if bars.open[bar] < price: + return self.broker.add_position(ticker, price, -1) + else: + return self.broker.add_position(ticker, bars.open[bar], -1) + else: + return None + + def short_at_stop(self, bar, price, ticker): + bars = self._get_bars(ticker) + if bars.low[bar] <= price: + if bars.open[bar] > price: + return self.broker.add_position(ticker, price, -1) + else: + return self.broker.add_position(ticker, bars.open[bar], -1) + else: + return None + + def short_at_close(self, bar, ticker): + bars = self._get_bars(ticker) + return self.broker.add_position(ticker, bars.close[bar], -1) + + def exit_at_open(self, bar, pos): + bars = self._get_bars(pos.ticker) + return self.broker.close_position(pos, bars.open[bar]) + + def exit_at_limit(self, bar, price, pos): + bars = self._get_bars(pos.ticker) + if pos.is_long(): + if bars.high[bar] >= price: + if bars.open[bar] < price: + return self.broker.close_position(pos, price) + else: + return self.broker.close_position(pos, bars.open[bar]) + else: + return False + else: + if bars.low[bar] <= price: + if bars.open[bar] > price: + return self.broker.close_position(pos, price) + else: + return self.broker.close_position(pos, bars.open[bar]) + else: + return False + + def exit_at_stop(self, bar, price, pos): + bars = self._get_bars(pos.ticker) + if pos.is_long(): + if bars.low[bar] <= price: + if bars.open[bar] > price: + return self.broker.close_position(pos, price) + else: + return self.broker.close_position(pos, bars.open[bar]) + else: + return False + else: + if bars.high[bar] >= price: + if bars.open[bar] < price: + return self.broker.close_position(pos, price) + else: + return self.broker.close_position(pos, bars.open[bar]) + else: + return False + + def exit_at_close(self, bar, pos): + bars = self._get_bars(pos.ticker) + return self.broker.close_position(pos, bars.close[bar]) + diff --git a/tests/test_bars.py b/tests/test_bars.py index 3a6dcb1..3f38985 100644 --- a/tests/test_bars.py +++ b/tests/test_bars.py @@ -5,7 +5,7 @@ import datetime from naiback.data.bars import Bars def test_bar_append(): - bars = Bars() + bars = Bars('FOO') bars.append_bar(10, 20, 5, 11, 100, datetime.datetime(2017, 1, 1)) assert bars.open[0] == 10 diff --git a/tests/test_genericcsvfeed.py b/tests/test_genericcsvfeed.py index 7de5cb5..aa1724e 100644 --- a/tests/test_genericcsvfeed.py +++ b/tests/test_genericcsvfeed.py @@ -1,5 +1,4 @@ - import pytest import datetime import io @@ -9,8 +8,7 @@ from naiback.data.feeds.genericcsvfeed import GenericCSVFeed @pytest.fixture def sample(): - return ''' -,,,