You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
5.7 KiB
186 lines
5.7 KiB
''' |
|
''' |
|
import random |
|
|
|
import talib |
|
import numpy |
|
|
|
class Signal: |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def calculate(self, series): |
|
pass |
|
|
|
def get_text(self): |
|
pass |
|
|
|
class PriceComparisonSignalGenerator: |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def generate(self): |
|
lhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE) |
|
lhs_shift = random.randint(0, 10) |
|
rhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE) |
|
rhs_shift = random.randint(0, 10) |
|
return PriceComparisonSignal(lhs, lhs_shift, rhs, rhs_shift) |
|
|
|
class PriceComparisonSignal(Signal): |
|
|
|
OPEN = 0 |
|
HIGH = 1 |
|
LOW = 2 |
|
CLOSE = 3 |
|
|
|
def __init__(self, lhs, lhs_shift, rhs, rhs_shift): |
|
self.lhs = lhs |
|
self.lhs_shift = lhs_shift |
|
self.rhs = rhs |
|
self.rhs_shift = rhs_shift |
|
|
|
def calculate(self, series): |
|
result = [] |
|
for i in range(0, series.length()): |
|
result.append(self.calculate_at_index(series, i)) |
|
|
|
return result |
|
|
|
def calculate_at_index(self, series, index): |
|
try: |
|
if self.lhs == PriceComparisonSignal.OPEN: |
|
lhs = series.get_open(index - self.lhs_shift) |
|
elif self.lhs == PriceComparisonSignal.HIGH: |
|
lhs = series.get_high(index - self.lhs_shift) |
|
elif self.lhs == PriceComparisonSignal.LOW: |
|
lhs = series.get_low(index - self.lhs_shift) |
|
elif self.lhs == PriceComparisonSignal.CLOSE: |
|
lhs = series.get_close(index - self.lhs_shift) |
|
else: |
|
raise Exception('Invalid lhs type') |
|
|
|
if self.rhs == PriceComparisonSignal.OPEN: |
|
rhs = series.get_open(index - self.rhs_shift) |
|
elif self.rhs == PriceComparisonSignal.HIGH: |
|
rhs = series.get_high(index - self.rhs_shift) |
|
elif self.rhs == PriceComparisonSignal.LOW: |
|
rhs = series.get_low(index - self.rhs_shift) |
|
elif self.rhs == PriceComparisonSignal.CLOSE: |
|
rhs = series.get_close(index - self.rhs_shift) |
|
else: |
|
raise Exception('Invalid lhs type') |
|
|
|
return lhs < rhs |
|
|
|
except IndexError: |
|
return False |
|
|
|
def get_text(self): |
|
return self.component_to_str(self.lhs) + '[' + str(self.lhs_shift) + '] < ' + self.component_to_str(self.rhs) + '[' + str(self.rhs_shift) + ']' |
|
|
|
def component_to_str(self, component): |
|
if component == PriceComparisonSignal.OPEN: |
|
return "open" |
|
elif component == PriceComparisonSignal.HIGH: |
|
return "high" |
|
elif component == PriceComparisonSignal.LOW: |
|
return "low" |
|
elif component == PriceComparisonSignal.CLOSE: |
|
return "close" |
|
else: |
|
return "??" |
|
|
|
|
|
class RsiSignalGenerator: |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def generate(self): |
|
period = random.randint(2, 30) |
|
threshold = random.randrange(1, 9) * 10 |
|
ineq_type = random.randint(RsiSignal.LT, RsiSignal.GT) |
|
return RsiSignal(period, threshold, ineq_type) |
|
|
|
class RsiSignal(Signal): |
|
|
|
LT = 0 |
|
GT = 1 |
|
|
|
def __init__(self, period, threshold, inequality_type): |
|
self.period = period |
|
self.threshold = threshold |
|
self.inequality_type = inequality_type |
|
if inequality_type == RsiSignal.LT: |
|
self.inequality_sign_str = '<' |
|
else: |
|
self.inequality_sign_str = '>' |
|
|
|
def calculate(self, series): |
|
closes = numpy.array([series.get_close(i) for i in range(0, series.length())]) |
|
|
|
rsi = talib.RSI(closes, self.period) |
|
|
|
result = [self.calc_signal(v) for v in rsi] |
|
return result |
|
|
|
def calc_signal(self, value): |
|
if self.inequality_type == RsiSignal.LT: |
|
return value < self.threshold |
|
else: |
|
return value > self.threshold |
|
|
|
def get_text(self): |
|
return "rsi(c, " + str(self.period) + ') ' + self.inequality_sign_str + ' ' + str(self.threshold) |
|
|
|
|
|
|
|
class AtrSignalGenerator: |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def generate(self): |
|
period = random.randint(2, 30) |
|
threshold = random.randint(1, 30) * 0.001 |
|
ineq_type = random.randint(AtrSignal.LT, AtrSignal.GT) |
|
return AtrSignal(period, threshold, ineq_type) |
|
|
|
class AtrSignal(Signal): |
|
|
|
LT = 0 |
|
GT = 1 |
|
|
|
def __init__(self, period, threshold_factor, inequality_type): |
|
self.period = period |
|
self.threshold_factor = threshold_factor |
|
self.inequality_type = inequality_type |
|
if inequality_type == RsiSignal.LT: |
|
self.inequality_sign_str = '<' |
|
else: |
|
self.inequality_sign_str = '>' |
|
|
|
def calculate(self, series): |
|
closes = numpy.array([series.get_close(i) for i in range(0, series.length())]) |
|
highs = numpy.array([series.get_high(i) for i in range(0, series.length())]) |
|
lows = numpy.array([series.get_low(i) for i in range(0, series.length())]) |
|
|
|
atr = talib.ATR(highs, lows, closes, self.period) |
|
|
|
result = [self.calc_signal(v, c) for (v, c) in zip(atr, closes)] |
|
return result |
|
|
|
def calc_signal(self, value, close): |
|
if self.inequality_type == AtrSignal.LT: |
|
return value < self.threshold_factor * close |
|
else: |
|
return value > self.threshold_factor * close |
|
|
|
def get_text(self): |
|
return "atr(" + str(self.period) + ') ' + self.inequality_sign_str + ' close[0] * ' + "{:.3f}".format(self.threshold_factor) |
|
|
|
|
|
|
|
|