import JSBI from 'jsbi';
import { SwapResult, toDecimal, ZERO, ceilingDivision, SwapExactOutputResult } from '../utils';
import Decimal from 'decimal.js';
import { Fraction } from '..';
import { calculateFeeAmount } from './fees';

export class TokenSwapConstantProduct {
  constructor(private traderFee: Fraction, private ownerFee: Fraction, private feesOnInput: boolean = true) {}

  public exchange(tokenAmounts: JSBI[], inputTradeAmount: JSBI, outputIndex: number): SwapResult {
    const inputIndex = outputIndex === 0 ? 1 : 0;
    const newInputTradeAmount = this.feesOnInput ? this.getAmountLessFees(inputTradeAmount) : inputTradeAmount;

    let expectedOutputAmount = this.getExpectedOutputAmount(tokenAmounts, newInputTradeAmount, inputIndex, outputIndex);

    const fees = this.getFees(this.feesOnInput ? inputTradeAmount : expectedOutputAmount);

    if (!this.feesOnInput) {
      expectedOutputAmount = this.getAmountLessFees(expectedOutputAmount);
    }

    return {
      priceImpact: this.getPriceImpact(
        tokenAmounts,
        newInputTradeAmount,
        expectedOutputAmount,
        inputIndex,
        outputIndex,
      ),
      fees,
      expectedOutputAmount,
    };
  }

  public exchangeForExactOutput(
    tokenAmounts: JSBI[],
    outputTradeAmount: JSBI,
    outputIndex: number,
  ): SwapExactOutputResult {
    const inputIndex = outputIndex === 0 ? 1 : 0;
    const newOutputTradeAmount = this.feesOnInput ? outputTradeAmount : this.getAmountPlusFees(outputTradeAmount);

    let expectedInputAmount = this.getInputAmount(tokenAmounts, newOutputTradeAmount, inputIndex, outputIndex);

    const fees = this.getFees(this.feesOnInput ? expectedInputAmount : outputTradeAmount);

    if (this.feesOnInput) {
      expectedInputAmount = this.getAmountPlusFees(expectedInputAmount);
    }

    return {
      priceImpact: this.getPriceImpactExactOutput(
        tokenAmounts,
        expectedInputAmount,
        newOutputTradeAmount,
        inputIndex,
        outputIndex,
      ),
      fees,
      expectedInputAmount,
    };
  }

  private getPriceImpact(
    tokenAmounts: JSBI[],
    inputTradeAmountJSBI: JSBI,
    expectedOutputAmountJSBI: JSBI,
    inputIndex: number,
    outputIndex: number,
  ): Decimal {
    if (
      JSBI.equal(inputTradeAmountJSBI, ZERO) ||
      JSBI.equal(tokenAmounts[inputIndex], ZERO) ||
      JSBI.equal(tokenAmounts[outputIndex], ZERO)
    ) {
      return new Decimal(0);
    }

    const noSlippageOutputAmount = toDecimal(
      this.getExpectedOutputAmountWithNoSlippage(tokenAmounts, inputTradeAmountJSBI, inputIndex, outputIndex),
    );
    const expectedOutputAmount = toDecimal(expectedOutputAmountJSBI);
    const impact = noSlippageOutputAmount.sub(expectedOutputAmount).div(noSlippageOutputAmount);

    return impact;
  }

  private getPriceImpactExactOutput(
    tokenAmounts: JSBI[],
    expectedInputTradeAmountJSBI: JSBI,
    outputAmountJSBI: JSBI,
    inputIndex: number,
    outputIndex: number,
  ): Decimal {
    if (
      JSBI.equal(outputAmountJSBI, ZERO) ||
      JSBI.equal(tokenAmounts[inputIndex], ZERO) ||
      JSBI.equal(tokenAmounts[outputIndex], ZERO)
    ) {
      return new Decimal(0);
    }

    const noSlippageInputAmount = toDecimal(
      this.getExpectedInputAmountWithNoSlippage(tokenAmounts, outputAmountJSBI, inputIndex, outputIndex),
    );
    const expectedInputAmount = toDecimal(expectedInputTradeAmountJSBI);
    const impact = expectedInputAmount.sub(noSlippageInputAmount).div(noSlippageInputAmount);

    return impact;
  }

