#include <vos/vutil/refcount.hh>
#include <vos/vutil/log.hh>
#include <vos/vutil/timer.hh>
#include <vos/vutil/rand.hh>

#include <vos/vip/stdprotocol.hh>

// First order of business: make this work without worrying about
// in-flight packets and congestion control.  That is, have it send
// one packet and wait for a reply, then send the next packet, etc.

using namespace VIP;
using namespace VUtil;

#define CHANNEL_HEADER_SIZE   6

#define VIP_SEQWIN  0xFFFF

StandardProtocol::StandardProtocol(Connection* c)
    : Protocol(c), haveSentEver(false)
{
    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        maxSentSeqnum[i] = rand_locked();
        lastRecvAcks[i] = maxSentSeqnum[i];

        backoff[i] = 1;
        retransTimer[i] = 0;
        lastRecvSeqnum[i] = 0;
        needAcks[i] = 0;
        ackCount[i] = 1;
    }
}

StandardProtocol::~StandardProtocol()
{
}

int StandardProtocol::getProtocolNum()
{
    return VIP_STDPROTO;
}

void StandardProtocol::makeSYN(uint8_t* buf, unsigned int* bufsize)
{
    if(*bufsize < (3 + (VIP_STDPROTO_CHANNELS * 4))) {
        *bufsize = 0;
        return;
    }

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    unsigned int ptr = structpack(buf, *bufsize, "cs", (char)VIP_STDPROTO, 0xFFFF);
    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        ptr += structpack(buf + ptr, *bufsize - ptr, "l", lastRecvSeqnum[i]);
    }

    ptr += structpack(buf + ptr, *bufsize - ptr, "s", 0xFFFF);
    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        LOG("stdproto", 6, connection << ": packing seqnum " << maxSentSeqnum[i] << " into channel "
            << i << " at offset " << ptr << " " << *bufsize << " " << ptr);
        ptr += structpack(buf + ptr, *bufsize - ptr, "l", maxSentSeqnum[i]);
        LOG("stdproto", 6, connection << ": numbers are "
                << (int)*(buf + ptr - 4) << " "
                << (int)*(buf + ptr - 3) << " "
                << (int)*(buf + ptr - 3) << " "
                << (int)*(buf + ptr - 1));
    }
    LOG("stdproto", 4, connection << ": set up StdProto SYN of len " << ptr);

    *bufsize = ptr;
}

unsigned int StandardProtocol::readAcks(uint8_t* buf, unsigned int bufsize, bool syn)
{
    uint8_t proto = 0;
    uint16_t ackbits = 0;

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    LOG("stdproto", 6, connection << ": in readAcks! " << bufsize);

    unsigned int ptr = structunpack(buf, bufsize, "cs", &proto, &ackbits);

    if(! syn && haveSentEver) connection->ackBytes(3, false);

    for(int i = 0; ptr < bufsize && i < VIP_STDPROTO_CHANNELS; i++) {
        if(ackbits & (1 << i)) {
            uint32_t ack = 0;
            if((bufsize - ptr) < 4) {
                ptr = bufsize;
                return ptr;
            }
            ptr += structunpack(buf + ptr, bufsize - ptr, "l", &ack);
            LOG("stdproto", 5, connection << ": stdproto channel " << i << " got ack " << ack);
            if(VIP_SEQNUMGT(ack, lastRecvAcks[i]) || syn) {
                if(! syn) {
                    connection->ackBytes(CHANNEL_HEADER_SIZE + VIP_SEQNUMDIFF(ack, lastRecvAcks[i]), true);
                    rmFromQueue(i, VIP_SEQNUMDIFF(ack, lastRecvAcks[i]));
                    lastRecvAcks[i] = ack;
                    if(VIP_SEQNUMGT(lastRecvAcks[i], maxSentSeqnum[i])) {
                        // this happens in the case where we've timed
                        // out and started the "go back N" recovery
                        // algorithm but in fact we timed out too
                        // early and are still getting ACKs back from
                        // the previous send attempt
                        maxSentSeqnum[i] = lastRecvAcks[i];
                    }
                }
                retransTimer[i] = getTimer();
                backoff[i] = 1;
                ackCount[i] = 1;
            } else if(VIP_SEQNUMLTE(ack, lastRecvAcks[i])) {
                LOG("stdproto", 4, connection << ": duplicate ACK! Count is " << (ackCount[i]+1));
                if(++ackCount[i] == 3) {
                    connection->ackBytes(CHANNEL_HEADER_SIZE
                                         + (VIP_SEQNUMDIFF(maxSentSeqnum[i], lastRecvAcks[i])), true);
                    connection->fastRecovery();
                    retransTimer[i] = 1;
                    ackCount[i] = 0;
                }
            }
        }
    }

    return ptr;
}

