#!/usr/bin/env python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

import os
import optparse
import sys
import socket
import locale
from types       import *
from cmd         import Cmd
from shlex       import split
from threading   import Lock
from time        import strftime, gmtime
from qpidtoollibs import Display
from qmf.console import Session, Console, SchemaClass, ObjectId

class Mcli(Cmd):
  """ Management Command Interpreter """

  def __init__(self, dataObject, dispObject):
    Cmd.__init__(self)
    self.dataObject = dataObject
    self.dispObject = dispObject
    self.dataObject.setCli(self)
    self.prompt = "qpid: "
    
  def emptyline(self):
    pass

  def setPromptMessage(self, p):
    if p == None:
      self.prompt = "qpid: "
    else:
      self.prompt = "qpid[%s]: " % p

  def do_help(self, data):
    print "Management Tool for QPID"
    print
    print "Commands:"
    print "    agents                          - Print a list of the known Agents"
    print "    list                            - Print summary of existing objects by class"
    print "    list <className>                - Print list of objects of the specified class"
    print "    list <className> active         - Print list of non-deleted objects of the specified class"
#   print "    show <className>                - Print contents of all objects of specified class"
#   print "    show <className> active         - Print contents of all non-deleted objects of specified class"
    print "    show <ID>                       - Print contents of an object (infer className)"
#   print "    show <className> <list-of-IDs>  - Print contents of one or more objects"
#   print "        list is space-separated, ranges may be specified (i.e. 1004-1010)"
    print "    call <ID> <methodName> [<args>] - Invoke a method on an object"
    print "    schema                          - Print summary of object classes seen on the target"
    print "    schema <className>              - Print details of an object class"
    print "    set time-format short           - Select short timestamp format (default)"
    print "    set time-format long            - Select long timestamp format"
    print "    quit or ^D                      - Exit the program"
    print

  def complete_set(self, text, line, begidx, endidx):
    """ Command completion for the 'set' command """
    tokens = split(line)
    if len(tokens) < 2:
      return ["time-format "]
    elif tokens[1] == "time-format":
      if len(tokens) == 2:
        return ["long", "short"]
      elif len(tokens) == 3:
        if "long".find(text) == 0:
          return ["long"]
        elif "short".find(text) == 0:
          return ["short"]
    elif "time-format".find(text) == 0:
      return ["time-format "]
    return []

  def do_set(self, data):
    tokens = split(data)
    try:
      if tokens[0] == "time-format":
        self.dispObject.do_setTimeFormat(tokens[1])
    except:
      pass

  def complete_schema(self, text, line, begidx, endidx):
    tokens = split(line)
    if len(tokens) > 2:
      return []
    return self.dataObject.classCompletions(text)

  def do_schema(self, data):
    try:
      self.dataObject.do_schema(data)
    except Exception, e:
      print "Exception in do_schema: %r" % e

  def do_agents(self, data):
    try:
      self.dataObject.do_agents(data)
    except Exception, e:
      print "Exception in do_agents: %r" % e

  def do_id(self, data):
    try:
      self.dataObject.do_id(data)
    except Exception, e:
      print "Exception in do_id: %r" % e

  def complete_list(self, text, line, begidx, endidx):
    tokens = split(line)
    if len(tokens) > 2:
      return []
    return self.dataObject.classCompletions(text)

  def do_list(self, data):
    try:
      self.dataObject.do_list(data)
    except Exception, e:
      print "Exception in do_list: %r" % e

  def do_show(self, data):
    try:
      self.dataObject.do_show(data)
    except Exception, e:
      print "Exception in do_show: %r" % e

  def do_call(self, data):
    try:
      self.dataObject.do_call(data)
    except Exception, e:
      print "Exception in do_call: %r" % e

  def do_EOF(self, data):
    print "quit"
    try:
      self.dataObject.do_exit()
    except:
      pass
    return True

  def do_quit(self, data):
    try:
      self.dataObject.do_exit()
    except:
      pass
    return True

  def postcmd(self, stop, line):
    return stop

  def postloop(self):
    print "Exiting..."
    self.dataObject.close()

