# -*- coding: iso-8859-1 -*-
"""
This module is an implementation of a table-controlled shift-reduce
parser (a.k.a. Tomita, though we're not quite as sophisticated).

Warning: For ambiguous grammars, we only return one derivation.

Ihr müsst das hier nicht verstehen, wenn ihr gerade erst Programmieren I
hört...
"""

import re, operator
from sets import Set
import grammar, symbol, word

# Extremely bad and dangerous hack: We need to retrofit hashing into
# Terminals
symbol.Terminal.__hash__ = lambda self: hash(self.content)


# Special symbol of end of input
eOF = symbol.NonTerminal("$")


class Error(Exception):
	pass


class MultiStack(list):
	"""Alternative implementation of grammar.Multistack for quicker
	copying (but slower access).
	"""
	def __str__(self):
		return "[%s]"%(", ".join(map(lambda a: "%s/%s"%a, self)))

	def copy(self):
		return MultiStack(self[:])
	
	def push(self, cat, val):
		self.append((cat, val))
	
	def pop(self, cat):
		for pos in xrange(len(self)-1, -1, -1):
			if self[pos][0]==cat:
				return list.pop(self, pos)[1]


class ShiftReduceGrammar(grammar.Grammar):
	def __init__(self, sourceName, ruleString=None, *args, **kwargs):
		if ruleString is None:
			ruleString = open(sourceName).read()
		rawGrammar = grammar.Grammar(sourceName, ruleString, *args, **kwargs)
		ruleString = ruleString+"\nShiftReduceStartSymbol -> %s\n"%(
			rawGrammar.getStartSymbol())
		grammar.Grammar.__init__(self, None, ruleString=ruleString, 
			*args, **kwargs)

	def getShiftReduceStartSym(self):
		return symbol.NonTerminal("ShiftReduceStartSymbol")

	def getRules(self):
		return reduce(operator.add, self.ruleDict.itervalues())

	def first(self, sym):
		if not hasattr(self, "_firstCache"):
			self._firstCache = { eOF: Set([eOF]) }
		if self._firstCache.has_key(sym):
			return self._firstCache[sym]
		targetSet = self._firstCache[sym] = Set()
		if isinstance(sym, symbol.Terminal):
			targetSet.add(sym)
		else:
			for rule in self.getRulesForNonTerm(sym):
				try:
					targetSet.union_update(self.first(rule.getRight()[0]))
				except IndexError:
					targetSet.add(word.epsilon)
		return targetSet

	def follow(self, nonTerm):
		if not hasattr(self, "_followCache"):
			self._followCache = {}
		if self._followCache.has_key(nonTerm):
			return self._followCache[nonTerm]
		followSet = self._followCache[nonTerm] = Set()
		for rule in self.getRules():
			rightSide = rule.getRight()
			for pos, sym in enumerate(rightSide):
				if sym==nonTerm:
					try:
						nextSym = rightSide[pos+1]
						if isinstance(nextSym, symbol.Terminal):
							followSet.add(nextSym)
						else:
							followSet.union_update(self.first(nextSym))
					except IndexError:
						if rule.getLeft()==self.getShiftReduceStartSym():
							followSet.add(eOF)
						else:
							followSet.union_update(self.follow(rule.getLeft()))
		return followSet


