#!/usr/bin/env python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys, os, xml

from qpid.spec import load, pythonize
from textwrap import TextWrapper
from xml.sax.handler import ContentHandler

class Block:

  def __init__(self, children):
    self.children = children

  def emit(self, out):
    for child in self.children:
      if not hasattr(child, "emit"):
        raise ValueError(child)
      child.emit(out)

    if not self.children:
      out.line("pass")

class If:

  def __init__(self, expr, cons, alt = None):
    self.expr = expr
    self.cons = cons
    self.alt = alt

  def emit(self, out):
    out.line("if ")
    self.expr.emit(out)
    out.write(":")
    out.level += 1
    self.cons.emit(out)
    out.level -= 1
    if self.alt:
      out.line("else:")
      out.level += 1
      self.alt.emit(out)
      out.level -= 1

class Stmt:

  def __init__(self, code):
    self.code = code

  def emit(self, out):
    out.line(self.code)

class Expr:

  def __init__(self, code):
    self.code = code

  def emit(self, out):
    out.write(self.code)

class Abort:

  def __init__(self, expr):
    self.expr = expr

  def emit(self, out):
    out.line("assert False, ")
    self.expr.emit(out)

WRAPPER = TextWrapper()

def wrap(text):
  return WRAPPER.wrap(" ".join(text.split()))

class Doc:

  def __init__(self, text):
    self.text = text

  def emit(self, out):
    out.line('"""')
    for line in wrap(self.text):
      out.line(line)
    out.line('"""')

class Frame:

  def __init__(self, attrs):
    self.attrs = attrs
    self.children = []
    self.text = None

  def __getattr__(self, attr):
    return self.attrs[attr]

def isunicode(s):
  if isinstance(s, str):
    return False
  for ch in s:
    if ord(ch) > 127:
      return True
  return False

def string_literal(s):
  if s == None:
    return None
  if isunicode(s):
    return "%r" % s
  else:
    return "%r" % str(s)

TRUTH = {
  "1": True,
  "0": False,
  "true": True,
  "false": False
  }

LITERAL = {
  "shortstr": string_literal,
  "longstr": string_literal,
  "bit": lambda s: TRUTH[s.lower()],
  "longlong": lambda s: "%r" % long(s)
  }

def literal(s, field):
  return LITERAL[field.type](s)

def palexpr(s, field):
  if s.startswith("$"):
    return "msg.%s" % s[1:]
  else:
    return literal(s, field)

class Translator(ContentHandler):

  def __init__(self, spec):
    self.spec = spec
    self.stack = []
    self.content = None
    self.root = Frame(None)
    self.push(self.root)

  def emit(self, out):
    blk = Block(self.root.children)
    blk.emit(out)
    out.write("\n")

  def peek(self):
    return self.stack[-1]

  def pop(self):
    return self.stack.pop()

  def push(self, frame):
    self.stack.append(frame)

  def startElement(self, name, attrs):
    self.push(Frame(attrs))

  def endElement(self, name):
    frame = self.pop()
    if hasattr(self, name):
      child = getattr(self, name)(frame)
    else:
      child = self.handle(name, frame)

    if child:
      self.peek().children.append(child)

  def characters(self, text):
    frame = self.peek()
    if frame.text:
      frame.text += text
    else:
      frame.text = text

  def handle(self, name, frame):
    for klass in self.spec.classes:
      pyklass = pythonize(klass.name)
      if name.startswith(pyklass):
        name = name[len(pyklass) + 1:]
        break
    else:
      raise ValueError("unknown class: %s" % name)

    for method in klass.methods:
      pymethod = pythonize(method.name)
      if name == pymethod:
        break
    else:
      raise ValueError("unknown method: %s" % name)

    args = ["%s = %s" % (key, palexpr(val, method.fields.bypyname[key]))
            for key, val in frame.attrs.items()]
    if method.content and self.content:
      args.append("content = %r" % string_literal(self.content))
    code = "ssn.%s_%s(%s)" % (pyklass, pymethod, ", ".join(args))
    if pymethod == "consume":
      code = "consumer_tag = %s.consumer_tag" % code
    return Stmt(code)

  def pal(self, frame):
    return Block([Doc(frame.text)] + frame.children)

  def include(self, frame):
    base, ext = os.path.splitext(frame.filename)
    return Stmt("from %s import *" % base)

  def session(self, frame):
    return Block([Stmt("cli = open()"), Stmt("ssn = cli.channel(0)"),
                  Stmt("ssn.channel_open()")] + frame.children)

  def empty(self, frame):
    return If(Expr("msg == None"), Block(frame.children))

  def abort(self, frame):
    return Abort(Expr(string_literal(frame.text)))

  def wait(self, frame):
    return Stmt("msg = ssn.queue(consumer_tag).get(timeout=%r)" %
                (int(frame.timeout)/1000))

  def basic_arrived(self, frame):
    if frame.children:
      return If(Expr("msg != None"), Block(frame.children))

  def basic_content(self, frame):
    self.content = frame.text

class Emitter:

  def __init__(self, out):
    self.out = out
    self.level = 0

  def write(self, code):
    self.out.write(code)

  def line(self, code):
    self.write("\n%s%s" % ("  "*self.level, code))

  def flush(self):
    self.out.flush()

  def close(self):
    self.out.close()


for f in sys.argv[2:]:
  base, ext = os.path.splitext(f)
  spec = load(sys.argv[1])
  t = Translator(spec)
  xml.sax.parse(f, t)
#  out = Emitter(open("%s.py" % base))
  out = Emitter(sys.stdout)
  t.emit(out)
  out.close()
