#include <vos/vutil/timer.hh>
#include <vos/vutil/log.hh>
#include <vos/vip/lowlatencyproto.hh>
#include <vos/vutil/rand.hh>

using namespace VUtil;
using namespace VIP;

#define VIP_SEQWIN  0xFF

LowLatencyProto::LowLatencyProto(Connection* c)
    : Protocol(c), sentSeqCounter(0), msgidCounter(0),
      heartbeatTime(.1), lastHeartbeat(0),
      needSendAck(false), haveSentEver(false)
{
    sentSeqCounter = rand_locked();
    lastAckNum = sentSeqCounter;
    lastRecvNum = 0;
}

int LowLatencyProto::queueData(Message* m)
{
    boost::mutex::scoped_lock lk(llp_mutex);

    unsigned int bytesSent = 0;
    if(m->msgid != 0) {
        for(unsigned int i = 0; i < outgoingqueue.size(); i++) {
            if(outgoingqueue[i].m->msgid == m->msgid) {
                bytesSent = outgoingqueue[i].bytesSent;
                outgoingqueue.erase(outgoingqueue.begin() + i);
                break;
            }
        }
    } else {
        m->msgid = ++msgidCounter;
    }

    OutgoingMsg ogm;
    ogm.seqnum = ++sentSeqCounter;
    ogm.m.assign(m, true);
    ogm.bytesSent = bytesSent;
    ogm.lastSentTime = 0;
    outgoingqueue.push_back(ogm);

    LOG("LLP", 4, "Queued msg seqnum " << sentSeqCounter << " msgid "
        << m->msgid << " queue size is " << outgoingqueue.size());

    return m->msgid;
}

void LowLatencyProto::handleChunk(uint8_t* buf, unsigned int* sz)
{
    boost::mutex::scoped_lock lk(llp_mutex);

    uint8_t proto, nummsgs;
    uint32_t acknum;
    uint32_t startseqnum;
    uint32_t endseqnum;

    LOG("LLP", 5, "Entered handleChunk with " << *sz << " bytes to handle");

#if 0
    for(int i = 0; i < *sz; i++) {
        printf("%2.2x", (int)buf[i]);
    }
    printf("\n");
#endif

    if(*sz < 14) return;

    unsigned int ptr = structunpack(buf, *sz, "clllc", &proto, &acknum,
                                    &startseqnum, &endseqnum, &nummsgs);

    LOG("LLP", 4, "handleChunk header is proto: " << (int)proto
        << " acknum: " << acknum
        << " start num: " << startseqnum
        << " end num: " << endseqnum
        << " num msg: " << (int)nummsgs);

    uint32_t oldLastRecvNum;
    {
        lastAckNum = acknum;

        LOG("LLP", 5, "lastAckNum is " << lastAckNum);

        unsigned int acked_payload = 0;

        while(!outgoingqueue.empty()
              && VIP_SEQNUMLTE(outgoingqueue.front().seqnum, lastAckNum))
        {
            LOG("LLP", 4, "removing ack'd head of queue, seqnum "
                << outgoingqueue.front().seqnum
                << " lastAckNum is " << lastAckNum);

            acked_payload += outgoingqueue.front().bytesSent;
            if(outgoingqueue.front().m->completedCallback) {
                outgoingqueue.front().m->completedCallback->notifySendingCompleted(outgoingqueue.front().m);
            }
            outgoingqueue.pop_front();
        }

        if(haveSentEver) connection->ackBytes(14 + acked_payload, false);

        if(nummsgs > 0) needSendAck = true;

        // if the last seq num we got is less than the this starting seq
        // num - 1, then drop the packet because there is a gap (should
        // maybe handle this better in the future)
        if(VIP_SEQNUMLT(lastRecvNum, startseqnum-1)) return;

        // Otherwise pick up the new sequence numbers...
        oldLastRecvNum = lastRecvNum;
        if(VIP_SEQNUMLT(lastRecvNum, endseqnum)) lastRecvNum = endseqnum;
    }

    for(unsigned int msgnum = 0; ptr < *sz && msgnum < nummsgs; msgnum++) {
        uint8_t seqoff, msgsize;
        uint32_t msgid;

        ptr += structunpack(buf+ptr, *sz - ptr,
                            "clc", &seqoff, &msgid, &msgsize);

        LOG("LLP", 4, "looking at message with seqnum " << (startseqnum + seqoff)
            << ", our oldLastRecvNum is " << oldLastRecvNum
            << ", msgid is " << msgid);

        if(VIP_SEQNUMLTE(startseqnum + seqoff, oldLastRecvNum))
        {
            ptr += msgsize;
            LOG("LLP", 5, "Not delivering seqnum " << (startseqnum + seqoff)
                << ", already seen it");

            continue;
        }

        if(msgsize > (*sz-ptr)) msgsize = (*sz-ptr);

        vRef<Message> msg(new Message, false);
        msg->connection.assign(connection, true);
        msg->msgid = msgid;
        msg->ToS = VIP_LOWLATENCY;
        msg->payload.assign((char*)(buf+ptr), (int)msgsize);

        LOG("LLP", 5, "delivering " << msgid << " " << (startseqnum + seqoff));

        connection->doCallback(msg);

        ptr += msgsize;
    }

    *sz = ptr;
}

