#!/usr/bin/env python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

import sys
from argparse import ArgumentParser
from datetime import datetime

# Version of this tool software
VERSION = "1.0"

# AMQP 0-10 commands - these increment the command counter
EXEC_COMMANDS = ["ExecutionSync", "ExecutionResult", "ExecutionException", "MessageTransfer", "MessageAccept",
                 "MessageReject", "MessageRelease", "MessageAcquire", "MessageResume", "MessageSubscribe",
                 "MessageCancel", "MessageSetFlowMode", "MessageFlow", "MessageFlush", "MessageStop", "TxSelect",
                 "TxCommit", "TxRollback", "DtxSelect", "DtxStart", "DtxEnd", "DtxCommit", "DtxForget", "DtxGetTimeout",
                 "DtxPrepare", "DtxRecover", "DtxRollback", "DtxSetTimeout", "ExchangeDeclare", "ExchangeDelete",
                 "ExchangeQuery", "ExchangeBind", "ExchangeUnbind", "ExchangeBound", "QueueDeclare", "QueueDelete",
                 "QueuePurge", "QueueQuery", "FileQos", "FileQosOk", "FileConsume", "FileConsumeOk", "FileCancel",
                 "FileOpen", "FileOpenOk", "FileStage", "FilePublish", "FileReturn", "FileDeliver", "FileAck",
                 "FileReject", "StreamQos", "StreamQosOk", "StreamConsume", "StreamConsumeOk", "StreamCancel",
                 "StreamPublish", "StreamReturn", "StreamDeliver"]

LINES_PER_DOT = 100000

class LogLevel:
    CRITICAL = (1, "critical")
    ERROR = (2, "error")
    WARNING = (3, "warning")
    NOTICE = (4, "notice")
    INFO = (5, "info")
    DEBUG = (6, "debug")
    TRACE = (7, "trace")
    @staticmethod
    def get_level(level):
        if level == LogLevel.CRITICAL[1]: return LogLevel.CRITICAL
        if level == LogLevel.ERROR[1]: return LogLevel.ERROR
        if level == LogLevel.WARNING[1]: return LogLevel.WARNING
        if level == LogLevel.NOTICE[1]: return LogLevel.NOTICE
        if level == LogLevel.INFO[1]: return LogLevel.INFO
        if level == LogLevel.DEBUG[1]: return LogLevel.DEBUG
        if level == LogLevel.TRACE[1]: return LogLevel.TRACE
        raise Exception("Unknown log level: %s" % level)

class LogLine:
    def __init__(self, line_no, line):
        self.line_no = line_no 
        self.timestamp = datetime.strptime(line[:19], "%Y-%m-%d %H:%M:%S")
        self.level = LogLevel.get_level(line[20:].split(" ")[0])
        self.line = line[21 + len(self.level[1]):].strip()
        self.cmd_cnt = None
        self.txn_cnt = None
    def __str__(self):
        if self.contains("RECV"): cnt_str = "R"
        else: cnt_str = "    S"
        if self.cmd_cnt is not None: cnt_str += str(self.cmd_cnt)
        set_index = self.find("{")
        header_index = self.find("header")
        content_index = self.find("content")
        if self.txn_cnt is None:
            txn_cnt_str = ""
        else:
            txn_cnt_str = "T%d" % self.txn_cnt
        if header_index != -1 and header_index < set_index: op_str = " + " + self.line[header_index:self.line.rfind("]")]
        elif content_index != -1 and set_index == -1: op_str = " + " + self.line[content_index:self.line.rfind("]")]
        else: op_str = self.line[set_index+1:self.line.rfind("}")]
        return " %7d  %19s %22s %3d  %-10s %-5s %s" % (self.line_no, self.timestamp.isoformat(" "),
                                                       self.get_identifier_remote_addr(), self.get_channel(),
                                                       cnt_str, txn_cnt_str, op_str)
    def contains(self, string):
        return self.line.find(string) != -1
    def find(self, string):
        return self.line.find(string)
    def get_channel(self):
        return int(self.get_named_value("channel"))
    def get_identifier(self):
        return self.line.partition("[")[2].partition("]")[0]
    def get_identifier_remote_addr(self):
        return self.get_identifier().partition("-")[2]
    def get_named_value(self, name):
        return self.line.partition("%s=" % name)[2].partition(";")[0]
    def is_log_level(self, level):
        if self.level is None: return None
        return level[0] == self.level[0]
    def is_frame(self):
        return self.contains("Frame[")

class ConnectionProperty:
    def __init__(self, line):
        self.addr = line.get_identifier_remote_addr()
        self.channel = line.get_channel()
        self.ops = [line]
    def add_op(self, line):
        self.ops.append(line)

class Connection(ConnectionProperty):
    def __init__(self, line):
        ConnectionProperty.__init__(self, line)
        self.sessionList = [] # Keeps session creation order
        self.sessionDict = {} # For looking up by channel no.
    def __str__(self):
        return "Connection %s (ops=%d; sessions=%d):" % (self.addr, len(self.ops), len(self.sessionDict))
    def add_session(self, session):
        self.sessionList.append(session)
        self.sessionDict[session.channel] = session
    def get_session(self, channel):
        return self.sessionDict[channel]

