import operator
from xml.dom.minidom import Node,Document
from psychsim.pwl.keys import CONSTANT
from psychsim.pwl.vector import KeyedVector
from psychsim.probability import Distribution
from psychsim.action import ActionSet
[docs]class KeyedPlane:
"""
String indexable hyperplane class
:ivar vector: the weights for the hyperplane
:type vector: L{KeyedVector}
:ivar threshold: the threshold for the hyperplane
:type threshold: float
:ivar comparison: if 1, value must be above hyperplane; if -1, below; if 0, equal (default is 1)
:type comparison: int
"""
DEFAULT_THRESHOLD = 0.
DEFAULT_COMPARISON = 1
COMPARISON_MAP = ['==','>','<']
def __init__(self,planes,threshold=None,comparison=None):
"""
:warning: if S{planes} is a list, then S{threshold} and S{comparison} are ignored
"""
if isinstance(planes,Node):
self.parse(planes)
elif isinstance(planes,KeyedVector):
# Only a single branch passed
if threshold is None:
threshold = self.DEFAULT_THRESHOLD
if comparison is None:
comparison = self.DEFAULT_COMPARISON
self.planes = [(planes,threshold,comparison)]
else:
self.planes = []
for plane in planes:
if len(plane) == 3:
self.planes.append(plane)
elif len(plane) == 2:
self.planes.append((plane[0],plane[1],self.DEFAULT_COMPARISON))
elif len(plane) == 1:
self.planes.append((plane[0],self.DEFAULT_THRESHOLD,self.DEFAULT_COMPARISON))
else:
raise ValueError('Empty plane passed into constructor')
self._string = None
self._keys = None
self.isConjunction = True
def __add__(self,other):
raise DeprecationWarning('Addition is ambiguous between con/disjunction. Replace with & or |, as desired.')
def __and__(self,other):
"""
:warning: Does not check for duplicate planes
"""
if len(self.planes) > 1:
assert self.isConjunction, f'Cannot conjoin disjunctive planes: {self}'
if len(other.planes) > 1:
assert other.isConjunction, f'Cannot conjoin disjunctive planes: {other}'
result = self.__class__(self.planes+other.planes)
result.isConjunction = True
return result
def __or__(self,other):
"""
:warning: Does not check for duplicate planes
"""
if len(self.planes) > 1:
assert not self.isConjunction, f'Cannot disjoin conjunctive planes: {self}'
if len(other.planes) > 1:
assert not other.isConjunction, f'Cannot disjoin conjunctive planes: {other}'
result = self.__class__(self.planes+other.planes)
result.isConjunction = False
return result
[docs] def keys(self):
if self._keys is None:
self._keys = set()
for vector,threshold,comparison in self.planes:
self._keys |= set(vector.keys())
return self._keys
[docs] def possible(self,variables):
for plane,threshold,comparison in self.planes:
lo = 0.
hi = 0.
for key in plane.keys():
if key not in variables:
# Ambiguous
return None
elif variables[key]['domain'] is set or variables[key]['domain'] is list:
# Symbolic
return None
if plane[key] > 0.:
lo += plane[key]*variables[key]['lo']
hi += plane[key]*variables[key]['hi']
else:
lo -= plane[key]*variables[key]['hi']
hi -= plane[key]*variables[key]['lo']
values = []
if self.COMPARISON_MAP[comparison] == '>':
print(plane)
print(lo,hi)
print(threshold)
if isinstance(threshold,list):
for i in range(len(threshold)):
if hi <= threshold[i]:
break
elif i > 0 and lo < threshold[i]:
values.append(i-1)
elif lo > threshold:
print('Always true')
return [True]
elif hi <= threshold:
print('Always false')
return [False]
else:
print('Who knows?')
return [True,False]
else:
raise NotImplementedError('Not yet handling comparisons other than >, but it is really easy to do')
[docs] def evaluate(self,vector,p_index=None):
"""
Tests whether the given vector passes or fails this test.
Also accepts a numeric value, in lieu of doing a dot product.
:param p_index: evaluate against only the given plane (not all of them in case this is a con/disjunctive branch)
:type p_index: int
:rtype: bool
:warning: If multiple planes are present, an AND over their results is assumed
"""
result = None
if p_index is None:
planes = self.planes
else:
planes = [self.planes[p_index]]
for plane,threshold,comparison in planes:
if isinstance(vector,float):
total = vector
else:
total = plane * vector
if isinstance(total,Distribution):
assert len(total) == 1,'Unable to handle uncertain test results'
total = total.first()
if comparison > 0:
if isinstance(threshold,list):
for index in range(len(threshold)):
if total <= threshold[index]:
return index
else:
return len(threshold)
elif total > threshold:
if len(planes) > 0 and not self.isConjunction:
# Disjunction, so any positive result is sufficient
return True
elif len(planes) > 0 and self.isConjunction:
# Conjunction, so any negative result is sufficient
return False
elif comparison < 0:
if isinstance(threshold,list):
for index in range(len(threshold)):
if total < threshold[index]:
return index
else:
return len(threshold)
elif total < threshold:
if len(planes) > 0 and not self.isConjunction:
# Disjunction, so any positive result is sufficient
return True
elif len(planes) > 0 and self.isConjunction:
# Conjunction, so any negative result is sufficient
return False
elif comparison == 0:
if isinstance(threshold,list):
for index in range(len(threshold)):
if abs(total-threshold[index]) < plane.epsilon:
# Disjunction, so any positive result is sufficient
return index
else:
return None
elif isinstance(threshold,set):
for t in threshold:
if abs(total-t) < plane.epsilon:
# Disjunction, so any positive result is sufficient
return True
else:
if len(planes) > 0 and self.isConjunction:
# Conjunction, so any negative result is sufficient
return False
else:
if abs(total-threshold) < plane.epsilon:
if len(planes) > 0 and not self.isConjunction:
# Disjunction, so any positive result is sufficient
return True
elif len(planes) > 0 and self.isConjunction:
# Conjunction, so any negative result is sufficient
return False
else:
# Return raw value, to be used in unspeakable ways
raise ValueError('Invalid comparison %s' % (comparison))
else:
# No planes matched
if len(planes) > 0 and self.isConjunction:
return True
else:
return False
[docs] def desymbolize(self,table,debug=False):
planes = [(p[0].desymbolize(table), self.desymbolizeThreshold(p[1],table),p[2])
for p in self.planes]
result = self.__class__(planes)
result.isConjunction = self.isConjunction
return result
[docs] def desymbolizeThreshold(self,threshold,table):
if isinstance(threshold,str):
try:
return eval(threshold,globals(),table)
except NameError:
# Undefined reference: assume it'll get sorted out later
return threshold
elif isinstance(threshold,list):
return [self.desymbolizeThreshold(t,table) for t in threshold]
elif isinstance(threshold,set):
return {self.desymbolizeThreshold(t,table) for t in threshold}
elif isinstance(threshold,ActionSet):
return table[threshold]
else:
return threshold
[docs] def makeFuture(self,keyList=None):
"""
Transforms this plane to refer to only future versions of its columns
:param keyList: If present, only references to these keys are made future
"""
self.changeTense(True,keyList)
[docs] def makePresent(self,keyList=None):
"""
Transforms this plane to refer to only current versions of its columns
:param keyList: If present, only references to these keys are made present
"""
self.changeTense(False,keyList)
[docs] def changeTense(self,future=True,keyList=None):
if keyList is None:
keyList = self.keys()
planes = []
for plane,threshold,comparison in self.planes:
plane.changeTense(future,keyList)
self._keys = None
self._string = None
# return self.__class__(planes)
[docs] def scale(self,table):
vector = self.vector.__class__(self.vector)
threshold = self.threshold
symbolic = False
span = None
assert CONSTANT not in vector,'Unable to scale hyperplanes with constant factors. Move constant factor into threshold.'
for key in vector.keys():
if table.has_key(key):
assert not symbolic,'Unable to scale hyperplanes with both numeric and symbolic variables'
if span is None:
span = table[key]
threshold /= float(span[1]-span[0])
else:
assert table[key] == span,'Unable to scale hyperplanes when the variables have different ranges (%s != %s)' % (span,table[key])
threshold -= vector[key]*span[0]/(span[1]-span[0])
else:
assert span is None,'Unable to scale hyperplanes with both numeric and symbolic variables'
symbolic = True
return self.__class__(vector,threshold)
def __eq__(self,other):
if not isinstance(other,KeyedPlane):
return False
if len(self.planes) != len(other.planes):
return False
for plane in other.planes:
if plane not in self.planes:
return False
else:
return True
[docs] def compare(self,other,value):
"""
Identifies any potential conflicts between two hyperplanes
:return: C{None} if no conflict was detected, C{True} if the tests are redundant, C{False} if the tests are conflicting
:warning: correct, but not complete
"""
assert len(self.planes) == 1,'Unable to compare branches with multiple tests'
assert len(other.planes) == 1,'Unable to compare branches with multiple tests'
myVec,myThresh,myComp = self.planes[0]
yrVec,yrThresh,yrComp = other.planes[0]
if myVec == yrVec:
if myComp == 0:
if yrComp == 0:
# Both are equality tests
if isinstance(myThresh,set):
if isinstance(yrThresh,set):
if myThresh == yrThresh and value is True:
return True
# TODO: There are more cases here
return None
elif isinstance(yrThresh,list):
raise NotImplementedError
else:
if yrThresh in myThresh:
# This equality test is one of my acceptable values
if value is True:
return True
return None
elif isinstance(yrThresh,set):
if myThresh in yrThresh:
# Not in a set that includes my acceptable value
if value is False:
return False
return None
elif isinstance(yrThresh,list):
try:
return yrThresh.index(myThresh)
except ValueError:
return None
elif isinstance(myThresh,list):
try:
return myThresh.index(yrThresh)
except ValueError:
return None
elif abs(myThresh - yrThresh) < myVec.epsilon:
# Values are the same, so test results must be the same
return value
elif value:
# Values are different, but we are equal to the other value
return False
else:
# Values are different, but not equal to other, so no information
return None
elif cmp(myThresh,yrThresh) == yrComp:
# Our value satisfies other's inequality
if value:
# So no information in this case
return None
else:
# But we know we are invalid in this one
return False
else:
# Our value does not satisfy other's inequality
if value:
# We know we are invalid
return False
else:
# And no information
return None
elif yrComp == 0:
# Other specifies equality, we are inequality
if value:
# Determine whether equality condition satisfies our inequality
return cmp(yrThresh,myThresh) == myComp
else:
# No information about inequality
return None
elif myThresh == yrThresh and myComp == yrComp:
# Identical planes
return value
else:
# Both inequalities, we should do something here
return None
return None
[docs] def minimize(self):
"""
"""
for i in range(len(self.planes)):
vector,threshold,comparison = self.planes[i]
if CONSTANT in vector:
if isinstance(threshold,list):
threshold = [value-vector[CONSTANT] for value in threshold]
elif isinstance(threshold,set):
threshold = {value-vector[CONSTANT] for value in threshold}
else:
threshold -= vector[CONSTANT]
vector[CONSTANT] = 0.
self.planes[i] = (vector,threshold,comparison)
self._string = None
def __str__(self):
if self._string is None:
if self.isConjunction:
operator = '\nAND '
else:
operator = '\nOR '
self._string = operator.join(['%s %s %s' % (vector.hyperString(),self.COMPARISON_MAP[comparison],threshold)
for vector,threshold,comparison in self.planes])
return self._string
def __xml__(self):
doc = Document()
root = doc.createElement('plane')
for vector,threshold,comparison in self.planes:
node = vector.__xml__().documentElement
node.setAttribute('threshold',str(threshold))
node.setAttribute('comparison',str(comparison))
root.appendChild(node)
doc.appendChild(root)
return doc
[docs] def parse(self,element):
assert element.tagName == 'plane'
node = element.firstChild
self.planes = []
while node:
if node.nodeType == node.ELEMENT_NODE:
assert node.tagName == 'vector'
vector = KeyedVector(node)
text = node.getAttribute('threshold')
if text[0] == '[':
if '.' in text:
threshold = [float(t) for t in text[1:-1].split(',')]
else:
threshold = [int(t) for t in text[1:-1].split(',')]
elif text[0] == '{':
if '.' in text:
threshold = {float(t) for t in text[1:-1].split(',')}
else:
threshold = {int(t) for t in text[1:-1].split(',')}
elif '.' in text:
threshold = float(text)
else:
threshold = int(text)
try:
comparison = int(node.getAttribute('comparison'))
except ValueError:
comparison = str(node.getAttribute('comparison'))
self.planes.append((vector,threshold,comparison))
node = node.nextSibling
[docs]def thresholdRow(key, threshold):
"""
:return: a plane testing whether the given keyed value exceeds the given threshold
:rtype: L{KeyedPlane}
"""
return KeyedPlane(KeyedVector({key: 1}), threshold)
[docs]def differenceRow(key1, key2, threshold):
"""
:return: a plane testing whether the difference between the first and second keyed values exceeds the given threshold
:rtype: L{KeyedPlane}
"""
return KeyedPlane(KeyedVector({key1: 1, key2: -1}), threshold)
[docs]def greaterThanRow(key1, key2):
"""
:return: a plane testing whether the first keyed value is greater than the second
:rtype: L{KeyedPlane}
"""
return differenceRow(key1, key2, 0)
[docs]def trueRow(key):
"""
:return: a plane testing whether a boolean keyed value is True
:rtype: L{KeyedPlane}
"""
return thresholdRow(key, 0.5)
[docs]def falseRow(key):
"""
:return: a plane testing whether a boolean keyed value is False
:rtype: L{KeyedPlane}
"""
return KeyedPlane(KeyedVector({key: 1}), 0.5, -1)
[docs]def andRow(trueKeys=[], falseKeys=[]):
"""
:param trueKeys: list of keys which must be C{True} (default is empty list)
:type trueKeys: str[]
:param falseKeys: list of keys which must be C{False} (default is empty list)
:type falseKeys: str[]
:return: a plane testing whether all boolean keyed values are set as desired
:rtype: L{KeyedPlane}
"""
weights = {}
for key in trueKeys:
weights[key] = 1
for key in falseKeys:
weights[key] = -1
return KeyedPlane(KeyedVector(weights), len(trueKeys)-0.5)
[docs]def equalRow(key, value):
"""
:type key: str or str[] or str->float/int
:return: a plane testing whether the given keyed value (or sum of keyed values) equals the given target value
:rtype: L{KeyedPlane}
"""
if isinstance(key, list):
return KeyedPlane(KeyedVector({k: 1 for k in key}), value, 0)
elif isinstance(key, dict):
return KeyedPlane(KeyedVector(key), value, 0)
else:
return KeyedPlane(KeyedVector({key: 1}), value, 0)
[docs]def equalFeatureRow(key1, key2):
"""
:return: a plane testing whether the values of the two given features are equal
:rtype: L{KeyedPlane}
"""
return KeyedPlane(KeyedVector({key1: 1, key2: -1}), 0, 0)