  private getFees(tradeAmount: JSBI): JSBI {
    const tradingFee = calculateFeeAmount(tradeAmount, this.traderFee);
    const ownerFee = calculateFeeAmount(tradeAmount, this.ownerFee);

    return JSBI.add(tradingFee, ownerFee);
  }

  private getExpectedOutputAmount(
    tokenAmounts: JSBI[],
    inputTradeAmount: JSBI,
    inputIndex: number,
    outputIndex: number,
  ): JSBI {
    return this.getOutputAmount(tokenAmounts, inputTradeAmount, inputIndex, outputIndex);
  }

  private getExpectedOutputAmountWithNoSlippage(
    tokenAmounts: JSBI[],
    inputTradeAmount: JSBI,
    inputIndex: number,
    outputIndex: number,
  ): JSBI {
    if (JSBI.equal(tokenAmounts[inputIndex], ZERO)) {
      return tokenAmounts[outputIndex];
    }

    const expectedOutputAmountWithNoSlippage = JSBI.divide(
      JSBI.multiply(inputTradeAmount, tokenAmounts[outputIndex]),
      tokenAmounts[inputIndex],
    );

    if (this.feesOnInput) {
      return expectedOutputAmountWithNoSlippage;
    } else {
      return this.getAmountLessFees(expectedOutputAmountWithNoSlippage);
    }
  }

  private getExpectedInputAmountWithNoSlippage(
    tokenAmounts: JSBI[],
    outputTradeAmount: JSBI,
    inputIndex: number,
    outputIndex: number,
  ): JSBI {
    if (JSBI.equal(tokenAmounts[outputIndex], ZERO)) {
      return tokenAmounts[inputIndex];
    }

    const expectedInputAmountWithNoSlippage = JSBI.divide(
      JSBI.multiply(outputTradeAmount, tokenAmounts[inputIndex]),
      tokenAmounts[outputIndex],
    );

    if (this.feesOnInput) {
      return this.getAmountPlusFees(expectedInputAmountWithNoSlippage);
    } else {
      return expectedInputAmountWithNoSlippage;
    }
  }

  private getAmountLessFees(tradeAmount: JSBI): JSBI {
    return JSBI.subtract(tradeAmount, this.getFees(tradeAmount));
  }

  private getAmountPlusFees(tradeAmount: JSBI): JSBI {
    return JSBI.add(tradeAmount, this.getFees(tradeAmount));
  }

  private getOutputAmount(tokenAmounts: JSBI[], inputTradeAmount: JSBI, inputIndex: number, outputIndex: number): JSBI {
    const [poolInputAmount, poolOutputAmount] = [tokenAmounts[inputIndex], tokenAmounts[outputIndex]];

    const invariant = this.getInvariant(tokenAmounts);

    const [newPoolOutputAmount] = ceilingDivision(invariant, JSBI.add(poolInputAmount, inputTradeAmount));

    return JSBI.subtract(poolOutputAmount, newPoolOutputAmount);
  }

  private getInputAmount(tokenAmounts: JSBI[], outputTradeAmount: JSBI, inputIndex: number, outputIndex: number): JSBI {
    const [poolInputAmount, poolOutputAmount] = [tokenAmounts[inputIndex], tokenAmounts[outputIndex]];

    const invariant = this.getInvariant(tokenAmounts);

    if (JSBI.greaterThanOrEqual(outputTradeAmount, poolOutputAmount)) {
      throw new Error('Insufficient liquidity to provide outputTradeAmount');
    }
    const [newPoolInputAmount] = ceilingDivision(invariant, JSBI.subtract(poolOutputAmount, outputTradeAmount));
    return JSBI.subtract(newPoolInputAmount, poolInputAmount);
  }

  getInvariant(tokenAmounts: JSBI[]) {
    return JSBI.multiply(tokenAmounts[0], tokenAmounts[1]);
  }
}
