/*
 Copyright (C) 2003 Niels Elken Sønderby

 This file is part of QuantLib for Mathematica, a Mathematica extension for
 QuantLib, a free-software/open-source financial C++ library
 http://www.nielses.dk/quantlib/mma
 http://quantlib.org/

 QuantLib for Mathematica is free software: you can redistribute it and/or
 modify it under the terms of the QuantLib license.  You should have received
 a copy of the license along with this program; if not, please email
 ferdinando@ametrano.net The license is also available online at
 http://quantlib.org/html/license.html

 This program is distributed in the hope that it will be useful, but WITHOUT
 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 FOR A PARTICULAR PURPOSE.  See the license for more details.
*/

#include "mmaoptions.hpp"

void qlEuropeanOption(int type, double underlying, double strike,
                        double dividendYield, double riskFreeRate,
                        double maturity, double volatility)
{
    try {
        EuropeanOption option((Option::Type)type, underlying, strike,
                              dividendYield, riskFreeRate, maturity, volatility);

        // Make all calculations first, so that if an error is thrown,
        // it can safely be send to Mathematica. (More elegant solution exists?)
        double NPV = option.value();
        double delta = option.delta();
        double gamma = option.gamma();
        double theta = option.theta();
        double vega = option.vega();
        double rho = option.rho();
        double dividendRho = option.dividendRho();

        MLPutFunction(stdlink, "List", 7);
        QLMLPutRule(stdlink, "Value", NPV);
        QLMLPutRule(stdlink, "Delta", delta);
        QLMLPutRule(stdlink, "Gamma", gamma);
        QLMLPutRule(stdlink, "Theta", theta);
        QLMLPutRule(stdlink, "Vega", vega);
        QLMLPutRule(stdlink, "Rho", rho);
        QLMLPutRule(stdlink, "DividendRho", dividendRho);
    } QLML_HANDLE_EXCEPTIONS
}

void qlEuropeanOptionMC(int type, double underlying, double strike,
                        double dividendYield, double riskFreeRate,
                        double maturity, double volatility, double samples,
                        int antithetic)
{
    try {
        // make sure seed is different between calls
        // otherwise it will fail if called twice in one second
        static long seed = 0;
        if (seed == 0)
            seed = QL_TIME(0);
        else
            seed++;

        McEuropean eur((Option::Type)type, underlying, strike, dividendYield,
           riskFreeRate, maturity, volatility, (bool)antithetic, seed);

        // samples should be long, but there's no MathLink type for it.
        double val = eur.valueWithSamples((long)samples);
        double err = eur.errorEstimate();

        MLPutFunction(stdlink, "List", 2);
        QLMLPutRule(stdlink, "Value", val);
        QLMLPutRule(stdlink, "StandardError", err);
    } QLML_HANDLE_EXCEPTIONS
}

void qlAmericanOptionFD(int type, double underlying, double strike,
                        double dividendYield, double riskFreeRate,
                        double maturity, double volatility, double timeSteps,
                        double gridPoints)
{
    try {
        FdAmericanOption fdam((Option::Type)type, underlying, strike,
                              dividendYield, riskFreeRate,
                              maturity, volatility, timeSteps, gridPoints);

        double val = fdam.value();

        MLPutFunction(stdlink, "List", 1);
        QLMLPutRule(stdlink, "Value", val);
    } QLML_HANDLE_EXCEPTIONS
}


void qlImpliedVolatility(double value, int type, double underlying,
                     double strike, double dividendYield, double riskFreeRate,
                     double maturity)
{
    try {
        EuropeanOption option((Option::Type)type, underlying, strike,
                               dividendYield, riskFreeRate, maturity, 0.1);

        double impliedVol = option.impliedVolatility(value);

        MLPutReal(stdlink, impliedVol);
    } QLML_HANDLE_EXCEPTIONS
}

#ifdef NESQUANT

