#!python2.4.4
# -*- coding: iso-8859-1 -*-

"""
Project: LRApy
Author: Max Jakob (max.jakob@web.de)
Module: lra
Module Description: This (main) module contains LRA specific data
structures, a function to read in task files, and methods that implement
the LRA algoritm for solving word analogy questions. For a detailed
description of the algorithm see http://arxiv.org/abs/cs/0412024

Version: 1.2
Last change: 2007-02-15

Copyright 2007 by Max Jakob.
This code is released under the GNU GPL. See the accompanying LICENSE file.

Embedded documentation can be translated with the Python pydoc module.
"""


# in this program MAX_PHRASE is always MAX_INTER + 2
NUM_SIM = 10         # number of synonyms to get from the thesaurus
NUM_FILTER = 3       # number of synynoms to keep
MIN_INTER = 1        # minimum number of intervening words (must be >= 1 !)
MAX_INTER = 3        # maximum number of intervening words
NUM_PATTERNS = 4000  # number of patterns to keep
k = 300              # number of columns to which is projected after SVD

USE_ALTERNATES = 0   # flag indicating if alternates are used or not
USE_ENTROPY = 1      # flag indicating if entropy is calculated or not
USE_SVD = 1          # flag indicating if SVD is applied or not
                     # SVD requires a lot of memory!


import matrix, search, synonym, sys


class LRAInputError(Exception):
	"""General error that is raised if data is not entered in the correct
	syntax, if input information is eihter missing or contradictory, or if
	the corpus search results conflict with contraints of the program.
	"""
	pass

class WordPair:
	"""A WordPair contains the original and alternate word pairs. It is
	possible to iterate over all of them. The alternates can be set by calling
	setAlternates.
	The <relationalSimDict> dictionary contains tuples of words as keys and
	relational similarities of the WordPair to those words in tuple as values.
	"""
	def __init__(self, pair):
		self.original = self._getPairTuple(pair)
		self.allPairs = [self._getPairTuple(pair)]
		self.relationalSimDict = {}

	def __cmp__(self, other):
		return cmp(self.original, other.original)

	def __str__(self):
		return ":".join(self.original)

	def printAll(self):
		"""Prints the original and all alternates.
		"""
		print ":".join(self.original) +  " " + ", ".join(
			[":".join(wordTuple) for wordTuple in self.allPairs
			if wordTuple != self.original])
			
	def _getPairTuple(self, pair):
		"""Returns a tuple with a word pair, if <pair> is a string of the
		form "Word1:Word2" or a tuple (Word1,Word2), and raises an Exception
		otherwise. The words in the tuple are lowered and striped.
		"""
		if isinstance(pair, tuple):
			wordTuple = pair
		elif isinstance(pair, str):
			wordTuple = tuple(pair.split(":"))
		else:
			msg = "pair must be tuble or string, got %s"%type(pair)
			raise LRAInputError, msg
		if len(wordTuple) != 2:
			msg = "WordPair length must be 2, got '%s'"%str(pair)
			raise LRAInputError, msg
		if " " in wordTuple[0].strip() or " " in wordTuple[1].strip():
			msg = "only single words allowed, got '%s'"%str(pair)
			raise LRAInputError, msg
		return (wordTuple[0].strip().lower(), wordTuple[1].strip().lower())

	def getAllPairs(self):
		"""Returns a list of tuples: The original and the alternate word
		pairs, if they have been set before.
		"""
		return self.allPairs

	def deletePair(self, pairTuple):
		"""Erases <pairTuple> out of the allPairs list, if <pairTuple> is
		not the original. An original can not be deleted, because it is
		essential when calculating the relational similarity.
		"""
		pair = self._getPairTuple(pairTuple)
		if pair == self.original:
			print "  found no phrases for original '%s'"%":".join(pair)
		elif pair in self.allPairs:
			del self.allPairs[self.allPairs.index(pair)]

	def _getAlternateCandidates(self, nrToPull):
		"""Returns a list of tuples, with the <nrToPull> top synonyms from
		the synonym module.
		"""
		candidates = [(self.original[0], syn)
			for syn in synonym.getSynonymList(self.original[0], nrToPull)]
		candidates.extend([(syn, self.original[1])
			for syn in synonym.getSynonymList(self.original[1], nrToPull)])
		return candidates

	def setAlternates(self, nrToPull, nrToKeep, minInter, maxInter):
		"""Sets the alternates to the original word pair. All accepted
		alternates are appended to the allPairs list. <nrToPull>
		synonyms are fetched from the thesaurus, the top frequent <nrToKeep>
		pairs in the corpus remain. Between the two words there may only occur
		phrases of length minimum <minInter> and maximum <maxInter>.
		"""
		sys.stdout.write("  %s ...             \r"%":".join(self.original))
		candidateFreqDict = {}
		for candidatePair in self._getAlternateCandidates(nrToPull):
			candidateFreqDict[candidatePair] = se.countWordPairInScope(
				candidatePair[0], candidatePair[1], minInter, maxInter)
		synList = [(frq,pair) for pair,frq in candidateFreqDict.iteritems()]
		synList.sort(reverse=True)
		self.allPairs.extend([self._getPairTuple(pair)
			for frq,pair in synList[:nrToKeep]])
	
	def setSimilarity(self, stemTuple, relationalSimilarity):
		"""Sets the relational similarity of the WordPair instance to the
		<stemTuple> WordPair to <relationalSimilarity>.
		"""
		self.relationalSimDict[stemTuple] = relationalSimilarity

	def getSimilarity(self, stemTuple):
		"""Returns the relational similarity of the WordPair instance to the
		<stemTuple> WordPair.
		"""
		if self.relationalSimDict:
			return self.relationalSimDict.get(stemTuple, 0)
		else:
			print "  no similarity assigned to '%s'"%":".join(self.original)
			return 0
		