void LowLatencyProto::makeSYN(uint8_t* outgoing, unsigned int* sz)
{
    boost::mutex::scoped_lock lk(llp_mutex);

    *sz = structpack(outgoing, *sz, "clllc",
                     VIP_LOWLATENCY,
                     lastRecvNum,
                     sentSeqCounter,
                     sentSeqCounter,
                     0);
}


bool LowLatencyProto::replySYN(uint8_t* received, unsigned int recvsz,
                               uint8_t* reply, unsigned int* replysz)

{
    boost::mutex::scoped_lock lk(llp_mutex);

    uint8_t proto, nummsgs;
    uint32_t acknum;
    uint32_t startseqnum;
    uint32_t endseqnum;

    structunpack(received, recvsz, "clllc",
                 &proto, &acknum, &startseqnum,
                 &endseqnum, &nummsgs);

    if(lastRecvNum != endseqnum)
    {
        LOG("LLP", 2, "not in setup and syn was "
            << endseqnum << ", expecting " << lastRecvNum);
        return false;
    }

    LOG("LLP", 5, "Got SYN -> " << acknum << " " << endseqnum);

    *replysz = structpack(reply, *replysz, "clllc",
                          VIP_LOWLATENCY,
                          lastRecvNum,
                          sentSeqCounter,
                          sentSeqCounter,
                          0);

    return true;
}

void LowLatencyProto::setupWithSYN(uint8_t* received, unsigned int recvsz)

{
    boost::mutex::scoped_lock lk(llp_mutex);

    uint8_t proto, nummsgs;
    uint32_t acknum;
    uint32_t startseqnum;
    uint32_t endseqnum;

    structunpack(received, recvsz, "clllc",
                 &proto, &acknum, &startseqnum,
                 &endseqnum, &nummsgs);

    lastRecvNum = endseqnum;

    LOG("LLP", 5, "Setting up with SYN -> " << acknum << " " << endseqnum);
}