#======================================================================================================
# QmfData
#======================================================================================================
class QmfData(Console):
  """
  """
  def __init__(self, disp, url, conn_options):
    self.disp = disp
    self.url = url
    self.session = Session(self, manageConnections=True)
    self.broker = self.session.addBroker(self.url, **conn_options)
    self.lock = Lock()
    self.connected = None
    self.closing = None
    self.first_connect = True
    self.cli = None
    self.idRegistry = IdRegistry()
    self.objects = {}

  #=======================
  # Methods to support CLI
  #=======================
  def setCli(self, cli):
    self.cli = cli

  def close(self):
    try:
      self.closing = True
      if self.session and self.broker:
        self.session.delBroker(self.broker)
    except:
      pass   # we're shutting down - ignore any errors

  def classCompletions(self, text):
    pass

  def do_schema(self, data):
    if data == "":
      self.schemaSummary()
    else:
      self.schemaTable(data)

  def do_agents(self, data):
    agents = self.session.getAgents()
    rows = []
    for agent in agents:
      version = 1
      if agent.isV2:
        version = 2
      rows.append(("%d.%s" % (agent.getBrokerBank(), agent.getAgentBank()), agent.label, agent.epoch, version)) 
    self.disp.table("QMF Agents:", ("Agent Name", "Label", "Epoch", "QMF Version"), rows)

  def do_id(self, data):
    tokens = data.split()
    for token in tokens:
      if not token.isdigit():
        print "Value %s is non-numeric" % token
        return
    title = "Translation of Display IDs:"
    heads = ('DisplayID', 'Epoch', 'Agent', 'ObjectName')
    if len(tokens) == 0:
      tokens = self.idRegistry.getDisplayIds()
    rows = []
    for token in tokens:
      rows.append(self.idRegistry.getIdInfo(int(token)))
    self.disp.table(title, heads, rows)

  def do_list(self, data):
    tokens = data.split()
    if len(tokens) == 0:
      self.listClasses()
    else:
      self.listObjects(tokens)

  def do_show(self, data):
    tokens = data.split()
    if len(tokens) == 0:
      print "Missing Class or ID"
      return
    keys = self.classKeysByToken(tokens[0])
    if keys:
      self.showObjectsByKey(keys)
    elif tokens[0].isdigit():
      self.showObjectById(int(tokens[0]))

  def _build_object_name(self, obj):
    values = []
    for p,v in obj.getProperties():
      if p.name != "vhostRef" and p.index == 1:
        if p.name == "brokerRef": # reference to broker
          values.append('org.apache.qpid.broker:broker:amqp-broker')
        else:
          values.append(str(v))

    object_key = ",".join(values)
    class_key = obj.getClassKey();
    return class_key.getPackageName() + ":" + class_key.getClassName() + ":" + object_key
    

  def do_call(self, data):
    tokens = data.split()
    if len(tokens) < 2:
      print "Not enough arguments supplied"
      return
    displayId = long(tokens[0])
    methodName = tokens[1]
    args = []
    for arg in tokens[2:]:
      ##
      ## If the argument is a map, list, boolean, integer, or floating (one decimal point),
      ## run it through the Python evaluator so it is converted to the correct type.
      ##
      ## TODO: use a regex for this instead of this convoluted logic,
      ## or even consider passing all args through eval() [which would
      ## be a minor change to the interface as string args would then
      ## always need to be quoted as strings within a map/list would
      ## now]
      if arg[0] == '{' or arg[0] == '[' or arg[0] == '"' or arg[0] == '\'' or arg == "True" or arg == "False" or \
            ((arg.count('.') < 2 and (arg.count('-') == 0 or \
            (arg.count('-') == 1 and  arg[0] == '-')) and \
            arg.replace('.','').replace('-','').isdigit())):
         args.append(eval(arg))
      else:
         args.append(arg)

    obj = None
    try:
      self.lock.acquire()
      if displayId not in self.objects:
        print "Unknown ID"
        return
      obj = self.objects[displayId]
    finally:
      self.lock.release()

    object_id = obj.getObjectId();
    if not object_id.isV2 and obj.getAgent().isV2:
      object_name = self._build_object_name(obj)
      object_id = ObjectId.create(object_id.agentName, object_name)

    self.session._sendMethodRequest(self.broker, obj.getClassKey(), object_id, methodName, args)


  def do_exit(self):
    pass

  #====================
  # Sub-Command Methods
  #====================
  def schemaSummary(self, package_filter=None):
    rows = []
    packages = self.session.getPackages()
    for package in packages:
      if package_filter and package_filter != package:
        continue
      keys = self.session.getClasses(package)
      for key in keys:
        kind = "object"
        schema = self.session.getSchema(key)
        if schema:
          if schema.kind == SchemaClass.CLASS_KIND_EVENT:
            kind = "event"
          if schema.kind == SchemaClass.CLASS_KIND_TABLE:
            #
            # Don't display event schemata.  This will be a future feature.
            #
            rows.append((package, key.getClassName(), kind))
    self.disp.table("QMF Classes:", ("Package", "Name", "Kind"), rows)

  def schemaTable(self, text):
    packages = self.session.getPackages()
    if text in packages:
      self.schemaSummary(package_filter=text)
    for package in packages:
      keys = self.session.getClasses(package)
      for key in keys:
        if text == key.getClassName() or text == package + ":" + key.getClassName():
          schema = self.session.getSchema(key)
          if schema.kind == SchemaClass.CLASS_KIND_TABLE:
            self.schemaObject(schema)
          else:
            self.schemaEvent(schema)

  def schemaObject(self, schema):
    rows = []
    title = "Object Class: %s" % schema.__repr__()
    heads = ("Element", "Type", "Access", "Unit", "Notes", "Description")
    for prop in schema.getProperties():
      notes = ""
      if prop.index    : notes += "index "
      if prop.optional : notes += "optional "
      row = (prop.name, self.typeName(prop.type), self.accessName(prop.access),
             self.notNone(prop.unit), notes, self.notNone(prop.desc))
      rows.append(row)
    for stat in schema.getStatistics():
      row = (stat.name, self.typeName(stat.type), "", self.notNone(stat.unit), "", self.notNone(stat.desc))
      rows.append(row)
    self.disp.table(title, heads, rows)

    for method in schema.methods:
      rows = []
      heads = ("Argument", "Type", "Direction", "Unit", "Description")
      title = "  Method: %s" % method.name
      for arg in method.arguments:
        row = (arg.name, self.typeName(arg.type), arg.dir, self.notNone(arg.unit), self.notNone(arg.desc))
        rows.append(row)
      print
      self.disp.table(title, heads, rows)

  def schemaEvent(self, schema):
    rows = []
    title = "Event Class: %s" % schema.__repr__()
    heads = ("Element", "Type", "Unit", "Description")
    for arg in schema.arguments:
      row = (arg.name, self.typeName(arg.type), self.notNone(arg.unit), self.notNone(arg.desc))
      rows.append(row)
    self.disp.table(title, heads, rows)

  def listClasses(self):
    title = "Summary of Objects by Type:"
    heads = ("Package", "Class", "Active", "Deleted")
    rows = []
    totals = {}
    try:
      self.lock.acquire()
      for dispId in self.objects:
        obj = self.objects[dispId]
        key = obj.getClassKey()
        index = (key.getPackageName(), key.getClassName())
        if index in totals:
          stats = totals[index]
        else:
          stats = (0, 0)
        if obj.isDeleted():
          stats = (stats[0], stats[1] + 1)
        else:
          stats = (stats[0] + 1, stats[1])
        totals[index] = stats
    finally:
      self.lock.release()

    for index in totals:
      stats = totals[index]
      rows.append((index[0], index[1], stats[0], stats[1]))
    self.disp.table(title, heads, rows)

  def listObjects(self, tokens):
    ckeys = self.classKeysByToken(tokens[0])
    show_deleted = True
    if len(tokens) > 1 and tokens[1] == 'active':
      show_deleted = None
    heads = ("ID", "Created", "Destroyed", "Index")
    rows = []
    try:
      self.lock.acquire()
      for dispId in self.objects:
        obj = self.objects[dispId]
        if obj.getClassKey() in ckeys:
          utime, ctime, dtime = obj.getTimestamps()
          dtimestr = self.disp.timestamp(dtime)
          if dtime == 0:
            dtimestr = "-"
          if dtime == 0 or (dtime > 0 and show_deleted):
            row = (dispId, self.disp.timestamp(ctime), dtimestr, self.objectIndex(obj))
            rows.append(row)
    finally:
      self.lock.release()
    self.disp.table("Object Summary:", heads, rows)

  def showObjectsByKey(self, key):
    pass

  def showObjectById(self, dispId):
    heads = ("Attribute", str(dispId))
    rows = []
    try:
      self.lock.acquire()
      if dispId in self.objects:
        obj = self.objects[dispId]
        caption = "Object of type: %r" % obj.getClassKey()
        for prop in obj.getProperties():
          row = (prop[0].name, self.valueByType(prop[0].type, prop[1]))
          rows.append(row)
        for stat in obj.getStatistics():
          row = (stat[0].name, self.valueByType(stat[0].type, stat[1]))
          rows.append(row)
      else:
        print "No object found with ID %d" % dispId
        return
    finally:
      self.lock.release()
    self.disp.table(caption, heads, rows)

  def classKeysByToken(self, token):
    """
    Given a token, return a list of matching class keys (if found):
    token formats:  <class-name>
                    <package-name>:<class-name>
    """
    pname = None
    cname = None
    parts = token.split(':')
    if len(parts) == 1:
      cname = parts[0]
    elif len(parts) == 2:
      pname = parts[0]
      cname = parts[1]
    else:
      raise ValueError("Invalid Class Name: %s" % token)

    keys = []
    packages = self.session.getPackages()
    for p in packages:
      if pname == None or pname == p:
        classes = self.session.getClasses(p)
        for key in classes:
          if key.getClassName() == cname:
            keys.append(key)
    return keys

  def typeName (self, typecode):
    """ Convert type-codes to printable strings """
    if   typecode == 1:  return "uint8"
    elif typecode == 2:  return "uint16"
    elif typecode == 3:  return "uint32"
    elif typecode == 4:  return "uint64"
    elif typecode == 5:  return "bool"
    elif typecode == 6:  return "short-string"
    elif typecode == 7:  return "long-string"
    elif typecode == 8:  return "abs-time"
    elif typecode == 9:  return "delta-time"
    elif typecode == 10: return "reference"
    elif typecode == 11: return "boolean"
    elif typecode == 12: return "float"
    elif typecode == 13: return "double"
    elif typecode == 14: return "uuid"
    elif typecode == 15: return "field-table"
    elif typecode == 16: return "int8"
    elif typecode == 17: return "int16"
    elif typecode == 18: return "int32"
    elif typecode == 19: return "int64"
    elif typecode == 20: return "object"
    elif typecode == 21: return "list"
    elif typecode == 22: return "array"      
    else:
      raise ValueError ("Invalid type code: %s" % str(typecode))

  def valueByType(self, typecode, val):
    if type(val) is type(None):
      return "absent"
    if   typecode == 1:  return "%d" % val
    elif typecode == 2:  return "%d" % val
    elif typecode == 3:  return "%d" % val
    elif typecode == 4:  return "%d" % val
    elif typecode == 6:  return val
    elif typecode == 7:  return val
    elif typecode == 8:  return strftime("%c", gmtime(val / 1000000000))
    elif typecode == 9:
      if val < 0: val = 0
      sec = val / 1000000000
      min = sec / 60
      hour = min / 60
      day = hour / 24
      result = ""
      if day > 0:
        result = "%dd " % day
      if hour > 0 or result != "":
        result += "%dh " % (hour % 24)
      if min > 0 or result != "":
        result += "%dm " % (min % 60)
      result += "%ds" % (sec % 60)
      return result

    elif typecode == 10: return str(self.idRegistry.displayId(val))
    elif typecode == 11:
      if val:
        return "True"
      else:
        return "False"

    elif typecode == 12: return "%f" % val
    elif typecode == 13: return "%f" % val
    elif typecode == 14: return "%r" % val
    elif typecode == 15: return "%r" % val
    elif typecode == 16: return "%d" % val
    elif typecode == 17: return "%d" % val
    elif typecode == 18: return "%d" % val
    elif typecode == 19: return "%d" % val
    elif typecode == 20: return "%r" % val
    elif typecode == 21: return "%r" % val
    elif typecode == 22: return "%r" % val
    else:
      raise ValueError ("Invalid type code: %s" % str(typecode))

  def accessName (self, code):
    """ Convert element access codes to printable strings """
    if   code == '1': return "ReadCreate"
    elif code == '2': return "ReadWrite"
    elif code == '3': return "ReadOnly"
    else:
      raise ValueError ("Invalid access code: %s" % str(code))

  def notNone (self, text):
    if text == None:
      return ""
    else:
      return text

  def objectIndex(self, obj):
    if obj._objectId.isV2:
      return obj._objectId.getObject()
    result = ""
    first = True
    props = obj.getProperties()
    for prop in props:
      if prop[0].index:
        if not first:
          result += "."
        result += self.valueByType(prop[0].type, prop[1])
        first = None
    return result


  #=====================
  # Methods from Console
  #=====================
  def brokerConnectionFailed(self, broker):
    """ Invoked when a connection to a broker fails """
    if self.first_connect:
      self.first_connect = None
      print "Failed to connect: ", broker.error

  def brokerConnected(self, broker):
    """ Invoked when a connection is established to a broker """
    try:
      self.lock.acquire()
      self.connected = True
    finally:
      self.lock.release()
    if not self.first_connect:
      print "Broker connected:", broker
    self.first_connect = None

  def brokerDisconnected(self, broker):
    """ Invoked when the connection to a broker is lost """
    try:
      self.lock.acquire()
      self.connected = None
    finally:
      self.lock.release()
    if not self.closing:
      print "Broker disconnected:", broker

  def objectProps(self, broker, record):
    """ Invoked when an object is updated. """
    oid = record.getObjectId()
    dispId = self.idRegistry.displayId(oid)
    try:
      self.lock.acquire()
      if dispId in self.objects:
        self.objects[dispId].mergeUpdate(record)
      else:
        self.objects[dispId] = record
    finally:
      self.lock.release()

  def objectStats(self, broker, record):
    """ Invoked when an object is updated. """
    oid = record.getObjectId()
    dispId = self.idRegistry.displayId(oid)
    try:
      self.lock.acquire()
      if dispId in self.objects:
        self.objects[dispId].mergeUpdate(record)
    finally:
      self.lock.release()

  def event(self, broker, event):
    """ Invoked when an event is raised. """
    pass

  def methodResponse(self, broker, seq, response):
    print response


