# Copyright 2009 Ben Escoto
#
# This file is part of Explicans.

# Explicans is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# Explicans is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with Explicans.  If not, see <http://www.gnu.org/licenses/>.

"""relref.py - Support relative references

Relative references show up in two related contexts:

1. A cell formula refer to the previous or next cell, for instance using "prev"
or "next". This requires support when evaluating the formula itself.

2. Also, relative references may involve already created objects, for instance
a.next(2) if a is refers to another cell. This requires attaching relative
reference information to the a object.

"""

import types
import objects, lazyarray

class RelRefInfo:
	"""Hold information necessary to do relative referencing
	
	To use this class, call get_method() or get_name_binding().
	"""
	def __init__(self):
		self.methods = {} # map from pystrings to (scalar_list, pyfuncs)
		self.name_bindings = {} # map from pystrings to (scalar_list, pyfunc)

	def get_method(self, method_name):
		"""Return ExObject corresponding to the relref method, or None
		
		For instance, me.prev(2) or a.prev(2) both go to the cell two previous
		from me or a. If self is associated with those objects,
		self.get_method('prev') returns an ExFunc from ExNum to values.
		"""
		return self.lookup_val(method_name, self.methods)

	def get_name_binding(self, name):
		"""Return ExObject associated with the local name, or None
		
		For instance, get_name_binding('prev') might return the previous cell.
		"""
		return self.lookup_val(name, self.name_bindings)

	def lookup_val(self, key, d):
		"""Look up a name or method, returning None if not found"""
		try: scalar_list, func = d[key]
		except KeyError: return None
		if len(scalar_list) == 0: return func()
		return objects.ExFunc(func, scalar_list)
		
	def set_method(self, method_name, scalar_list, py_func):
		"""Associate method_name with given function"""
		assert not isinstance(method_name, objects.ExObject), method_name
		self.methods[method_name] = (scalar_list, py_func)
		
	def set_name_binding(self, name, scalar_list, py_func):
		"""Let name bind then given function or object"""
		assert not isinstance(name, objects.ExObject), name
		self.name_bindings[name] = (scalar_list, py_func)
		
	def makeme(self):
		"""Make a blank 'me' object with given relative reference info"""
		me = objects.ExBlank()
		me.set_relref(self)
		return me

	def set_standard_values(self):
		"""Set the typical methods and names ('next', 'prev', 'index', etc.)"""
		def next_func(exnum): return self.next(int(exnum.obj))
		self.set_method('next', (True,), next_func)
		def prev_func(exnum): return self.next(-int(exnum.obj))
		self.set_method('prev', (True,), prev_func)
		self.set_method('up', (), self.up)
		self.set_method('index', (), self.get_index)
		self.set_method('name', (), self.get_name)

		self.set_name_binding('me', (), self.makeme)
		def next_thunk(): return self.next(1)
		self.set_name_binding('next', (), next_thunk)
		def prev_thunk(): return self.next(-1)
		self.set_name_binding('prev', (), prev_thunk)


class RelRefCell(RelRefInfo):
	"""Standard case --- relative references for a cell in a list"""
	def __init__(self, root_la, absref):
		"""Initialize with root lazyarray for referencing and my absolute ref"""
		RelRefInfo.__init__(self)
		self.root_la = root_la
		self.my_absref = absref
		self.set_standard_values()
		
	def next(self, pynum):
		"""Return the nth element after the current one"""
		return self.root_la.deref_absref(self.my_absref.next(pynum))

	def up(self):
		"""Return parent node using absolute references"""
		return self.root_la.deref_absref(self.my_absref.parent())

	def get_name(self):
		"""Return name of current cell"""
		return self.root_la.deref_absref_key(self.my_absref)

	def get_index(self):
		"""Return ExNum of current index.  Start counting from 1!"""
		return objects.ExNum(self.my_absref.last()+1)