class ShiftReduceState:
	"""don't construct directly, use create static method.
	"""
	createdDict = {}
	def __init__(self, stateIndex, pertainingRule, dotPos, gramm):
		self.stateIndex = stateIndex
		self.symsBefore = pertainingRule.getRight()[:dotPos]
		self.dpSet = Set([(pertainingRule, dotPos)])
		self.grammar = gramm
		self._collectAdditionalDps(pertainingRule, dotPos)

	def __str__(self):
		return ('State %d\n'
			'  %s')%(self.stateIndex,
			"\n  ".join(map(str, self.dpSet)))
	
	__repr__ = __str__

	def _collectAdditionalDps(self, rule, pos):
		try:
			curSym = rule.getRight()[pos]
		except IndexError:
			return
		if isinstance(curSym, symbol.Terminal):
			return
		for depRule in self.grammar.getRulesForNonTerm(curSym):
			if (depRule, 0) in self.dpSet:
				return
			self.dpSet.add((depRule, 0))
			self._collectAdditionalDps(depRule, 0)

	def _addToPertainingRules(self, rule, dotPos):
		self.dpSet.add((rule, dotPos))
		self._collectAdditionalDps(rule, dotPos)

	def iterRulesAndDotPos(self):
		for rule, dotPos in self.dpSet:
			yield rule, dotPos

	def getIndex(self):
		return self.stateIndex

	def getDotPos(self):
		return self.dotPos

	def getDotPairs(self):
		return self.dpSet
	
	def create(klass, pertainingRule, dotPos, gramm):
		"""@classmethod...
		"""
		key = str([pertainingRule.getLeft()
			]+pertainingRule.getRight()[:dotPos])
		if klass.createdDict.has_key(key):
			newState = klass.createdDict[key]
			newState._addToPertainingRules(pertainingRule, dotPos)
		else:
			klass.createdDict[key] = newState = klass(len(klass.createdDict), 
				pertainingRule, dotPos, gramm)
		return newState
	create = classmethod(create)

	def getDotPairIndex(klass):
		"""@classmethod...
		"""
		dpIndex = {}
		for state in klass.createdDict.values():
			for dps in state.getDotPairs():
				dpIndex.setdefault(dps, []).append(state.getIndex())
		return dpIndex
	getDotPairIndex = classmethod(getDotPairIndex)
			

class _ParserState:
	def __init__(self, curState, toParse, curPos, stateStack, rulesApplied,
			depth=0):
		self.curState = curState
		self.toParse = toParse
		self.curPos = curPos
		self.stateStack = stateStack
		self.rulesApplied = rulesApplied
		self.depth = depth

	def __str__(self):
		return "state%s pos%d -- %s"%(self.curState, self.curPos, self.stateStack)

	def copy(self):
		return _ParserState(self.curState, self.toParse, self.curPos, 
			self.stateStack.copy(), self.rulesApplied[:], self.depth+1)

	def getIndent(self):
		return "  "*self.depth


