#include "util/BlockingChannel.h"

#include "protocol/AggreGateCommand.h"

#include "communication/SocketTimeoutException.h"
#include "communication/SocketException.h"
#include "communication/SocketDisconnectionException.h"

#include "util/Log.h"

#include "IOException.h"

#include <boost/date_time/posix_time/posix_time.hpp>
#include <boost/lambda/lambda.hpp>
#include <boost/lambda/bind.hpp>
#include <boost/thread/thread.hpp>

#ifdef __GNUC__
#include <unistd.h>
#include <netinet/tcp.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <errno.h>
#else
#include <winsock2.h>
#include <Windows.h>
#include <MSTcpIP.h>
#endif

BlockingChannel::BlockingChannel(int64_t timeout)
    : mSslSocket(NULL),
      mCtx(boost::asio::ssl::context::sslv23),
      mIsConnected(false), mIsClosing(true), mReadySocket(false), mSsl(true)
{
    this->mCtx.set_default_verify_paths();
    this->mCtx.set_options(boost::asio::ssl::context::no_sslv2 | boost::asio::ssl::context::no_sslv3);

    this->mSslSocket = new boost::asio::ssl::stream<boost::asio::ip::tcp::socket>(mIos, mCtx);
    this->mTimeout = timeout;
}

BlockingChannel::BlockingChannel(bool isSsl) :
    mSslSocket(NULL),
    mCtx(boost::asio::ssl::context::sslv23),mReadySocket(false), mSsl(isSsl)
{

}

#ifdef __GNUC__
BlockingChannel::BlockingChannel(int socket) :
    mSslSocket(NULL),
    mCtx(boost::asio::ssl::context::sslv23), mReadySocket(true), socket_(socket), mSsl(false)
{

}
#else
BlockingChannel::BlockingChannel(SOCKET socket) :
    mSslSocket(NULL),
    mCtx(boost::asio::ssl::context::sslv23), mReadySocket(true), socket_(socket), mSsl(false)
{

}
#endif


BlockingChannel::~BlockingChannel()
{
    if (mSsl)
    {
        if (mSslSocket)
        {
            delete mSslSocket;
        }
    }
    else
    {
        close();
    }
}

void BlockingChannel::connect(const AgString &ip, unsigned short port)
{
    if (mSsl)
    {
        // connect socket
        try {
            boost::asio::ip::tcp::resolver resolver(mIos);
            boost::asio::ip::tcp::resolver::query query(ip.toUtf8(), AgString::fromInt(port).toUtf8());
            boost::asio::ip::tcp::resolver::iterator endpoint_iterator = resolver.resolve(query);

            connect(endpoint_iterator);
        }
        catch(boost::system::system_error e) {
            throw IOException(AgString("Connect exception: ")/*+e.what()*/);
        }
    }
    else
    {
        if (mReadySocket)
            return;

#ifdef __GNUC__
        socket_ = 0;
        socket_ = socket(AF_INET , SOCK_STREAM , 0);
        if (socket_ == -1)
        {
            throw SocketException("Could not create socket!");
        }

        server_.sin_addr.s_addr = inet_addr(ip.toUtf8().c_str());
        server_.sin_family = AF_INET;
        server_.sin_port = htons(port);

        if (::connect(socket_, (struct sockaddr *)&server_, sizeof(server_)) < 0)
        {
            close();
            throw SocketException("Connect failed!");
        }
#else
        socket_ = INVALID_SOCKET;

        result = NULL;
        ptr = NULL;

        ZeroMemory(&hints, sizeof(hints));
        hints.ai_family = AF_UNSPEC;
        hints.ai_socktype = SOCK_STREAM;
        hints.ai_protocol = IPPROTO_TCP;

        char strPort[256];
        itoa(port, strPort, 10);
        int iResult = getaddrinfo(ip.toUtf8().c_str(), strPort, &hints, &result);
        if (iResult != 0)
        {
            LOG_PROTOCOL_DEBUG("Could not getaddrinfo!");
            throw SocketException("Could not getaddrinfo!");
        }

        for(ptr = result; ptr != NULL; ptr = ptr->ai_next)
        {
            socket_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);

            if (socket_ == INVALID_SOCKET)
            {
                freeaddrinfo(result);
                LOG_PROTOCOL_DEBUG("Could not create socket!");
                throw SocketException("Could not create socket!");
            }

            int iResult = ::connect(socket_, ptr->ai_addr, (int) ptr->ai_addrlen);
            if (iResult == SOCKET_ERROR)
            {
                close();
                LOG_PROTOCOL_DEBUG("Connect failed!");
                throw SocketException("Connect failed!");
            }
        }

        freeaddrinfo(result);
#endif
    }
}