class RelRefColumn(RelRefInfo):
	"""Relative references for column formulas"""
	def __init__(self, root_la, my_absref, column_len, me_pointer):
		"""Initializer
		
		root_la and my_absref are as in RelRefCell. column_len is the length of
		the array that the column formula is used on. me_pointer is a pointer to
		the final array, in the form of a singleton list.
		
		"""
		RelRefInfo.__init__(self)
		self.root_la, self.my_absref = root_la, my_absref
		self.column_length = column_len
		self.me_pointer = me_pointer
		self.set_standard_values()
		
	def next(self, pynum):
		"""Return the current list, but shifted by pynum elements"""
		def helper(n):
			"""Return thunk that looks up parent at spot n"""
			def subhelper(): return self.me_pointer[0].obj.get_thunk(n)()
			return subhelper

		result = lazyarray.LazyArray(self.column_length)
		for i in range(self.column_length):
			if 0 <= i+pynum < self.column_length:
				result.set_value(i, helper(i+pynum))
			else: result.set_value(i, lambda: objects.Blank)
		return objects.ExArray(result)

	def up(self):
		"""Return parent node using absolute references"""
		return self.root_la.deref_absref(self.my_absref.parent())

	def get_name(self):
		"""Return an array with the names of parent as values"""
		def helper(n):
			"""Return thunk that looks up name at spot n"""
			def subhelper(): return self.me_pointer[0].obj.get_key(n)
			return subhelper

		result = lazyarray.LazyArray(self.column_length)
		for i in range(self.column_length):
			result.set_value(i, helper(i))
		return objects.ExArray(result)

	def get_index(self):
		"""Return an array of indicies from parent"""
		result = lazyarray.LazyArray(self.column_length)
		val_list = [objects.ExNum(i) for i in range(1, self.column_length+1)]
		result.set_values_precomputed(val_list)
		return objects.ExArray(result)


class TempTableAxisValue(objects.ExBlank):
	"""Used for Table evaluation below as a standin for the Axis Value"""
	pass
	
class RelRefTable(RelRefInfo):
	def set_standard_values(self):
		"""Set the typical methods and names ('next', 'prev', 'index', etc.)"""
		def next_func(axis, exnum): return self.next(axis, int(exnum.obj))
		self.set_method('next', (False, True), next_func)
		def prev_func(axis, exnum): return self.next(axis, -int(exnum.obj))
		self.set_method('prev', (False, True), prev_func)
		self.set_method('up', (), self.up)
		self.set_method('index', (), self.get_index)
		self.set_method('name', (), self.get_name)

		self.set_name_binding('me', (), self.makeme)
		def next_thunk(axis): return self.next(axis, 1)
		self.set_name_binding('next', (False,), next_thunk)
		def prev_thunk(axis): return self.next(axis, -1)
		self.set_name_binding('prev', (False,), prev_thunk)


class TableAxisNames(RelRefTable):
	"""This class is used to evalute the row/column formulas for names
	
	Because the values of the table depend on the names in the rows and columns,
	these formulas must be evaluated before the table is created.
	"""
	def __init__(self, root_la, my_absref, column_len,
				 table_axis_name, temp_axis_value, me_pointer):
		RelRefTable.__init__(self)
		self.root_la, self.my_absref = root_la, my_absref
		self.column_length = column_len
		self.me_pointer = me_pointer
		self.axis_name, self.axis_value = table_axis_name, temp_axis_value

		self.set_standard_values()
		def helper(): return self.axis_value
		self.set_name_binding(self.axis_name, (), helper)
		
	def next(self, axis, pynum):
		"""Return the current list, but shifted by n elements"""
		def helper(n):
			"""Return thunk that looks up parent at spot n"""
			def subhelper(): return self.me_pointer[0].obj.get_thunk(n)()
			return subhelper

		assert axis is self.axis_value, (axis, self.axis_value)
		result = lazyarray.LazyArray(self.column_length)
		for i in range(self.column_length):
			if 0 <= i+pynum < self.column_length:
				result.set_value(i, helper(i+pynum))
			else: result.set_value(i, lambda: objects.Blank)
		return objects.ExArray(result)

	def up(self):
		"""Return parent node using absolute references"""
		return self.root_la.deref_absref(self.my_absref.parent())
	
	def get_name(self): assert False, "Names have no names!"
	
	def get_index(self):
		XXXX # fix me later