#======================================================================================================
# IdRegistry
#======================================================================================================
class IdRegistry(object):
  """
  """
  def __init__(self):
    self.next_display_id = 101
    self.oid_to_display = {}
    self.display_to_oid = {}
    self.lock = Lock()

  def displayId(self, oid):
    try:
      self.lock.acquire()
      if oid in self.oid_to_display:
        return self.oid_to_display[oid]
      newId = self.next_display_id
      self.next_display_id += 1
      self.oid_to_display[oid] = newId
      self.display_to_oid[newId] = oid
      return newId
    finally:
      self.lock.release()

  def objectId(self, displayId):
    try:
      self.lock.acquire()
      if displayId in self.display_to_oid:
        return self.display_to_oid[displayId]
      return None
    finally:
      self.lock.release()

  def getDisplayIds(self):
    result = []
    for displayId in self.display_to_oid:
      result.append(str(displayId))
    return result

  def getIdInfo(self, displayId):
    """
    Given a display ID, return a tuple of (displayID, bootSequence/Durable, AgentBank/Name, ObjectName)
    """
    oid = self.objectId(displayId)
    if oid == None:
      return (displayId, "?", "unknown", "unknown")
    bootSeq = oid.getSequence()
    if bootSeq == 0:
      bootSeq = '<durable>'
    agent = oid.getAgentBank()
    if agent == '0':
      agent = 'Broker'
    return (displayId, bootSeq, agent, oid.getObject())