class Session(ConnectionProperty):
    def __init__(self, line):
        ConnectionProperty.__init__(self, line)
        self.name = line.get_named_value("name")
        self.send_cnt = 0
        self.recv_cnt = 0
        self.txn_flag = False
        self.txn_cnt = 0
    def __str__(self):
        if self.txn_flag:
            return " + Session %d (name=%s send-cmds=%d recv-cmds=%d txns=%d):" % (self.channel, self.name,
                                                                                   self.send_cnt, self.recv_cnt,
                                                                                   self.txn_cnt)
        return " + Session %d (name=%s send-cmds=%d recv-cmds=%d non-txn):" % (self.channel, self.name, self.send_cnt,
                                                                               self.recv_cnt)

class TraceAnalysis:
    def __init__(self):
        self.connectionList = [] # Keeps connection creation order
        self.connectionDict = {} # For looking up by connection address
        parser = ArgumentParser(description="Analyze trace level logs from a Qpid broker log file.")
        parser.add_argument("--connection-summary", action="store_true", default=False,
                            help="Hide connection details, provide one-line summary")
        parser.add_argument("--session-summary", action="store_true", default=False,
                            help="Hide session details, provide one-line summary")
        parser.add_argument("-s", "--summary", action="store_true", default=False,
                            help="Hide both connection and session details. Equivalent to --connection-summary and "
                            "--session-summary")
        parser.add_argument("log_file", action="store",
                            help="Log file")
        parser.add_argument("-v", "--version", action='version', version="%%(prog)s %s" % VERSION)
        self.args = parser.parse_args()
    def analyze_trace(self):
        lcnt = 0
        print "Reading trace file %s:" % self.args.log_file
        log_file = open(self.args.log_file, "r")
        try:
            for fline in log_file:
                lcnt += 1
                try:
                    lline = LogLine(lcnt, fline)
                    if lline.is_log_level(LogLevel.TRACE) and lline.is_frame():
                        if lline.contains("{ConnectionStartBody"):
                            conn = Connection(lline)
                            self.connectionList.append(conn)
                            self.connectionDict[conn.addr] = conn
                        elif lline.contains("{Connection"):
                            self.connectionDict[lline.get_identifier_remote_addr()].add_op(lline)
                        elif lline.contains("{SessionAttachBody"):
                            ssn = Session(lline)
                            self.connectionDict[ssn.addr].add_session(ssn)
                        else:
                            ssn = self.connectionDict[lline.get_identifier_remote_addr()].get_session(lline.get_channel())
                            ssn.add_op(lline)
                            if lline.line[lline.find("{") + 1 : lline.find("Body")] in EXEC_COMMANDS:
                                if lline.contains("RECV"):
                                    lline.cmd_cnt = ssn.recv_cnt
                                    if ssn.txn_flag and lline.contains("MessageTransferBody"): lline.txn_cnt = ssn.txn_cnt
                                    ssn.recv_cnt += 1
                                elif lline.contains("SEND") or lline.contains("SENT"):
                                    lline.cmd_cnt = ssn.send_cnt
                                    if ssn.txn_flag and lline.contains("MessageTransferBody"): lline.txn_cnt = ssn.txn_cnt
                                    ssn.send_cnt += 1
                                if lline.contains("xSelectBody"):
                                    ssn.txn_flag = True
                                elif lline.contains("xCommitBody") or lline.contains("xRollbackBody"):
                                    lline.txn_cnt = ssn.txn_cnt
                                    ssn.txn_cnt += 1
                except KeyboardInterrupt as e: raise e
                except: pass
                if (lcnt + 1) % LINES_PER_DOT == 0:
                    sys.stdout.write(".")
                    sys.stdout.flush()
        finally: log_file.close()
        if lcnt > LINES_PER_DOT: print
        print "Read and analyzed", lcnt, "lines."
    def print_analysis(self):
        for c in self.connectionList:
            print
            if not self.args.connection_summary and not self.args.session_summary and not self.args.summary:
                print " ---line  ----------timestamp  -----------connection ssn recv send- txn--  operation---------->"
            print c
            if not self.args.connection_summary and not self.args.summary:
                for o in c.ops:
                    print o
            for s in c.sessionList:
                print s
                if not self.args.session_summary and not self.args.summary:
                    for o in s.ops:
                        print o

def check_python_version(major, minor, micro):
    if sys.version_info < (major, minor, micro):
        print "Incorrect Python version: %s found; >= %d.%d.%d needed." % (sys.version.split()[0], major, minor, micro)
        sys.exit(-1)     

# === Main program ===

if __name__ == '__main__':
    check_python_version(2, 7, 0) # need at least v2.7
    t = TraceAnalysis()
    t.analyze_trace()
    t.print_analysis()
 