class ShiftReduceParser:
	"""As I said: You're not supposed to understand all this if you're
	a newbie.  Try again after a parsing course.
	"""
	def __init__(self, gramm, debug=0):
		self.debug = debug
		self.grammar = gramm
		self.realStartSym = self.grammar.getShiftReduceStartSym()
		self.startSymbol = self.grammar.getStartSymbol()
		self._buildStateList()
		self._setStartState()
		self._buildTables()
	
	def _buildStateList(self):
		self.stateDict = {}
		for rule in self.grammar.getRules():
			if rule.getLeft()==self.realStartSym:
				dotRange = range(len(rule.getRight())+1)
			else:
				dotRange = range(1, len(rule.getRight())+1)
			for dotPos in dotRange:
				newState = ShiftReduceState.create(rule, dotPos, self.grammar)
				self.stateDict[newState.getIndex()] = newState
		if self.debug:
			print "\n".join(map(str, self.stateDict.values()))

	def _setStartState(self):
		for state in self.stateDict.itervalues():
			for rule, dotPos in state.iterRulesAndDotPos():
				if dotPos==0 and rule.getLeft()==self.realStartSym:
					self.startState = state.getIndex()
					break

	def _enterShiftInstruction(self, dotPos, rule, state, dpIndex):
		stateInd = state.getIndex()
		try:
			for targetState in dpIndex[rule, dotPos+1]:
				self.actionTable.setdefault(
					(stateInd, rule.getRight()[dotPos]), Set()).add(
						("shift", targetState))
		except KeyError:
			pass

	def _enterGotoInstruction(self, dotPos, rule, state, dpIndex):
		stateInd = state.getIndex()
		try:
			for targetState in dpIndex[rule, dotPos+1]:
				self.gotoTable.setdefault(
					(stateInd, rule.getRight()[dotPos]), Set()).add(
						targetState)
		except KeyError:
			pass

	def _enterReduceInstruction(self, rule, state):
		stateInd = state.getIndex()
		head = rule.getLeft()
		for folSym in self.grammar.follow(rule.getLeft()):
			for sym in self.grammar.first(folSym):
				self.actionTable.setdefault((stateInd, sym), Set()).add(
					("reduce", rule))
		if head==self.realStartSym:
			self.actionTable.setdefault((stateInd, eOF), Set()).add(
				("accept", None))

	def _buildTables(self):
		self.actionTable = {}
		self.gotoTable = {}
		dpIndex = ShiftReduceState.getDotPairIndex()
		for state in self.stateDict.itervalues():
			for rule, dotPos in state.iterRulesAndDotPos():
				try:
					sym = rule.getRight()[dotPos]
				except IndexError:
					self._enterReduceInstruction(rule, state)
				else:
					if isinstance(sym, symbol.Terminal):
						self._enterShiftInstruction(dotPos, rule, state, dpIndex)
					elif isinstance(sym, symbol.NonTerminal):
						self._enterGotoInstruction(dotPos, rule, state, dpIndex)
					else:
						raise Error, "Invalid sym: %s"%sym
		self._cleanupTable(self.actionTable)
		self._cleanupTable(self.gotoTable)

	def _cleanupTable(setDict):
		for key, value in setDict.iteritems():
			if isinstance(value, Set):
				if len(value)==1:
					setDict[key] = value.pop()
	_cleanupTable = staticmethod(_cleanupTable)
	
	def asHtml(self):
		def buildTableFromTupleDict(tupleDict):
			keys = tupleDict.keys()
			inds = list(Set([key[0] for key in keys]))
			names = list(Set([key[1] for key in keys]))
			names.sort()
			inds.sort()
			preHeader = "</th><th>".join(map(str, names))
			lines = []
			for ind in inds:
				accu = ['<tr><th>%s</th>'%str(ind)]
				for name in names:
					accu.append("<td>%s</td>"%str(tupleDict.get((ind, name), "&nbsp;")))
				accu.append("</tr>")
				lines.append("".join(accu))
			return "<table border=1><tr><th>state#</th><th>%s</th>\n%s\n</table>"%(
				preHeader, "\n".join(lines))
		return ('<head><title>junk</title></head>'
		'<body><h1>Action Table</h1>\n'
		'%(actionTable)s\n'
		'<h1>Goto Table</h1>\n'
		'%(gotoTable)s\n'
		'</body>\n')%{
			"actionTable": buildTableFromTupleDict(self.actionTable),
			"gotoTable": buildTableFromTupleDict(self.gotoTable)}

	def _getNextAction(self, ps):
		try:
			curSym = ps.toParse[ps.curPos]
		except IndexError:
			curSym = eOF
		if self.debug:
			print ps.getIndent()+"->current Symbol here: %s"%curSym
		try:
			return self.actionTable[ps.curState, curSym]
		except KeyError:
			if self.debug:
				print ps.getIndent()+"No action, aborting"
			return None

	def _runShiftAction(self, action, ps):
		ps.stateStack.push(ps.toParse[ps.curPos], ps.curState)
		ps.curState = action[1]
		ps.curPos += 1
		return ("continue", None)

	def _runReduceAction(self, action, ps):
		rule = action[1]
		ps.rulesApplied.append(rule)
		toPop = rule.getRight()[:]
		toPop.reverse()
		for sym in toPop:
			tmpState = ps.stateStack.pop(sym)
		try:
			nextStates = self.gotoTable[tmpState, rule.getLeft()]
		except KeyError:
			return ("abort", None)
		ps.stateStack.push(rule.getLeft(), tmpState)
		if isinstance(nextStates, Set):
			if self.debug:
				print ps.getIndent()+"Branching on Goto:", nextStates
			nextStates = nextStates.copy()
			while len(nextStates)>1:
				newPs = ps.copy()
				newPs.curState = nextStates.pop()
				res = self._runParser(newPs)
				if res is not None:
					return ("accept", res)
				if self.debug:
					print ps.getIndent()+"Trying next branch"
			if self.debug:
				print ps.getIndent()+"resuming", id(ps), ps
			ps.curState = nextStates.pop()
		else:
			ps.curState = nextStates
		return ("continue", None)

	def _runAcceptAction(self, action, ps):
		return ("accept", ps.rulesApplied)

	actionHandlerTable = {
		'accept': _runAcceptAction,
		'shift': _runShiftAction,
		'reduce': _runReduceAction,
	}

	def _runNextAction(self, nextAction, ps):
		if isinstance(nextAction, Set):
			if self.debug:
				print ps.getIndent()+"Branching on Actions:", nextAction
			nextAction = nextAction.copy()
			while len(nextAction)>1:
				res = self._runParser(ps.copy(), nextAction.pop())
				if res is not None:
					ps.rulesApplied = res
					return ('accept', ps.rulesApplied)
				if self.debug:
					print ps.getIndent()+"Trying next branch"
			if self.debug:
				print ps.getIndent()+"resuming", id(ps), ps
			nextAction = nextAction.pop()
		if self.debug:
			print ps.getIndent()+"Now executing", nextAction
		return self.actionHandlerTable[nextAction[0]](self, nextAction, ps)

	def _runParser(self, ps, nextAction=None):
		if self.debug:
			print ps.getIndent()+"New Branch:", id(ps), ps, nextAction
		try:
			while 1:
				if nextAction is None:
					nextAction = self._getNextAction(ps)
				if nextAction is None:
					return None
				res = self._runNextAction(nextAction, ps)
				if res[0]=="accept":
					return res[1]
				elif res[0]=="abort":
					if self.debug:
						print ps.getIndent()+"Aborting on action request"
					return None
				elif res[0]=="continue":
					pass
				else:
					raise Error("Unknown parser action: %s"%str(res))
				if self.debug:
					print ps.getIndent()+"State after action:", ps
				nextAction = None
		finally:
			if self.debug:
				print ps.getIndent()+"End Branch:", id(ps)

	def parse(self, toParse):
		ps = _ParserState(
			curState = self.startState,
			curPos = 0,
			toParse = toParse,
			stateStack = MultiStack([(self.realStartSym, self.startState)]),
			rulesApplied = []
		)
		rulesApplied = self._runParser(ps)
		if rulesApplied is not None:
			rulesApplied.reverse()
		return rulesApplied

			