#=========================================================
# Option Parsing
#=========================================================

def parse_options( argv ):
    _usage = """qpid-tool [OPTIONS] [[<username>/<password>@]<target-host>[:<tcp-port>]]"""

    parser = optparse.OptionParser(usage=_usage)
    parser.add_option("-b", "--broker", action="store", type="string", metavar="<address>", help="Address of qpidd broker with syntax: [username/password@] hostname | ip-address [:<port>]")
    parser.add_option("--sasl-mechanism", action="store", type="string", metavar="<mech>", help="SASL mechanism for authentication (e.g. EXTERNAL, ANONYMOUS, PLAIN, CRAM-MD5, DIGEST-MD5, GSSAPI). SASL automatically picks the most secure available mechanism - use this option to override.")
    parser.add_option("--sasl-service-name", action="store", type="string", help="SASL service name to use")
    parser.add_option("--ssl-certificate",
                      action="store", type="string", metavar="<path>",
                      help="SSL certificate for client authentication")
    parser.add_option("--ssl-key",
                      action="store", type="string", metavar="<path>",
                      help="Private key (if not contained in certificate)")

    opts, encArgs = parser.parse_args(args=argv)
    try:
        encoding = locale.getpreferredencoding()
        args = [a.decode(encoding) for a in encArgs]
    except:
        args = encArgs

    conn_options = {}
    broker_option = None
    if opts.broker:
        broker_option = opts.broker
    if opts.ssl_certificate:
        conn_options['ssl_certfile'] = opts.ssl_certificate
    if opts.ssl_key:
        if not opts.ssl_certificate:
            parser.error("missing '--ssl-certificate' (required by '--ssl-key')")
        conn_options['ssl_keyfile'] = opts.ssl_key
    if opts.sasl_mechanism:
        conn_options['mechanisms'] = opts.sasl_mechanism
    if opts.sasl_service_name:
        conn_options['service'] = opts.sasl_service_name
    return broker_option, conn_options, args[1:]