bool StandardProtocol::readSYNSeqNums(uint8_t* buf, unsigned int bufsize,
                                      unsigned int* ptr, bool setup)
{
    uint16_t seqbits;

    LOG("stdproto", 7, connection << ": ptr is " << *ptr);

    *ptr += structunpack(buf + *ptr, bufsize, "s", &seqbits);

    LOG("stdproto", 7, connection << ": ptr is now " << *ptr << " and seqbits is " << seqbits);

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

#ifdef WTF
    LOG("stdproto", 1, "buffer is: ");
    for(int i = 0; i < bufsize; i++) {
        fprintf(stderr, "%x", buf[i]);
    }
#endif

    for(int i = 0; *ptr < bufsize && i < VIP_STDPROTO_CHANNELS; i++) {
        if(seqbits & (1 << i)) {
            uint32_t seq;
            *ptr += structunpack(buf + *ptr, bufsize - *ptr, "l", &seq);
            LOG("stdproto", 5, connection << ": stdproto channel " << i << " got seq " << seq
                << " " << *ptr << " " << bufsize);

            if(!setup && lastRecvSeqnum[i] != seq) {
                LOG("stdproto", 2, connection << ": not in setup and syn on channel "
                    << i << " was " << seq << ", expecting " << lastRecvSeqnum[i]);
                return false;
            }

            lastRecvSeqnum[i] = seq;
        }
    }
    return true;
}

bool StandardProtocol::replySYN(uint8_t* received, unsigned int recvsz,
                                uint8_t* reply, unsigned int* replysz)
{
    LOG("stdproto", 6, connection << ": in replySYN " << recvsz << " " << *replysz);

    if(*replysz < (3 + (VIP_STDPROTO_CHANNELS * 4))) {
        *replysz = 0;
        return false;
    }

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    unsigned int readptr = readAcks(received, recvsz, true);
    if(!readSYNSeqNums(received, recvsz, &readptr, false)) {
        return false;
    }

    unsigned int ptr = structpack(reply, *replysz, "cs", (char)VIP_STDPROTO, 0xFFFF);
    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        ptr += structpack(reply + ptr, *replysz - ptr, "l", lastRecvSeqnum[i]);
    }

    ptr += structpack(reply + ptr, *replysz, "s", 0xFFFF);
    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        ptr += structpack(reply + ptr, *replysz - ptr, "l", maxSentSeqnum[i]);
    }

    *replysz = ptr;

    return true;
}

void StandardProtocol::setupWithSYN(uint8_t* received, unsigned int recvsz)
{
    LOG("stdproto", 6, connection << ": in setupWithSYN " << recvsz);

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    if(recvsz < (3 + (VIP_STDPROTO_CHANNELS * 4))) {
        return;
    }

    unsigned int readptr = readAcks(received, recvsz, true);
    LOG("stdproto", 7, connection << ": reading at offset " << readptr);
    readSYNSeqNums(received, recvsz, &readptr, true);
}

int StandardProtocol::queueData(Message* m)
{
    if(m->channel < VIP_STDPROTO_CHANNELS) {
        boost::recursive_mutex::scoped_lock lk(stdproto_mutex);
        data[m->channel].push_back(vRef<Message>(m, true));
        LOG("stdproto", 4, connection << ": queued msgid " << m->msgid
            << " on channel " << (int)m->channel
            << " queue size has " << data[m->channel].size()
            << " chunks with " << queuedBytes(m->channel) << " bytes");
    }
    return 0;
}

