diff --git a/src/naiback/broker/broker.py b/src/naiback/broker/broker.py index 2a03759..9d70015 100644 --- a/src/naiback/broker/broker.py +++ b/src/naiback/broker/broker.py @@ -21,14 +21,23 @@ class Broker: def add_position(self, ticker, price, amount): volume = abs(price * amount) if amount > 0: - if volume * (1 + self.commission_percentage) > self.cash_: + if volume * (1 + 0.01 * self.commission_percentage) > self.cash_: return None pos = Position(ticker) pos.enter(price, amount) - self.cash_ -= (volume + volume * self.commission_percentage) + self.cash_ -= price * amount + self.cash_ -= volume * 0.01 * self.commission_percentage self.positions.append(pos) return pos + def close_position(self, pos, price): + volume = abs(price * pos.size()) + size = pos.size() + pos.exit(price) + + self.cash_ += price * size + self.cash_ -= volume * 0.01 * self.commission_percentage + def set_commission(self, percentage): self.commission_percentage = percentage diff --git a/tests/test_broker.py b/tests/test_broker.py index df08f9b..ce14234 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -41,3 +41,21 @@ def test_broker_all_position(): assert pos in broker.all_positions() +def test_broker_close_position(): + broker = Broker(initial_cash=100) + + pos = broker.add_position('FOO', price=10, amount=1) + broker.close_position(pos, price=12) + + assert broker.cash() == 100 + 2 + +def test_broker_close_position_with_commission(): + broker = Broker(initial_cash=100) + + broker.set_commission(percentage=1) # 1% + + pos = broker.add_position('FOO', price=10, amount=1) + broker.close_position(pos, price=12) + + assert broker.cash() == 100 + 2 - (10 + 12) * 0.01 +