#=========================================================
# Main Program
#=========================================================


# Get options specified on the command line
broker_option, conn_options, cargs = parse_options(sys.argv)

_host = "localhost"
if broker_option is not None:
    _host = broker_option
elif len(cargs) > 0:
    _host = cargs[0]

# note: prior to supporting options, qpid-tool assumed positional parameters.
# the first argument was assumed to be the broker address.  The second argument
# was optional, and, if supplied, was assumed to be the path to the
# certificate.  To preserve backward compatibility, accept the certificate if
# supplied via the second parameter.
#
if 'ssl_certfile' not in conn_options:
    if len(cargs) > 1:
        conn_options['ssl_certfile'] = cargs[1]

disp = Display()

# Attempt to make a connection to the target broker
try:
  data = QmfData(disp, _host, conn_options)
except Exception, e:
  if str(e).find("Exchange not found") != -1:
    print "Management not enabled on broker:  Use '-m yes' option on broker startup."
  else:
    print "Failed: %s - %s" % (e.__class__.__name__, e)
  sys.exit(1)

# Instantiate the CLI interpreter and launch it.
cli = Mcli(data, disp)
print("Management Tool for QPID")
try:
  cli.cmdloop()
except KeyboardInterrupt:
  print
  print "Exiting..."
except Exception, e:
  print "Failed: %s - %s" % (e.__class__.__name__, e)

# alway attempt to cleanup broker resources
data.close()