void LowLatencyProto::getNextChunk(uint8_t* buf, unsigned int *bufsz,
                                   bool ackOnly)
{
    boost::mutex::scoped_lock lk(llp_mutex);

    double now = getTimer();

    if(now < (lastHeartbeat + heartbeatTime))
    {
        *bufsz = 0;
        return;
    }

#if 0
    if (outgoingqueue.size() == 0 && !needSendAck) {
        *bufsz = 0;
        return;
    }
#endif

    double resetTime = heartbeatTime > connection->getHighRTT()
        ? heartbeatTime : connection->getHighRTT();

    unsigned int flushPtr;

    /* Generally speaking, flushPtr should end up at one of three
       places; either the at the beginning (because the oldest
       messages that haven't been ack'd yet will be at the head of the
       queue), at a message that hasn't been sent yet (the
       lastSentTime will be 0 in that case) or at the end (because
       everything queued has been sent and it isn't time to resend
       it).
    */
    for(flushPtr = 0;
        flushPtr < outgoingqueue.size()
            && now < (outgoingqueue[flushPtr].lastSentTime + resetTime);
        flushPtr++);

    LOG("LLP", 4, " flushPtr is " << flushPtr
        << " queue len is " << outgoingqueue.size()
        << " now is " << now
        << " resetTime is " << resetTime);

    if(flushPtr < outgoingqueue.size()) {
        LOG("LLP", 5, "sending some data " << flushPtr << " " << outgoingqueue.size());

        unsigned int i = flushPtr;
        uint32_t ptr = 14;

        uint32_t startSeqNum = (flushPtr == 0 ? lastAckNum+1
                                    : outgoingqueue[flushPtr].seqnum);

        LOG("LLP", 4, "Beginning message pack with startSeqNum " << startSeqNum);

        uint8_t n = 0;
        while(i < outgoingqueue.size()
              && (ptr+outgoingqueue[i].m->payload.size()+6) < *bufsz)
        {
            LOG("LLP", 4, "packing and sending msgid " << outgoingqueue[i].m->msgid << " seqnum " << outgoingqueue[i].seqnum);

            ptr += structpack(buf + ptr, 6, "clc",
                              (uint8_t)(outgoingqueue[i].seqnum - startSeqNum),
                              outgoingqueue[i].m->msgid,
                              (uint8_t)outgoingqueue[i].m->payload.size());
            memcpy(buf+ptr, outgoingqueue[i].m->payload.data(),
                   outgoingqueue[i].m->payload.size());

            ptr += outgoingqueue[i].m->payload.size();

            outgoingqueue[i].bytesSent += 6 + outgoingqueue[i].m->payload.size();
            outgoingqueue[i].lastSentTime = now;

            i++;
            n++;
        }

        structpack(buf, 14, "clllc",
                   VIP_LOWLATENCY,
                   lastRecvNum,
                   startSeqNum,
                   outgoingqueue[i-1].seqnum,
                   n);

        LOG("LLP", 4, "getNextChunk header is proto: " << VIP_LOWLATENCY
            << " acknum: " << lastRecvNum
            << " start num: " << startSeqNum
            << " end num: " << outgoingqueue[i-1].seqnum
            << " num msg: " << (int)n);

        //std::cout << skt << ": Sending pkt of size " << ptr << std::endl;

        *bufsz = ptr;

        LOG("LLP", 6, "flushing from " << flushPtr << " to "
            << i << " with last seqnum " << outgoingqueue[i-1].seqnum);

        needSendAck = false;
    } else {
        /* we're not done with this heartbeat until we've exhausted
           the current outgoing queue, so we don't move the
           lastHeartbeat timer forward until we're really ready.
         */
        lastHeartbeat = now;

        if(needSendAck) {
            LOG("LLP", 5, "sending ACK on " << lastRecvNum);

            unsigned int ptr = 14;

            LOG("LLP", 5, "values are " << lastAckNum << " " << sentSeqCounter);

            // this is asserting that the last thing ACK'd by the
            // other side is the last thing we sent.
            // I don't think this is an error...  not sure why
            // I added the assertion originally.
            //assert(lastAckNum == sentSeqCounter);

            structpack(buf, 14, "clllc",
                       VIP_LOWLATENCY,
                       lastRecvNum,
                       sentSeqCounter,
                       sentSeqCounter,
                       0);

            LOG("LLP", 4, "getNextChunk header is proto: " << VIP_LOWLATENCY
                << " acknum: " << lastRecvNum
                << " start num: " << sentSeqCounter
                << " end num: " << sentSeqCounter
                << " num msg: " << (int)0);

            //std::cout << skt << ": Sending pkt of size " << ptr << std::endl;

            *bufsz = ptr;
            needSendAck = false;
        } else {
            *bufsz = 0;
        }
    }

    haveSentEver = true;

    LOG("LLP", 5, "will send " << *bufsz << " bytes and flushPtr is " << flushPtr);
}

double LowLatencyProto::desiredWaitTime()
{
    boost::mutex::scoped_lock lk(llp_mutex);

    double now = getTimer();
    double resetTime = heartbeatTime > connection->getHighRTT()
        ? heartbeatTime : connection->getHighRTT();

    unsigned int flushPtr;

    for(flushPtr = 0;
        flushPtr < outgoingqueue.size()
            && now < (outgoingqueue[flushPtr].lastSentTime + resetTime);
        flushPtr++);

    LOG("LLP", 6, "flushptr is " << flushPtr
        << " outgoing queue size is " << outgoingqueue.size());

    if(flushPtr < outgoingqueue.size()
       && now >= (lastHeartbeat + heartbeatTime))
    {
        return 0;
    } else {
        return (lastHeartbeat + heartbeatTime) - now;
    }
}

bool LowLatencyProto::hasQueuedData()
{
    boost::mutex::scoped_lock lk(llp_mutex);

    LOG("LLP", 5, "outgoingqueue.empty() = " << outgoingqueue.empty());

    return !outgoingqueue.empty();
}

unsigned int LowLatencyProto::queuedBytes(int channel)
{
    boost::mutex::scoped_lock lk(llp_mutex);

    unsigned int ret = 0;

    for(unsigned int i = 0; i < outgoingqueue.size(); i++) {
        ret += outgoingqueue[i].m->payload.size();
    }

    return ret;
}