void StandardProtocol::handleChunk(uint8_t* buf, unsigned int* bufsize)
{
    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    LOG("stdproto", 5, connection << ": entered handleChunk " << *bufsize);

    assert(*bufsize < 0xFFFFF000);

    unsigned int ptr = readAcks(buf, *bufsize, false);

    uint16_t channelbits = 0;
    ptr += structunpack(buf + ptr, *bufsize - ptr, "s", &channelbits);

    if(haveSentEver) connection->ackBytes(2, false);

    LOG("stdproto", 5, connection << ": channelbits is " << channelbits);

    for(int i = 0; ptr < *bufsize && i < VIP_STDPROTO_CHANNELS; i++) {
        if(channelbits & (1 << i)) {
            uint32_t seqNum = 0;
            uint16_t size = 0;

            ptr += structunpack(buf + ptr, *bufsize - ptr, "ls", &seqNum, &size);

            LOG("stdproto", 4, connection << ": on channel " << i << " got " << size
                << " bytes with seqnum " << seqNum);

            if((ptr + size) > *bufsize) size = *bufsize - ptr;
            ptr += size;

            // drop the channel chunk if it's not the next one we were
            // expecting...  a terribly lousy policy which will
            // obviously need to be fixed later :-)
            if(VIP_SEQNUMLT(lastRecvSeqnum[i], seqNum) || VIP_SEQNUMGTE(lastRecvSeqnum[i], (seqNum + size))) {
                LOG("stdproto", 4, connection << ": dropping " << seqNum << ", lastRecvSeqnum is "
                    << lastRecvSeqnum[i] << " which is not in the range "
                    << seqNum << " - " << seqNum+size);

                assert(seqNum != 0);

                needAcks[i]++;
                continue;
            }

            needAcks[i] = 1;

            vRef<Message> m(new Message, false);

            m->connection.assign(connection, true);
            m->ToS = VIP_STDPROTO;
            m->channel = i;
            m->msgid = 0;

            unsigned int diff = (seqNum + size) - lastRecvSeqnum[i];

            //LOG("stdproto", 2, "diff " << seqNum << " " << size << " " << lastRecvSeqnum[i]);

            m->payload.assign((char*)(buf + (ptr - diff)), diff);

            lastRecvSeqnum[i] = seqNum + size;

            LOG("stdproto", 4, connection << ": passing message size " << m->payload.size()
                << " to queue or callback");

            connection->doCallback(m);
        }
    }

    haveSentEver = true;

    *bufsize = ptr;
}

void StandardProtocol::packChannel(int channel, uint8_t* buf, unsigned int* bufsize)
{
    unsigned int ptr = 6;

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    unsigned int sent = VIP_SEQNUMDIFF(maxSentSeqnum[channel], lastRecvAcks[channel]);

    for(unsigned int i = 0; i < data[channel].size() && ptr < *bufsize; i++) {
        if(sent > 0) {
            if(data[channel][i]->payload.size() <= sent) {
                sent -= data[channel][i]->payload.size();
            } else {
                LOG("stdproto", 6, connection << ": channel " << channel << " sent " << sent
                    << " bufsize " << *bufsize
                    << " payload size " << data[channel][i]->payload.size()
                    << " ptr " << ptr);

                unsigned int send = data[channel][i]->payload.size() - sent;
                if(send > (*bufsize - ptr)) send = (*bufsize - ptr);

                memcpy(buf + ptr,
                       data[channel][i]->payload.data() + sent,
                       send);
                ptr += send;
                sent = 0;
            }
        } else {
            if(data[channel][i]->payload.size() <= (*bufsize - ptr)) {
                memcpy(buf + ptr,
                       data[channel][i]->payload.data(),
                       data[channel][i]->payload.size());
                ptr += data[channel][i]->payload.size();
            } else {
                memcpy(buf + ptr, data[channel][i]->payload.data(), *bufsize - ptr);
                ptr += (*bufsize - ptr);
            }
        }
    }

    if(ptr > 6) {
        structpack(buf, 6, "ls", maxSentSeqnum[channel], ptr - 6);

        maxSentSeqnum[channel] += ptr - 6;
        *bufsize = ptr;
    } else {
        *bufsize = 0;
    }
}