void nqAmericanOptionSVJDLSM(int type, double underlying, double strike,
                         double dividendYield, double riskFreeRate,
                         double maturity, double volatility,
                         double volatilityOfVolatility,
                         double steadyStateVolatility, double meanReversionRate,
                         double correlationUnderlyingVolatility,
                         double jumpIntensity, double jumpMean,
                         double jumpStandardDeviation,
                         double timeSteps, double exerciseTimes,
                         double samples, int degree,
                         int antithetic, int controlVariate)
{
    try {
        LSMSVJDEngine engine(samples, timeSteps, exerciseTimes, degree,
                             antithetic, controlVariate);

        // downcast the Arguments struct to the appropriate type...
        SVJDArguments* underArgs =
            dynamic_cast<SVJDArguments*>(engine.arguments());
        QL_ENSURE(underArgs != 0, "dynamic_cast failed");
        // ... and set the values
        underArgs->type = (Option::Type)type;
        underArgs->underlying = underlying;
        underArgs->strike = strike;
        underArgs->dividendYield = dividendYield;
        underArgs->riskFreeRate = riskFreeRate;
        underArgs->maturity = maturity;
        underArgs->volatility = volatility;

        underArgs->volatilityOfVolatility = volatilityOfVolatility;
        underArgs->steadyStateVolatility = steadyStateVolatility;
        underArgs->meanReversionRate = meanReversionRate;
        underArgs->correlationUnderlyingVolatility
            = correlationUnderlyingVolatility;

        underArgs->jumpIntensity = jumpIntensity;
        underArgs->jumpMean = jumpMean;
        underArgs->jumpStandardDeviation = jumpStandardDeviation;

        engine.calculate();

        // now read additional results
        const OptionValue* results =
        dynamic_cast<const OptionValue*>(engine.results());
        QL_ENSURE(results != 0, "dynamic_cast failed");

        MLPutFunction(stdlink, "List", 2);
        QLMLPutRule(stdlink, "Value", results->value);
        QLMLPutRule(stdlink, "StandardError", results->errorEstimate);
    } QLML_HANDLE_EXCEPTIONS
}

void nqAmericanOptionLSM(int type, double underlying, double strike,
                         double dividendYield, double riskFreeRate,
                         double maturity, double volatility,
                         double timeSteps, double samples, int degree,
                         int antithetic, int controlVariate)
{
    try {
        RelinkableHandle<BlackVolTermStructure> volTS =
            RelinkableHandle<BlackVolTermStructure>(
                makeFlatVolatility(volatility));
        RelinkableHandle<TermStructure> riskFreeTS =
            RelinkableHandle<TermStructure>(
                Handle<TermStructure>(new ConstantTS(riskFreeRate)));
        RelinkableHandle<TermStructure> dividendTS =
            RelinkableHandle<TermStructure>(
                Handle<TermStructure>(new ConstantTS(dividendYield)));

        RelinkableHandle<MarketElement> underlyingH(
            Handle<MarketElement>(new SimpleMarketElement(underlying)));

        LSMVanillaEngine engine(samples, timeSteps, degree,
                                (bool)antithetic, (bool)controlVariate);

        // downcast the Arguments struct to the appropriate type...
        VanillaOptionArguments* underArgs =
            dynamic_cast<VanillaOptionArguments*>(engine.arguments());
        QL_ENSURE(underArgs != 0, "dynamic_cast failed");
        // ... and set the values
        underArgs->payoff = Handle<Payoff>(
                          new PlainVanillaPayoff((Option::Type)type, strike));
        underArgs->underlying = underlying;
        underArgs->riskFreeTS = riskFreeTS;
        underArgs->dividendTS = dividendTS;
        underArgs->volTS = volTS;
        underArgs->maturity = maturity;
        underArgs->exerciseType = Exercise::Type::American;

        engine.calculate();

        const OptionValue* results =
            dynamic_cast<const OptionValue*>(engine.results());
        QL_ENSURE(results != 0, "dynamic_cast failed");

        MLPutFunction(stdlink, "List", 2);
        QLMLPutRule(stdlink, "Value", results->value);
        QLMLPutRule(stdlink, "StandardError", results->errorEstimate);
    } QLML_HANDLE_EXCEPTIONS
}

