From a649929475f37501dd028cf1fe7c53d4b468fd32 Mon Sep 17 00:00:00 2001 From: Denis Tereshkin Date: Thu, 20 Dec 2018 21:03:15 +0700 Subject: [PATCH] EquityAnalyzer: optimized algorithm --- src/naiback/analyzers/equityanalyzer.py | 29 +++++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/naiback/analyzers/equityanalyzer.py b/src/naiback/analyzers/equityanalyzer.py index 3cb447b..b0798a1 100644 --- a/src/naiback/analyzers/equityanalyzer.py +++ b/src/naiback/analyzers/equityanalyzer.py @@ -7,6 +7,7 @@ class EquityAnalyzer(Analyzer): def __init__(self, strategy): self.strategy = strategy + self.bar_to_pos = [] def get_result(self): positions = self.strategy.broker.retired_positions() # TODO also add open positions @@ -15,7 +16,6 @@ class EquityAnalyzer(Analyzer): return equity def calc_equity(self, positions, bars): - timestamp = bars.timestamp close = bars.close cumulative_pnl = 0 @@ -23,16 +23,16 @@ class EquityAnalyzer(Analyzer): equity = [] prev_p = 0 - print(len(close), len(timestamp)) - for (p,ts) in zip(close, timestamp): - active_positions = self.positions_for_timestamp(positions, ts) + self.calculate_lookup_table(positions, len(close)) + for (bar_num, p) in enumerate(close): + active_positions = self.positions_for_bar(positions, bar_num) for pos in active_positions: - if pos.entry_time() == ts: + if pos.entry_bar() == bar_num: if pos.is_long(): cumulative_pnl += (p - pos.entry_price()) else: cumulative_pnl -= (p - pos.entry_price()) - elif pos.exit_time() == ts: + elif pos.exit_bar() == bar_num: if pos.is_long(): cumulative_pnl += (pos.exit_price() - prev_p) else: @@ -47,12 +47,17 @@ class EquityAnalyzer(Analyzer): prev_p = p return equity - def positions_for_timestamp(self, positions, timestamp): + def positions_for_bar(self, positions, bar_num): result = [] - for p in positions: - if p.entry_time() <= timestamp and p.exit_time() >= timestamp: - result.append(p) + for pos_index in self.bar_to_pos[bar_num]: + pos = positions[pos_index] + result.append(pos) return result - - + def calculate_lookup_table(self, positions, length): + self.bar_to_pos = [] + for i in range(0, length): + self.bar_to_pos.append([]) + for pos_index, pos in enumerate(positions): + for i in range(pos.entry_bar(), pos.exit_bar() + 1): + self.bar_to_pos[i].append(pos_index)