class TableColumns(RelRefTable):
	"""Relative referencing info for table column formulas"""
	def __init__(self, root_la, table_ar, table, colnum):
		RelRefTable.__init__(self)
		self.root_la, self.table_ar = root_la, table_ar
		self.table = table
		self.colnum = colnum
		self.set_standard_values()
		self.bind_col_names()
		self.bind_axis_names()

	def next(self, axis_obj, pynum):
		"""Return column pynum steps along the axis given by axis_obj"""
		if axis_obj.axis_flag is self.table.row_axis_flag:
			return self.shift_col(pynum)
		elif axis_obj.axis_flag is self.table.col_axis_flag:
			return self.get_delayed_column(self.colnum + pynum)
		else: assert False, ("Unknown axis", axis_obj)
		
	def get_delayed_column(self, j):
		"""Return an ExArray constructed from the future state of column j"""
		def thunk_maker(i):
			def thunk(): return self.table.get_value_thunk(i,j)()
			return thunk
		
		length = self.table.get_num_rows()
		la = lazyarray.LazyArray(length)
		for i in range(length):
			la.set_value(i, thunk_maker(i))
		return objects.ExArray(la)

	def shift_col(self, pynum):
		"""Return the current column shifted pynum steps up"""
		def helper(n):
			"""Return thunk that returns value in row n of current column"""
			def subhelper(): return self.table.get_value_thunk(n, self.colnum)()
			return subhelper
		def get_blank(): return objects.ExBlank()
		
		length = self.table.get_num_rows()
		result = lazyarray.LazyArray(length)
		for i in range(length):
			if 0 <= i+pynum < length:
				result.set_value(i, helper(i+pynum))
			else: result.set_value(i, get_blank)
		return objects.ExArray(result)

	# Fix these later!
	def up(self): XXXX
	def get_name(self): XXXX
	def get_index(self): XXXX

	def bind_col_names(self):
		"""Bind references to row/column names"""
		def thunk_maker(j):
			"""Return column j as an ExArray"""
			def thunk(): return self.get_delayed_column(j)
			return thunk
		
		for j in range(self.table.get_num_cols()):
			name = self.table.colnames[j]
			if name.get_type == 'blank': continue
			self.set_name_binding(name.obj, (), thunk_maker(j))

	def bind_axis_names(self):
		"""Bind names of the two axes to current row/col name"""
		def row_axis_thunk():
			la = lazyarray.LazyArray(self.table.get_num_rows())
			la.set_values_precomputed(self.table.rownames)
			exa = objects.ExArray(la)
			exa.axis_flag = self.table.row_axis_flag
			return exa
		self.set_name_binding(self.table.row_axis_name, (), row_axis_thunk)

		def col_axis_thunk():
			val = self.table.colnames[self.colnum]
			val.axis_flag = self.table.col_axis_flag
			return val
		self.set_name_binding(self.table.col_axis_name, (), col_axis_thunk)


class TableCellValue(RelRefTable):
	"""Relative reference info for table values"""
	def __init__(self, root_la, table_ar, table, rownum, colnum):
		"""Initializer
		
		root_la - lazy array with root names for looking up references
		table_ar - AbsoluteReference of the parent table
		table - the parent table itself
		rownum - the python row number of the cell
		colnum - the python column number of the cell
		"""
		RelRefTable.__init__(self)
		self.root_la, self.table_ar = root_la, table_ar
		self.table = table
		self.rownum, self.colnum = rownum, colnum
		self.set_standard_values()
		self.bind_rowcol_names()
		self.bind_axis_names()

	def next(self, axis_obj, pynum):
		"""Return the value pynum spaces in the table along the given axis"""
		if axis_obj.axis_flag is self.table.row_axis_flag:
			new_rownum = self.rownum + pynum
			assert 0 <= new_rownum < self.table.get_num_rows()
			return self.table.get_value_thunk(new_rownum, self.colnum)()
		elif axis_obj.axis_flag is self.table.col_axis_flag:
			new_colnum = self.colnum + pynum
			assert 0 <= new_colnum < self.table.get_num_cols()
			return self.table.get_value_thunk(self.rownum, new_colnum)()
		else: assert False, ("Unknown axis", axis_obj)

	# Fix these later!
	def up(self): XXXX
	def get_name(self): XXXX
	def get_index(self): XXXX

	def bind_rowcol_names(self):
		"""Bind references to row/column names"""
		# This whole function is horribly inefficient
		def thunk_lookup(i,j):
			"""Return thunk at row i, col j.  Delays lookup until eval"""
			def thunk(): return self.table.get_value_thunk(i,j)()
			return thunk
		
		for j in range(self.table.get_num_cols()):
			name = self.table.colnames[j]
			if name.get_type == 'blank': continue
			self.set_name_binding(name.obj, (), thunk_lookup(self.rownum, j))
		for i in range(self.table.get_num_rows()):
			name = self.table.rownames[i]
			if name.get_type == 'blank': continue
			self.set_name_binding(name.obj, (), thunk_lookup(i, self.colnum))

	def bind_axis_names(self):
		"""Bind names of the two axes to current row/col name"""
		def row_axis_thunk():
			val = self.table.rownames[self.rownum]
			val.axis_flag = self.table.row_axis_flag
			return val
		self.set_name_binding(self.table.row_axis_name, (), row_axis_thunk)

		def col_axis_thunk():
			val = self.table.colnames[self.colnum]
			val.axis_flag = self.table.col_axis_flag
			return val
		self.set_name_binding(self.table.col_axis_name, (), col_axis_thunk)