class Task:
	"""A Task-object represents one analogy-question that is to solve by LRA.
	There are the following attributes:
	- stem: A WordPair whoms relation is too be found in the choices.
	- choices: A list of WordPairs, from which to choose the corresponding
	  relation to the stem.
	- rightChoice: A WordPair that is equal to one of the choices, and whoms
	  relation is similar to that of the stem.
	"""
	def __init__(self, stem, choices, rightChoice):
		self.setStem(stem)
		self.setChoices(choices)
		self.setRightChoice(rightChoice)
		self.lraChoice = WordPair("none:none")
	
	def __str__(self):
		return "%s\n- %s\nRight choice: %s\nLRA choice: %s"%(
			str(self.stem).upper(),
			"\n- ".join([str(choice) for choice in self.choices]),
			str(self.rightChoice),
			str(self.lraChoice))
		
	def setStem(self, newStem):
		self.stem = WordPair(newStem)

	def setChoices(self, newChoices):
		self.choices = [WordPair(c) for c in newChoices]

	def setRightChoice(self, newRightChoice):
		self.rightChoice = WordPair(newRightChoice)
		if self.rightChoice not in self.choices:
			raise LRAInputError, "right choice is not in choices"

	def getAllPairTuples(self):
		"""Returns a list of all word pairs that are involved in the task
		as tuples.
		"""
		allTuplesList = []
		for wPair in [self.stem] + self.choices:
			for word1,word2 in wPair.getAllPairs():
				allTuplesList.append((word1,word2))
		return allTuplesList

	def setAllAlternates(self, numSim, numFilter, minInter, maxInter):
		"""Sets all alternates of the task.
		"""
		for wPair in [self.stem] + self.choices:
			wPair.setAlternates(numSim, numFilter, minInter, maxInter)

	def setPhraseCache(self, minInter, maxInter):
		"""Searches for all word pairs involved in the task, and saves all 
		intervening phrases in the cache of the search engine, if they have
		not been stored there. Phrases must	have a length of at least
		<minInter> and <maxInter> the most.
		"""
		for wPair in [self.stem] + self.choices:
			for word1,word2 in wPair.getAllPairs():
				if not se.getCachedPhrasesForWordPair(word1,word2):
					se.cacheInterPhrases(word1, word2, minInter, maxInter)

	def getAllCachedPhrases(self):
		"""Returns all found phrases of all word pairs, that were found when
		calling setPhraseCache.
		"""
		return se.getCachedPhrases()

	def deleteFutureZeroVectors(self):
		"""Deletes all word pairs that have no phrases in the corpus. If the
		stem pair has no phrases in the corpus, an LRAInputError is raised.
		"""
		foundStemPhrase = False
		for word1,word2 in self.stem.getAllPairs():
			if not se.getCachedPhrasesForWordPair(word1, word2):
				self.stem.deletePair((word1,word2))
			else:
				foundStemPhrase = True
		if not foundStemPhrase:
			raise LRAInputError, "no phrases found for stem '%s'"%self.stem
		for wPair in self.choices:
			for word1,word2 in wPair.getAllPairs():
				if not se.getCachedPhrasesForWordPair(word1, word2):
					wPair.deletePair((word1,word2))
			
	def compareAllChoicesWithStem(self, aMatrix):
		"""Compares all choices with the stem by cosinus similarity, and
		sets the similarity list to every choices WordPair, so that the
		average similarity can be computed later.
		"""
		for choicePair in self.choices:
			orgininalsSim = 0
			simsBiggerOriginalsSim = []
			for choiceTuple in choicePair.getAllPairs():
				for stemTuple in self.stem.getAllPairs():
					sim = float(matrix.cosinus(
						aMatrix.getRowVector(stemTuple),
						aMatrix.getRowVector(choiceTuple)))
					if not orgininalsSim:
						# compared both originals (first in iteration lists)
						orgininalsSim = sim
						simsBiggerOriginalsSim.append(sim)
						print "%25s :: %s - %f (original pairs)"%(
							":".join(choiceTuple), ":".join(stemTuple), sim)
					elif sim >= orgininalsSim:
						simsBiggerOriginalsSim.append(sim)
						print "%25s :: %s - %f (cos >= original pairs)"%(
							":".join(choiceTuple), ":".join(stemTuple), sim)
					else:
						print "%25s :: %s - %f"%(
							":".join(choiceTuple), ":".join(stemTuple), sim)
			relationalSim = reduce(lambda x,y: x+y, simsBiggerOriginalsSim
				)/float(len(simsBiggerOriginalsSim))
			choicePair.setSimilarity(self.stem.original, relationalSim)

	def setLRAResult(self):
		"""Calculates the relational similarity of every choice to the stem
		pair, and sets the lraChoice attribute to the one choice WordPair
		with the highest score.
		"""
		biggestSim = 0
		for choicePair in self.choices:
			sim = choicePair.getSimilarity(self.stem.original)
			print "%25s - %f"%(choicePair, sim)
			if sim > biggestSim:
				biggestSim = sim
				self.lraChoice = choicePair

	def getLRAResult(self):
		"""Returns True if the selection that LRA made matches the actual
		right choice, and False othdeerwise.
		"""
		if self.lraChoice.original == ("none","none"):
			self.setLRAResult()
		print "Right choice - %s :: %s"%(self.stem, self.rightChoice)
		print "LRA choice - %s :: %s"%(self.stem, self.lraChoice)
		return self.rightChoice == self.lraChoice