void BlockingChannel::connect(boost::asio::ip::tcp::resolver::iterator& endpoint)
{
    if(mIsConnected) return;
    if(!mIsClosing) return;

    boost::asio::ip::tcp::resolver::iterator end;

    if (endpoint == end)
        return;

    mEndPoint = *endpoint;

    // try to connect, then call handle_connect
    mSslSocket->lowest_layer().connect(*endpoint);

    //mSslSocket->lowest_layer().set_option( boost::asio::ip::tcp::no_delay( true) );
    mSslSocket->lowest_layer().set_option( boost::asio::socket_base::send_buffer_size( 8192 ) );
    mSslSocket->lowest_layer().set_option( boost::asio::socket_base::receive_buffer_size( 8192 ) );

    // The SSL handshake
    mSslSocket->set_verify_mode(boost::asio::ssl::verify_none);
    mSslSocket->set_verify_callback(boost::bind(&BlockingChannel::verify_certificate, this, boost::lambda::_1, boost::lambda::_2));
    mSslSocket->handshake(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>::client);

    // we are connected!
    mIsConnected = true;
    mIsClosing = false;
}

void BlockingChannel::disconnect()
{
    // tell socket to close the connection
    close();

    if (mSsl)
    {
        // tell the IO service to stop
        mIos.stop();
        mIsConnected = false;
        mIsClosing = false;
    }
}

bool BlockingChannel::verify_certificate(bool preverified, boost::asio::ssl::verify_context& ctx)
{

    // The verify callback can be used to check whether the certificate that is
    // being presented is valid for the peer. For example, RFC 2818 describes
    // the steps involved in doing this for HTTPS. Consult the OpenSSL
    // documentation for more details. Note that the callback is called once
    // for each certificate in the certificate chain, starting from the root
    // certificate authority.

    // In this example we will simply print the certificate's subject name.
    char subject_name[256];
    X509* cert = X509_STORE_CTX_get_current_cert(ctx.native_handle());
    X509_NAME_oneline(X509_get_subject_name(cert), subject_name, 256);

    return preverified;
}

int BlockingChannel::read(boost::asio::streambuf &dst)
{
    if (mSsl)
    {
        boost::system::error_code ec = boost::asio::error::would_block;
        std::size_t bytes_transfered = 0;
        try {

            bytes_transfered = boost::asio::read_until(*mSslSocket, dst, AggreGateCommand::END_CHAR, ec);
            if (ec)
            {
                return -1;
            }
        } catch (boost::system::system_error& e) {
            UNUSED(e);
            return -1;
        }

        return bytes_transfered;
    }
    else
    {
        char tempBuf[10000];
        int bytes_read = recv(socket_, tempBuf, sizeof(tempBuf), 0);
        if (bytes_read > 0)
        {
            std::ostream bufStream(&dst);
            bufStream.write((const char *)&tempBuf[0], bytes_read);
            return bytes_read;
        }
        else
        {
#ifdef __GNUC__
            if ((errno != EAGAIN) || (errno != EWOULDBLOCK))
                throw SocketDisconnectionException("Socket disconnected. Trying to connect again...");
#else
            if (WSAGetLastError() != WSAECONNABORTED)
                throw SocketDisconnectionException("Socket disconnected. Trying to connect again...");
#endif
        }

        return 0;
    }
}

int BlockingChannel::write(std::vector<unsigned char>& src)// throws IOException
{
    if (mSsl)
    {
        if(!mIsConnected) return 0;
        if(mIsClosing) return 0;

        boost::system::error_code ec = boost::asio::error::would_block;
        std::size_t bytes_transfered = 0;
        try {

            while (bytes_transfered < src.size())
            {
                bytes_transfered += boost::asio::write(*mSslSocket, boost::asio::buffer(&src[bytes_transfered], src.size() - bytes_transfered), ec);
                if (ec)
                {
                    return 0;
                }
            }

        } catch (boost::system::system_error& e) {
            UNUSED(e);
        }

        return bytes_transfered;
    }
    else
    {
        size_t bytes_transfered = 0;
        while (bytes_transfered < src.size())
        {
            int r = send(socket_, (const char *)&src[bytes_transfered], src.size() - bytes_transfered, 0);
            if (r == -1)
            {
#ifdef __GNUC__
                if ((errno != EAGAIN) || (errno != EWOULDBLOCK))
                    throw SocketDisconnectionException("Socket disconnected. Trying to connect again...");
#else
                if (WSAGetLastError() != WSAECONNABORTED)
                    throw SocketDisconnectionException("Socket disconnected. Trying to connect again...");
#endif
            }
            else
            {
                bytes_transfered += r;
            }
        }

        return bytes_transfered;
    }
}

void BlockingChannel::close()
{
    if (mSsl)
    {
        if(mIsClosing) return;
        mIsClosing = true;

        try {
            mSslSocket->lowest_layer().close();
        }catch(boost::system::system_error /*e*/) {
            //TODO: throw exception
            //std::cout << e.code() << std::endl;
        }
    }
    else
    {
#ifdef __GNUC__
        ::close(socket_);
        socket_ = -1;
#else
        closesocket(socket_);
        socket_ = ~0;
#endif
    }
}

bool BlockingChannel::isOpen()
{
    if (mSsl)
    {
        return mIsConnected && !mIsClosing;
    }
    else
    {
#ifdef __GNUC__
        return socket_ != -1;
#else
        return socket_ != ~0;
#endif
    }
}