void StandardProtocol::getNextChunk(uint8_t* buf, unsigned int* bufsize, bool ackOnly)
{
    std::deque<int> more;

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    assert(*bufsize < 0xFFFFF000);

    for(int i = 0; i < VIP_STDPROTO_CHANNELS && !ackOnly; i++) {
        //  if(timer has run out)
        if(retransTimer[i] > 0
           && (retransTimer[i] + (connection->getRTO() * backoff[i])) < getTimer()
           && (VIP_SEQNUMDIFF(maxSentSeqnum[i], lastRecvAcks[i])) > 0)
        {
            LOG("stdproto", 4, connection << ": (timeout) going back "
                << (VIP_SEQNUMDIFF(maxSentSeqnum[i], lastRecvAcks[i]))
                << " bytes on channel " << i << " acks were "
                << lastRecvAcks[i] << " " << maxSentSeqnum[i]);
            connection->notifyLoss();

            maxSentSeqnum[i] = lastRecvAcks[i];

            more.push_back(i);
            backoff[i] *= 2;
        } else {
            // send more data
            if(pendingData(i) > 0) {
                LOG("stdproto", 5, connection << ": sending on channel "
                    << i << " acks are " << lastRecvAcks[i]
                    << " " << maxSentSeqnum[i]);
                more.push_back(i);
            } else {
                if(data[i].empty()) {
                    retransTimer[i] = 0;
                    backoff[i] = 1;
                }
            }
        }
    }

    LOG("stdproto", 5, connection << ": sending on " << more.size() << " channels");

    uint16_t acks = 0;

    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        if(needAcks[i]) acks |= (1 << i);
    }

    if(more.empty() && acks == 0) {
        *bufsize = 0;
        return;
    }

    LOG("stdproto", 6, connection << ": packed acks is " << acks << " more is " << more.empty()
        << " pending is " << pendingData(0));

    assert(*bufsize < 0xFFFFF000);

    unsigned int ptr = structpack(buf, *bufsize, "cs", VIP_STDPROTO, acks);
    assert(*bufsize >= ptr);

    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        if(needAcks[i]) {
            LOG("stdproto", 4, connection << ": stdproto sending ack "
                << lastRecvSeqnum[i] << " on channel "
                << i << " needAcks is " << needAcks[i]);
            ptr += structpack(buf + ptr, *bufsize - ptr, "l", lastRecvSeqnum[i]);
            needAcks[i]--;
        }
    }

    unsigned int channels = 0;
    for(unsigned int i = 0; i < more.size(); i++) {
        channels |= (1 << more[i]);
    }

    ptr += structpack(buf + ptr, *bufsize - ptr, "s", channels);
    assert(*bufsize >= ptr);

    double now = getTimer();

    if(more.size() > 0 && more[0] == 0) {
        unsigned int csz;
        if(more.size() == 1) csz = *bufsize - ptr;
        else csz = (*bufsize - ptr) / 2;
        packChannel(0, buf + ptr, &csz);
        ptr += csz;

        more.pop_front();

        if(retransTimer[0] == 0 || backoff[0] > 1) retransTimer[0] = now;
    }

    assert(*bufsize >= ptr);

    unsigned int allocs[VIP_STDPROTO_CHANNELS];
    bool active[VIP_STDPROTO_CHANNELS];
    unsigned int freespace = (*bufsize - ptr);

    LOG("stdproto", 5, connection << ": freespace is " << freespace);

    memset(allocs, 0, sizeof(allocs));
    memset(active, 0, sizeof(active));

