''' ''' import random import talib import numpy class Signal: def __init__(self): pass def calculate(self, series): pass def get_text(self): pass class PriceComparisonSignalGenerator: def __init__(self): pass def generate(self): lhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE) lhs_shift = random.randint(0, 10) rhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE) rhs_shift = random.randint(0, 10) return PriceComparisonSignal(lhs, lhs_shift, rhs, rhs_shift) class PriceComparisonSignal(Signal): OPEN = 0 HIGH = 1 LOW = 2 CLOSE = 3 def __init__(self, lhs, lhs_shift, rhs, rhs_shift): self.lhs = lhs self.lhs_shift = lhs_shift self.rhs = rhs self.rhs_shift = rhs_shift def calculate(self, series): result = [] for i in range(0, series.length()): result.append(self.calculate_at_index(series, i)) return result def calculate_at_index(self, series, index): try: if self.lhs == PriceComparisonSignal.OPEN: lhs = series.get_open(index - self.lhs_shift) elif self.lhs == PriceComparisonSignal.HIGH: lhs = series.get_high(index - self.lhs_shift) elif self.lhs == PriceComparisonSignal.LOW: lhs = series.get_low(index - self.lhs_shift) elif self.lhs == PriceComparisonSignal.CLOSE: lhs = series.get_close(index - self.lhs_shift) else: raise Exception('Invalid lhs type') if self.rhs == PriceComparisonSignal.OPEN: rhs = series.get_open(index - self.rhs_shift) elif self.rhs == PriceComparisonSignal.HIGH: rhs = series.get_high(index - self.rhs_shift) elif self.rhs == PriceComparisonSignal.LOW: rhs = series.get_low(index - self.rhs_shift) elif self.rhs == PriceComparisonSignal.CLOSE: rhs = series.get_close(index - self.rhs_shift) else: raise Exception('Invalid lhs type') return lhs < rhs except IndexError: return False def get_text(self): return self.component_to_str(self.lhs) + '[' + str(self.lhs_shift) + '] < ' + self.component_to_str(self.rhs) + '[' + str(self.rhs_shift) + ']' def component_to_str(self, component): if component == PriceComparisonSignal.OPEN: return "open" elif component == PriceComparisonSignal.HIGH: return "high" elif component == PriceComparisonSignal.LOW: return "low" elif component == PriceComparisonSignal.CLOSE: return "close" else: return "??" class RsiSignalGenerator: def __init__(self): pass def generate(self): period = random.randint(2, 30) threshold = random.randrange(1, 9) * 10 ineq_type = random.randint(RsiSignal.LT, RsiSignal.GT) return RsiSignal(period, threshold, ineq_type) class RsiSignal(Signal): LT = 0 GT = 1 def __init__(self, period, threshold, inequality_type): self.period = period self.threshold = threshold self.inequality_type = inequality_type if inequality_type == RsiSignal.LT: self.inequality_sign_str = '<' else: self.inequality_sign_str = '>' def calculate(self, series): closes = numpy.array([series.get_close(i) for i in range(0, series.length())]) rsi = talib.RSI(closes, self.period) result = [self.calc_signal(v) for v in rsi] return result def calc_signal(self, value): if self.inequality_type == RsiSignal.LT: return value < self.threshold else: return value > self.threshold def get_text(self): return "rsi(c, " + str(self.period) + ') ' + self.inequality_sign_str + ' ' + str(self.threshold) class AtrSignalGenerator: def __init__(self): pass def generate(self): period = random.randint(2, 30) threshold = random.randint(1, 30) * 0.001 ineq_type = random.randint(AtrSignal.LT, AtrSignal.GT) return AtrSignal(period, threshold, ineq_type) class AtrSignal(Signal): LT = 0 GT = 1 def __init__(self, period, threshold_factor, inequality_type): self.period = period self.threshold_factor = threshold_factor self.inequality_type = inequality_type if inequality_type == RsiSignal.LT: self.inequality_sign_str = '<' else: self.inequality_sign_str = '>' def calculate(self, series): closes = numpy.array([series.get_close(i) for i in range(0, series.length())]) highs = numpy.array([series.get_high(i) for i in range(0, series.length())]) lows = numpy.array([series.get_low(i) for i in range(0, series.length())]) atr = talib.ATR(highs, lows, closes, self.period) result = [self.calc_signal(v, c) for (v, c) in zip(atr, closes)] return result def calc_signal(self, value, close): if self.inequality_type == AtrSignal.LT: return value < self.threshold_factor * close else: return value > self.threshold_factor * close def get_text(self): return "atr(" + str(self.period) + ') ' + self.inequality_sign_str + ' close[0] * ' + "{:.3f}".format(self.threshold_factor)