def getTaskListFromFile(fName):
	"""Returns a list of task objects. Takes a filename of a textfile, which
	specifies these tasks. For a description of the task file format see
	accompanying README.
	"""
	import re
	f = open(fName)
	filecontent = f.read().replace("\r","")
	f.close()
	taskList = []
	for taskStr in re.sub("\n\n(\n)+","\n\n",filecontent).split("\n\n"):
		stem = taskStr.split("\n")[0]
		choices = [choice for choice in taskStr.split("\n")[1:] if choice]
		rightChoice = ""
		for pos, c in enumerate(choices):
			if "*" in c:
				if rightChoice:
					msg = "more than one right choice (stem: '%s')"%stem
					raise LRAInputError, msg
				rightChoice = c.replace("*","")
				choices[pos] = rightChoice
		taskList.append(Task(stem, choices, rightChoice))
	return taskList


def getLRASuccessList(taskFile):
	"""Solves all analogy tasks in <taskFile>, and returns a list of
	successes. If a task was solves	correctly, True is appendend to the
	success list, and False otherwise.
	To follow each step of the algorithm, compare them with the document
	mentioned in line 11.
	"""
	taskList = getTaskListFromFile(taskFile)
	print "Latent Relational Analysis for %i analogy tasks"%len(taskList)
	successList = []

	skippedTasks = []
	phraseList = []
	for analogyTask in taskList:
		if USE_ALTERNATES:
			# LRA step 1 & 2
			print "finding & filtering alternates..."
			analogyTask.setAllAlternates(NUM_SIM, NUM_FILTER,
				MIN_INTER, MAX_INTER)
		# LRA step 3
		print "finding phrases..."
		analogyTask.setPhraseCache(MIN_INTER, MAX_INTER)
		try:
			analogyTask.deleteFutureZeroVectors()
		except LRAInputError:
			skippedTasks.append(analogyTask)
		phraseList.extend(analogyTask.getAllCachedPhrases())
	for skippedTask in skippedTasks:
		print "no phrases for stem '%s' -> skipping task"%skippedTask.stem
		del taskList[taskList.index(skippedTask)]
	# LRA step 4
	print "finding patterns..."
	patterns = search.getPatterns(phraseList, NUM_PATTERNS)
	print "  nr of patterns:", len(patterns)
	vectorLength = len(patterns) * 2
	# LRA step 5, 6 & 7
	print "generating matrix..."
	lraMatrix = matrix.LRAMatrix()
	for analogyTask in taskList:
		for word1,word2 in analogyTask.getAllPairTuples():
			vectorHalf1 = matrix.getEmptyVector(vectorLength / 2)
			vectorHalf2 = matrix.getEmptyVector(vectorLength / 2)
			for idx, pat in enumerate(patterns):
				vectorHalf1[idx]=se.countPatternForWordPair(pat,word1,word2)
				vectorHalf2[idx]=se.countPatternForWordPair(pat,word2,word1)
			lraMatrix.setWordPairVector((word1,word2),
				matrix.mergeVectors(vectorHalf1, vectorHalf2))
			lraMatrix.setWordPairVector((word2,word1),
				matrix.mergeVectors(vectorHalf2, vectorHalf1))
	if USE_ENTROPY:
		# LRA step 8
		print "applying entropy & log transformations..."
		lraMatrix.applyEntropy()
	if USE_SVD:
		# LRA step 9 & 10
		print "applying svd and projection..."
		lraMatrix.applySVD(k)
	for analogyTask in taskList:
		# LRA step 11
		print "evaluating (alternates)..."
		analogyTask.compareAllChoicesWithStem(lraMatrix)
		# LRA step 12
		print "calculating relational similarity..."
		successList.append(analogyTask.getLRAResult())
	return successList


def printLRASuccessRate(lraSuccessList):
	"""Prints LRAs percentage of success.
	"""
	taskNr = len(lraSuccessList)
	correct = lraSuccessList.count(True)
	if taskNr > 0:
		resultParams = (correct, taskNr, (float(correct) / taskNr) * 100)
		print "\nLRA overall success rate (%d/%d tasks): %f%%"%resultParams
	else:
		print "\nLRA has skipped every task"

if __name__ == "__main__":
	import time
	if len(sys.argv) == 3:
		taskFile = sys.argv[1]
		corpusDir = sys.argv[2]
		se = search.SearchEngine(corpusDir)
	else:
		print "usage: python lra.py <task_file> <corpus_directory>"
		print "see README for further information"
		sys.exit(1)

	startTime = time.time()
	try:
		import psyco # Import Psyco if available
		psyco.full()
	except ImportError:
		pass
	printLRASuccessRate(getLRASuccessList(taskFile))
	print "LRA time:", time.time() - startTime, "sec"