#define PREALLOC     (CHANNEL_HEADER_SIZE+1)    // 6 byte header + 1 byte data

    while(freespace < (more.size() * PREALLOC)) {
        more.pop_back();
    }

    unsigned int numactive = more.size();
    unsigned int total = (numactive * PREALLOC);

    for(unsigned int i = 0; i < more.size(); i++) {
        active[more[i]] = true;
        allocs[more[i]] = PREALLOC;
    }

    while(numactive > 0 && total < freespace) {
        unsigned int loopchansize = (freespace - total) / numactive;
        if(loopchansize == 0) loopchansize = 1;

        LOG("stdproto", 5, connection << ": loop chansize is " << loopchansize);

        for(unsigned int i = 0; i < more.size() && total < freespace; i++) {
            if(active[more[i]]) {
                unsigned int csz = loopchansize;

                LOG("stdproto", 5, connection << ": csz for channel " << more[i] << " is " << csz
                    << " active is " << numactive);

                LOG("stdproto", 5, connection << ": pending is " << pendingData(more[i])
                    << " allocated so far is " << allocs[more[i]]);

                unsigned int dif = (pendingData(more[i])+CHANNEL_HEADER_SIZE) - allocs[more[i]];
                if(dif <= csz) {
                    csz = dif;
                    active[more[i]] = false;
                    numactive--;
                }

                LOG("stdproto", 5, connection << ": csz for channel " << more[i] << " is " << csz
                    << " active is " << numactive);
                allocs[more[i]] += csz;
                total += csz;
                LOG("stdproto", 5, connection << ": total is " << total);
            }
        }
    }

    /*
      for(unsigned int i = 0; i < more.size(); i++) {
      LOG("stdproto", 4, "allocated " << allocs[more[i]] << " bytes for channel " << more[i]);
      }
    */

    for(unsigned int i = 0; i < more.size(); i++) {
        packChannel(more[i], buf + ptr, &allocs[more[i]]);
        ptr += allocs[more[i]];

        assert(ptr <= *bufsize);

        if(retransTimer[more[i]] == 0 || backoff[more[i]] > 1) retransTimer[more[i]] = now;
    }

    *bufsize = ptr;
}

void StandardProtocol::rmFromQueue(int channel, uint32_t bytes)
{
    LOG("stdproto", 5, connection << ": removing " << bytes << " bytes from the queue of channel " << channel);
    LOG("stdproto", 5, connection << ": pending data is " << pendingData(channel));

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    while(bytes && !data[channel].empty()) {
        if(data[channel].front()->payload.size() <= bytes) {
            bytes -= data[channel].front()->payload.size();
            if(data[channel].front()->completedCallback) {
                data[channel].front()->completedCallback->notifySendingCompleted(data[channel].front());
            }
            data[channel].pop_front();
        } else {
            data[channel].front()->payload.replace(0, bytes, "");
            bytes = 0;
        }
    }

    if(data[channel].empty()) retransTimer[channel] = 0;

    LOG("stdproto", 5, connection << ": pending data is now " << pendingData(channel));
}


unsigned int StandardProtocol::pendingData(int channel)
{
    unsigned int ret = 0;

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    unsigned int sent = VIP_SEQNUMDIFF(maxSentSeqnum[channel], lastRecvAcks[channel]);

    assert(VIP_SEQNUMDIFF(maxSentSeqnum[channel], lastRecvAcks[channel]) < 0xFFFFF000);

    LOG("stdproto", 5, connection << ": sent for channel " << channel
        << " is " << VIP_SEQNUMDIFF(maxSentSeqnum[channel], lastRecvAcks[channel])
        << " with maxSentSeqnum: " << maxSentSeqnum[channel] << " lastRecvAck: " << lastRecvAcks[channel]);

    for(unsigned int i = 0; i < data[channel].size(); i++) {
        ret += data[channel][i]->payload.size();
    }
    return ret - sent;
}

double StandardProtocol::desiredWaitTime()
{
    double minwait = 1000000;

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        double timeout = ((retransTimer[i] + (connection->getRTO() * backoff[i])) - getTimer());

        if(retransTimer[i] > 0
           && timeout < 0
           && !data[i].empty())
        {
            connection->notifyLoss();
            return 0;
        }

        if(pendingData(i) > 0) return 0;

        if(needAcks[i]) return 0;

        if(retransTimer[i] > 0
           && timeout < minwait
           && !data[i].empty())
        {
            minwait = timeout;
        }
    }

    return minwait;
}

bool StandardProtocol::hasQueuedData()
{
    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    for(int i = 0; i < VIP_STDPROTO_CHANNELS; i++) {
        if(!data[i].empty()) {
            LOG("stdproto", 5, connection << ": channel " << i << " still has unacked data");
            return true;
        }

    }

    return false;
}


unsigned int StandardProtocol::queuedBytes(int channel)
{
    unsigned int ret = 0;

    boost::recursive_mutex::scoped_lock lk(stdproto_mutex);

    for(unsigned int i = 0; i < data[channel].size(); i++) {
        ret += data[channel][i]->payload.size();
    }
    return ret;
}