void nqSVJDOptionMC(int type, double underlying, double strike,
                     double dividendYield, double riskFreeRate,
                     double maturity, double volatility,
                     double volatilityOfVolatility,
                     double steadyStateVolatility, double meanReversionRate,
                     double correlationUnderlyingVolatility,
                     double jumpIntensity, double jumpMean,
                     double jumpStandardDeviation,
                     double timeSteps, double samples,
                     int antithetic, int controlVariate)
{
    try {
        MCSVJDEngine<PseudoRandom> engine(timeSteps / maturity,
                                          (bool)antithetic, (bool)controlVariate,
                                          samples, Null<double>(), samples);

        // downcast the Arguments struct to the appropriate type...
        SVJDArguments* underArgs =
            dynamic_cast<SVJDArguments*>(engine.arguments());
        QL_ENSURE(underArgs != 0, "dynamic_cast failed");
        // ... and set the values
        underArgs->type = (Option::Type)type;
        underArgs->underlying = underlying;
        underArgs->strike = strike;
        underArgs->dividendYield = dividendYield;
        underArgs->riskFreeRate = riskFreeRate;
        underArgs->maturity = maturity;
        underArgs->volatility = volatility;

        underArgs->volatilityOfVolatility = volatilityOfVolatility;
        underArgs->steadyStateVolatility = steadyStateVolatility;
        underArgs->meanReversionRate = meanReversionRate;
        underArgs->correlationUnderlyingVolatility
            = correlationUnderlyingVolatility;

        underArgs->jumpIntensity = jumpIntensity;
        underArgs->jumpMean = jumpMean;
        underArgs->jumpStandardDeviation = jumpStandardDeviation;

        engine.calculate();

        // now read additional results
        const OptionValue* results =
        dynamic_cast<const OptionValue*>(engine.results());
        QL_ENSURE(results != 0, "dynamic_cast failed");

        MLPutFunction(stdlink, "List", 2);
        QLMLPutRule(stdlink, "Value", results->value);
        QLMLPutRule(stdlink, "StandardError", results->errorEstimate);
    } QLML_HANDLE_EXCEPTIONS
}

void nqSVJDOption(int type, double underlying, double strike,
                     double dividendYield, double riskFreeRate,
                     double maturity, double volatility,
                     double volatilityOfVolatility,
                     double steadyStateVolatility, double meanReversionRate,
                     double correlationUnderlyingVolatility,
                     double jumpIntensity, double jumpMean,
                     double jumpStandardDeviation)
{
    try {
        SVJDEngine engine;

        // downcast the Arguments struct to the appropriate type...
        SVJDArguments* underArgs =
            dynamic_cast<SVJDArguments*>(engine.arguments());
        QL_ENSURE(underArgs != 0, "dynamic_cast failed");
        // ... and set the values
        underArgs->type = (Option::Type)type;
        underArgs->underlying = underlying;
        underArgs->strike = strike;
        underArgs->dividendYield = dividendYield;
        underArgs->riskFreeRate = riskFreeRate;
        underArgs->maturity = maturity;
        underArgs->volatility = volatility;

        underArgs->volatilityOfVolatility = volatilityOfVolatility;
        underArgs->steadyStateVolatility = steadyStateVolatility;
        underArgs->meanReversionRate = meanReversionRate;
        underArgs->correlationUnderlyingVolatility
            = correlationUnderlyingVolatility;

        underArgs->jumpIntensity = jumpIntensity;
        underArgs->jumpMean = jumpMean;
        underArgs->jumpStandardDeviation = jumpStandardDeviation;

        engine.calculate();

        // now read additional results
        const OptionValue* results =
        dynamic_cast<const OptionValue*>(engine.results());
        QL_ENSURE(results != 0, "dynamic_cast failed");

        MLPutFunction(stdlink, "List", 1);
        QLMLPutRule(stdlink, "Value", results->value);
    } QLML_HANDLE_EXCEPTIONS
}

void nqSVJDPDF(double x, int type, double underlying, double strike,
                     double dividendYield, double riskFreeRate,
                     double maturity, double volatility,
                     double volatilityOfVolatility,
                     double steadyStateVolatility, double meanReversionRate,
                     double correlationUnderlyingVolatility,
                     double jumpIntensity, double jumpMean,
                     double jumpStandardDeviation)
{
    try {
        SVJDEngine engine;

        // downcast the Arguments struct to the appropriate type...
        SVJDArguments* underArgs =
            dynamic_cast<SVJDArguments*>(engine.arguments());
        QL_ENSURE(underArgs != 0, "dynamic_cast failed");
        // ... and set the values
        underArgs->type = (Option::Type)type;
        underArgs->underlying = underlying;
        underArgs->strike = strike;
        underArgs->dividendYield = dividendYield;
        underArgs->riskFreeRate = riskFreeRate;
        underArgs->maturity = maturity;
        underArgs->volatility = volatility;

        underArgs->volatilityOfVolatility = volatilityOfVolatility;
        underArgs->steadyStateVolatility = steadyStateVolatility;
        underArgs->meanReversionRate = meanReversionRate;
        underArgs->correlationUnderlyingVolatility
            = correlationUnderlyingVolatility;

        underArgs->jumpIntensity = jumpIntensity;
        underArgs->jumpMean = jumpMean;
        underArgs->jumpStandardDeviation = jumpStandardDeviation;

        MLPutReal(stdlink, engine.pdf(x));
    } QLML_HANDLE_EXCEPTIONS
}

#endif
