#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

import optparse, sys, time
import statistics
from qpid.messaging import *

SECOND = 1000
TIME_SEC = 1000000000

op = optparse.OptionParser(usage="usage: %prog [options]", description="Drains messages from the specified address")
op.add_option("-b", "--broker", default="localhost:5672", type="str", help="url of broker to connect to")
op.add_option("-a", "--address", type="str", help="address to receive from")
op.add_option("--connection-options", default={}, help="options for the connection")
op.add_option("-m", "--messages", default=0, type="int", help="stop after N messages have been received, 0 means no limit")
op.add_option("--timeout", default=0, type="int", help="timeout in seconds to wait before exiting")
op.add_option("-f", "--forever", default=False, action="store_true", help="ignore timeout and wait forever")
op.add_option("--ignore-duplicates", default=False, action="store_true", help="Detect and ignore duplicates (by checking 'sn' header)")
op.add_option("--verify-sequence", default=False, action="store_true", help="Verify there are no gaps in the message sequence (by checking 'sn' header)")
op.add_option("--check-redelivered", default=False, action="store_true", help="Fails with exception if a duplicate is not marked as redelivered (only relevant when ignore-duplicates is selected)")
op.add_option("--capacity", default=1000, type="int", help="size of the senders outgoing message queue")
op.add_option("--ack-frequency", default=100, type="int", help="Ack frequency (0 implies none of the messages will get accepted)")
op.add_option("--tx", default=0, type="int", help="batch size for transactions (0 implies transaction are not used)")
op.add_option("--rollback-frequency", default=0, type="int", help="rollback frequency (0 implies no transaction will be rolledback)")
op.add_option("--print-content", type="str", default="yes", help="print out message content")
op.add_option("--print-headers", type="str", default="no", help="print out message headers")
op.add_option("--failover-updates", default=False, action="store_true", help="Listen for membership updates distributed via amq.failover")
op.add_option("--report-total", default=False, action="store_true", help="Report total throughput statistics")
op.add_option("--report-every", default=0, type="int", help="Report throughput statistics every N messages")
op.add_option("--report-header", type="str", default="yes", help="Headers on report")
op.add_option("--ready-address", type="str", help="send a message to this address when ready to receive")
op.add_option("--receive-rate", default=0, type="int", help="Receive at rate of N messages/second. 0 means receive as fast as possible")
#op.add_option("--help", default=False, action="store_true", help="print this usage statement")

def getTimeout(timeout, forever):
    if forever:
        return None
    else:
        return SECOND*timeout
    

EOS = "eos"
SN = "sn"

# Check for duplicate or dropped messages by sequence number
class SequenceTracker:
    def __init__(self, opts):
        self.opts = opts
        self.lastSn = 0

    # Return True if the message should be procesed, false if it should be ignored.
    def track(self, message):
        if not(self.opts.verify_sequence) or (self.opts.ignore_duplicates):
            return True
        sn = message.properties[SN]
        duplicate = (sn <= lastSn)
        dropped = (sn > lastSn+1)
        if self.opts.verify_sequence and dropped:
            raise Exception("Gap in sequence numbers %s-%s" %(lastSn, sn))
        ignore = (duplicate and self.opts.ignore_duplicates)
        if ignore and self.opts.check_redelivered and (not msg.redelivered):
            raise Exception("duplicate sequence number received, message not marked as redelivered!")
        if not duplicate:
            lastSn = sn
        return (not(ignore))


def main():
    opts, args = op.parse_args()
    if not opts.address:
        raise Exception("Address must be specified!")

    broker = opts.broker
    address = opts.address
    connection = Connection(opts.broker, **opts.connection_options)

    try:
        connection.open()
        if opts.failover_updates:
            auto_fetch_reconnect_urls(connection)
        session = connection.session(transactional=(opts.tx))
        receiver = session.receiver(opts.address)
        if opts.capacity > 0:
            receiver.capacity = opts.capacity
        msg = Message()
        count = 0
        txCount = 0
        sequenceTracker = SequenceTracker(opts)
        timeout = getTimeout(opts.timeout, opts.forever)
        done = False
        stats = statistics.ThroughputAndLatency()
        reporter = statistics.Reporter(opts.report_every, opts.report_header == "yes", stats)

        if opts.ready_address is not None:
            session.sender(opts.ready_address).send(msg)
        if opts.tx > 0:
            session.commit()
        # For receive rate calculation
        start = time.time()*TIME_SEC
        interval = 0
        if opts.receive_rate > 0:
            interval = TIME_SEC / opts.receive_rate
        
        replyTo = {} # a dictionary of reply-to address -> sender mapping 

        while (not done):
            try:
                msg = receiver.fetch(timeout=timeout)
                reporter.message(msg)
                if sequenceTracker.track(msg):
                    if msg.content == EOS:
                        done = True
                    else:
                        count+=1
                        if opts.print_headers == "yes":
                            if msg.subject is not None:
                                print "Subject: %s" %msg.subject
                            if msg.reply_to is not None:
                                print "ReplyTo: %s" %msg.reply_to
                            if msg.correlation_id is not None:
                                print "CorrelationId: %s" %msg.correlation_id
                            if msg.user_id is not None:
                                print "UserId: %s" %msg.user_id
                            if msg.ttl is not None:
                                print "TTL: %s" %msg.ttl
                            if msg.priority is not None:
                                print "Priority: %s" %msg.priority
                            if msg.durable:
                                print "Durable: true"
                            if msg.redelivered:
                                print "Redelivered: true"
                            print "Properties: %s" %msg.properties
                            print
                        if opts.print_content == "yes":
                            print msg.content
                        if (opts.messages > 0) and (count >= opts.messages):
                            done = True
                # end of "if sequenceTracker.track(msg):"
                if (opts.tx > 0) and (count % opts.tx == 0):
                    txCount+=1
                    if (opts.rollback_frequency > 0) and (txCount % opts.rollback_frequency == 0):
                        session.rollback()
                    else:
                        session.commit()
                elif (opts.ack_frequency > 0) and (count % opts.ack_frequency == 0):
                    session.acknowledge()
                if msg.reply_to is not None: # Echo message back to reply-to address.
                    if msg.reply_to not in replyTo:
                        replyTo[msg.reply_to] = session.sender(msg.reply_to)
                        replyTo[msg.reply_to].capacity = opts.capacity
                    replyTo[msg.reply_to].send(msg)
                if opts.receive_rate > 0:
                    delay = start + count*interval - time.time()*TIME_SEC
                    if delay > 0:
                        time.sleep(delay)
                # Clear out message properties & content for next iteration.
                msg = Message()
            except Empty: # no message fetched => break the while cycle
                break
        # end of while cycle
        if opts.report_total:
            reporter.report()
        if opts.tx > 0:
            txCount+=1
            if opts.rollback_frequency and (txCount % opts.rollback_frequency == 0):
                session.rollback()
            else:
                session.commit()
        else:
            session.acknowledge()
        session.close()
        connection.close()
    except Exception,e:
        print e
        connection.close()

if __name__ == "__main__": main()