def _test():
	gramm = ShiftReduceGrammar(None, ruleString='S -> NP VP\n'
		'VP -> "vi"\n'
		'VP -> "vt" NP\n'
		'VP -> "vt" NP PP\n'
		'NP -> "n"\n'
		'NP -> "det" "n"\n'
		'NP -> "det" "adj" "n"\n'
		'PP -> "prep" NP\n'
		'=S\n')
	tp = ShiftReduceParser(gramm)
	open("zw.html", "w").write(tp.asHtml())
	print tp.parse(word.Word('"n" "vt" "n" "prep" "n"'))

def _test2():
	gramm = ShiftReduceGrammar(None, ruleString='S -> NP VP\n'
		'VP -> "vi"\n'
		'VP -> "vt" NP\n'
		'VP -> VP PP\n'
		'NP -> "n"\n'
		'NP -> "det" "n"\n'
		'NP -> "det" "adj" "n"\n'
		'NP -> NP PP\n'
		'PP -> "prep" NP\n'
		'=S\n')
	print "\n".join(map(str, gramm.getRules()))
	tp = ShiftReduceParser(gramm)
	open("zw.html", "w").write(tp.asHtml())
	print tp.parse(word.Word('"n" "vt" "n" "prep" "n"'))


if __name__=="__main__":
	_test2()


