import logging
from typing import Any, Dict
from psychsim.probability import Distribution
from psychsim.action import Action
from psychsim.pwl.keys import CONSTANT,makeFuture,makePresent,isFuture
from psychsim.pwl.vector import KeyedVector
from psychsim.pwl.matrix import KeyedMatrix,setToConstantMatrix
from psychsim.pwl.plane import KeyedPlane,equalRow
[docs]class KeyedTree:
"""
Decision tree node using symbolic PWL structures
@ivar leaf: C{True} iff this node is a leaf
@type leaf: bool
@ivar children: table of children of this node
@type children: dict
@ivar branch: the hyperplane branch at this node (if applicable)
@type branch: L{KeyedPlane}
"""
def __init__(self, leaf=None, plane=None, children=None):
self._string = None
self._keysIn = None
self._keysOut = None
if plane is not None:
self.makeBranch(plane, children)
elif isinstance(children, Distribution):
self.makeProbabilistic(children)
else:
self.makeLeaf(leaf)
[docs] def isLeaf(self):
return self.leaf
[docs] def getLeaf(self):
return self.children[None]
[docs] def makeLeaf(self, leaf: Any):
self.children = {None: leaf}
self.leaf = True
self.branch = None
[docs] def makeBranch(self, plane: KeyedPlane, children: Dict[Any, Any]):
self.children = children
self.branch = plane
self.leaf = False
[docs] def makeProbabilistic(self, distribution: Distribution):
self.children = distribution
self.branch = None
self.leaf = False
[docs] def isProbabilistic(self):
"""
@return: C{True} if there is a probabilistic branch at this node
@rtype: bool
"""
return self.branch is None and not self.isLeaf()
[docs] def getKeysIn(self):
"""
@return: a set of all keys that affect the output of this PWL function
"""
if self._keysIn is None:
self._keysIn = set()
self._keysOut = set()
if self.isProbabilistic():
# Keys are taken from each child
children = self.children.domain()
else:
children = self.children.values()
if not self.isLeaf():
# Keys also include those in the branch
self._keysIn |= set(self.branch.keys())
# Accumulate keys across children
for child in children:
if isinstance(child,KeyedVector):
self._keysIn |= set(child.keys())
elif not child is None and not isinstance(child,bool):
self._keysIn |= child.getKeysIn()
self._keysOut |= child.getKeysOut()
return self._keysIn
[docs] def getKeysOut(self):
"""
@return: a set of all keys that are affected by this PWL function
"""
if self._keysOut is None:
self.getKeysIn()
return self._keysOut
[docs] def keys(self):
return self.getKeysIn() | self.getKeysOut()
[docs] def collapseProbabilistic(self):
"""
Utility method that combines any consecutive probabilistic branches at this node into a single distribution
"""
if self.isProbabilistic():
collapse = False
distribution = Distribution(self.children)
for child in self.children.domain():
if child.isProbabilistic():
# Probabilistic branch to merge
collapse = True
child.collapseProbabilistic()
del distribution[child]
for grandchild in child.children.domain():
try:
distribution[grandchild] += self.children[child]*child.children[grandchild]
except KeyError:
distribution[grandchild] = self.children[child]*child.children[grandchild]
if collapse:
assert sum(distribution.values()) == 1.
self.makeProbabilistic(distribution)
def __getitem__(self,index):
if self.isLeaf():
return self.children[None]
elif self.branch is None:
# Probabilistic branch
result = {}
for element in self.children.domain():
prob = self.children[element]
subtree = element[index]
if isinstance(subtree,Distribution):
for subelement in subtree.domain():
try:
result[subelement] += prob*subtree[subelement]
except KeyError:
result[subelement] = prob*subtree[subelement]
else:
try:
result[subtree] += prob
except KeyError:
result[subtree] = prob
return Distribution(result)
else:
# Deterministic branch
subindex = self.branch.evaluate(index)
try:
child = self.children[subindex]
except KeyError:
logging.error('Missing child for case %s in tree:\n%s' % (subindex,self))
raise ValueError('Missing child for case %s in tree' % (subindex))
return child[index]
[docs] def desymbolize(self,table,debug=False):
"""
@return: a new tree with any symbolic references replaced with numeric values according to the table of element lists
@rtype: L{KeyedTree}
"""
tree = self.__class__()
if self.isLeaf():
leaf = self.children[None]
if isinstance(leaf, KeyedVector) or isinstance(leaf, KeyedMatrix):
tree.makeLeaf(leaf.desymbolize(table, debug))
else:
tree.makeLeaf(leaf)
elif self.branch:
tree.makeBranch(self.branch.desymbolize(table),
{value: self.children[value].desymbolize(table) \
for value in self.children})
else:
new_branch = Distribution([(child.desymbolize(table), prob) for child, prob in self.children.items()])
tree.makeProbabilistic(new_branch)
return tree
[docs] def floor(self,key,lo):
"""
Modify this tree to make sure the new computed value never goes lower than the given floor
@warning: may introduce redundant checks
"""
if self.isLeaf():
tMatrix = self.getLeaf()
assert len(tMatrix) == 1,'Unable to handle dynamics of more than one feature'
assert makeFuture(key) in tMatrix,'Are you sure you should be flooring me on a key I don\'t have?'
del self.children[None]
fMatrix = setToConstantMatrix(key,lo)
branch = KeyedPlane(KeyedVector(tMatrix[makeFuture(key)]),lo)
self.makeBranch(branch,{True: KeyedTree(tMatrix),
False: KeyedTree(fMatrix)})
elif self.branch:
for child in self.children.values():
child.floor(key,lo)
else:
for child in self.children.domain():
prob = self.children[child]
del self.children[child]
self[child.floor(key,lo)] = prob
return self
[docs] def ceil(self,key,hi):
"""
Modify this tree to make sure the new computed value never goes higher than the given ceiling
@warning: may introduce redundant checks
"""
if self.isLeaf():
fMatrix = self.children[None]
assert len(fMatrix) == 1,'Unable to handle dynamics of more than one feature'
assert makeFuture(key) in fMatrix,'Are you sure you should be ceiling me on a key I don\'t have?'
del self.children[None]
tMatrix = setToConstantMatrix(key,hi)
branch = KeyedPlane(KeyedVector(fMatrix[makeFuture(key)]),hi)
self.makeBranch(branch,{True: KeyedTree(tMatrix),
False: KeyedTree(fMatrix)})
elif self.branch:
for child in self.children.values():
child.ceil(key,hi)
else:
for child in self.children.domain():
prob = self.children[child]
del self.children[child]
self[child.ceil(key,hi)] = prob
return self
[docs] def makeFuture(self,keyList=None):
self.changeTense(True,keyList)
[docs] def makePresent(self,keyList=None):
self.changeTense(False,keyList)
[docs] def changeTense(self,future=True,keyList=None):
"""
Transforms this vector to refer to only future versions of its columns
@param keyList: If present, only references to these keys are made future
"""
if keyList is None:
keyList = self.keys()
if self.isProbabilistic():
for child in self.children.domain():
prob = self.children[child]
del self.children[child]
child.changeTense(future,keyList)
self.children[child] = prob
else:
if not self.isLeaf():
self.branch.changeTense(future,keyList)
for value,child in self.children.items():
child.changeTense(future,keyList)
self._string = None
self._keysIn = None
[docs] def scale(self,table):
tree = self.__class__()
if self.isLeaf():
tree.makeLeaf(self.children[None].scale(table))
elif self.branch:
tree.makeBranch(self.branch.scale(table),
{value: self.children[value].scale(table) for value in self.children})
else:
new = {}
for child in self.children.domain():
new[child.scale(table)] = self.children[child]
tree.makeProbabilistic(Distribution(new))
return tree
def __eq__(self,other):
if self.isLeaf():
if other.isLeaf():
return self.children[None] == other.children[None]
else:
return False
elif self.isProbabilistic():
if other.isProbabilistic():
return self.children == other.children
else:
return False
else:
if self.branch == other.branch:
return self.children == other.children
else:
return False
def __add__(self,other):
if isinstance(other,KeyedTree):
return self.compose(other,lambda x,y: x+y)
else:
return self+KeyedTree(other)
def __mul__(self,other):
if isinstance(other,KeyedTree):
return self.compose(other,lambda x,y: x*y,lambda x,y: x*y)
elif isinstance(other,KeyedVector):
return self[other]*other
elif isinstance(other,float) or isinstance(other,int):
return other*self
else:
raise NotImplementedError
def __rmul__(self,other):
if isinstance(other,float) or isinstance(other,int):
tree = self.__class__()
if self.isLeaf():
tree.makeLeaf(other*self.children[None])
elif self.isProbabilistic():
dist = {}
for child in self.children.domain():
prod = other*child
dist[prod] = dist.get(prod,0.)+self.children[child]
tree.makeProbabilistic(Distribution(dist))
else:
tree.makeBranch(self.branch,
{value: other*self.children[value] for value in self.children})
return tree
else:
raise NotImplementedError
[docs] def max(self,other):
return self.compose(other,self.__max)
def __max(self,leaf1,leaf2):
"""
Helper method for computing max
@return: a tree returning the maximum of the two vectors
@rtype: L{KeyedTree}
"""
result = self.__class__()
if leaf1 is False:
result.graft(leaf2)
elif leaf2 is False:
result.graft(leaf1)
else:
if isinstance(leaf1,tuple):
# (ER weights,action)
weights = leaf1[0] - leaf2[0]
else:
# Assume vectors
weights = leaf1 - leaf2
weights.prune()
if CONSTANT in weights:
threshold = -weights[CONSTANT]
del weights[CONSTANT]
else:
threshold = 0.
if len(weights) == 0:
if 0. > threshold:
# Must be true
result.graft(KeyedTree(leaf1))
else:
# Must be false
result.graft(KeyedTree(leaf2))
else:
alpha = weights.normalize()
result.makeBranch(KeyedPlane(weights,threshold*alpha),{True: KeyedTree(leaf1),False: KeyedTree(leaf2)})
return result
[docs] def compose(self, other, leafOp=None, planeOp=None):
"""
Compose two trees into a single tree
:param other: the other tree to be composed with
:type other: L{KeyedTree}
:param leafOp: the binary operator to apply to leaves of each tree to generate a new leaf
:param planeOp: the binary operator to apply to the plane
:rtype: L{KeyedTree}
"""
result = KeyedTree()
if other.isLeaf():
if self.isLeaf():
result.graft(leafOp(self.getLeaf(), other.getLeaf()))
elif self.isProbabilistic():
# Probabilistic branch
distribution = self.children.__class__()
for old in self.children.domain():
new = old.compose(other,leafOp,planeOp)
if isinstance(new,Distribution):
for tree in new.domain():
distribution.addProb(tree,self.children[old]*new[tree])
else:
distribution.addProb(new,self.children[old])
if len(distribution) > 1:
result.makeProbabilistic(distribution)
result.collapseProbabilistic()
else:
result.graft(new)
else:
# Deterministic branch
trees = {value: self.children[value].compose(other,leafOp,planeOp) \
for value in self.children}
protoTree = None
for tree in trees.values():
if protoTree is None:
protoTree = tree
elif tree != protoTree:
if planeOp is None or not isinstance(other.children[None],KeyedMatrix):
plane = self.branch
else:
plane = KeyedPlane([(planeOp(p,other.children[None]),t,c)
for p,t,c in self.branch.planes])
plane.minimize()
plane.isConjunction = self.branch.isConjunction
result.makeBranch(plane,trees)
break
else:
result.graft(protoTree)
elif other.isProbabilistic():
# Probabilistic branch
distribution = other.children.__class__()
for old in other.children.domain():
new = self.compose(old,leafOp,planeOp)
if isinstance(new,Distribution):
for tree in new.domain():
distribution.addProb(tree,other.children[old]*new[tree])
else:
distribution.addProb(new,other.children[old])
if len(distribution) > 1:
result.makeProbabilistic(distribution)
result.collapseProbabilistic()
else:
result.graft(new)
else:
# Deterministic branch
trees = {value: self.compose(other.children[value],leafOp,planeOp) \
for value in other.children}
protoTree = None
for tree in trees.values():
if protoTree is None:
protoTree = tree
elif tree != protoTree:
result.makeBranch(other.branch,trees)
break
else:
result.graft(protoTree)
return result
[docs] def replace(self,old,new):
"""
@return: a new tree with the given substitution applied to all leaf nodes
"""
return self.map(lambda leaf: new if leaf == old else leaf)
[docs] def expectation(self):
"""
@return: a new tree representing an expectation over any probabilistic branches
"""
return self.map(distOp=lambda branch: branch.expectation())
[docs] def map(self,leafOp=None,planeOp=None,distOp= None):
"""
Generates a new tree applying a function to all planes and leaves
@param leafOp: functional transformation of leaf nodes
@type leafOp: lambda XS{->}X
@param planeOp: functional transformation of hyperplanes
@type planeOp: lambda XS{->}X
@param distOp: functional transformation of probabilistic branches
@type distOp: lambda L{Distribution}S{->}X
@rtype: L{KeyedTree}
"""
result = self.__class__()
if self.isLeaf():
if leafOp:
leaf = leafOp(self.children[None])
else:
leaf = self.children[None]
result.graft(leaf)
elif self.isProbabilistic():
if distOp:
result.graft(distOp(self.children))
else:
distribution = self.children.__class__()
for old in self.children.domain():
new = old.map(leafOp,planeOp,distOp)
try:
distribution[new] += self.children[old]
except KeyError:
distribution[new] = self.children[old]
if len(distribution) > 1:
result.makeProbabilistic(distribution)
else:
result.graft(distribution.first())
else:
# Deterministic branch
if planeOp:
branch = planeOp(self.branch)
else:
branch = self.branch
first = None
children = {value: self.children[value].map(leafOp,planeOp,distOp) for value in self.children}
for child in children.values():
if first is None:
first = child
elif first != child:
# Not all children are identical
result.makeBranch(branch,children)
break
else:
# All children are the same, so branch is unnecessary
result.graft(first)
return result
[docs] def graft(self,root):
"""
Grafts a tree at the current node
@warning: clobbers anything currently at (or rooted at) this node
"""
if isinstance(root,Distribution):
self.makeProbabilistic(root)
elif isinstance(root,KeyedTree):
if root.isLeaf():
self.makeLeaf(root.children[None])
elif root.isProbabilistic():
self.makeProbabilistic(root.children)
else:
self.makeBranch(root.branch,{value: root.children[value] for value in root.children})
else:
# Leaf node (not a very smart use of graft, but who are we to judge)
self.makeLeaf(root)
[docs] def sampleLeaf(self, vector, mostlikely=False):
"""
:param mostlikely: if True, then only the most likely branches are chosen at each probabilistic branch
:type mostlikely: bool
:return: a leaf node sampled from the distribution over leaf nodes for the given vector
"""
if self.isLeaf():
return self.getLeaf(), 1
elif self.branch is None:
# Probabilistic branch
if mostlikely:
subtree = self.children.max()
subprob = self.children[subtree]
final_tree, final_prob = subtree.sampleLeaf(vector, mostlikely)
return final_tree, subprob*final_prob
else:
subtree, subprob = self.children.sample()
final_tree, final_prob = subtree.sampleLeaf(vector, mostlikely)
return final_tree, subprob*final_prob
else:
# Deterministic branch
return self.children[self.branch.evaluate(vector)].sampleLeaf(vector, mostlikely)
[docs] def sample(self, mostlikely=False, vector=None):
"""
:param mostlikely: if True, then only the most likely branches are chosen at each probabilistic branch
:type mostlikely: bool
:param vector: if provided, return the leaf node corresponding to the given possible world
:type vector: KeyedVector
:return: a tree sampled from all of the probabilistic branches
"""
if vector is None:
raise ValueError('Unable to sample from tree without a state vector, as the sample likelihood is too hard to compute')
return self.sampleLeaf(vector, mostlikely)
# result = self.__class__()
# if self.isLeaf():
# result.makeLeaf(self.getLeaf())
# prob = 1
# elif self.isProbabilistic():
# if mostlikely:
# child = self.children.max()
# prob = self.children[child]
# else:
# child, prob = self.children.sample()
# subtree, subprob = child.sample(mostlikely)
# result.graft(subtree)
# prob *= subprob
# else:
# subtrees = [(value, self.children[value].sample(mostlikely) for value in self.children)]
# result.makeBranch(self.branch, {value: subtree[0] for value, subtree in subtrees})
# prob = 1
# return result, prob
[docs] def prune(self, path=[], variables={}):
"""
Removes redundant branches
:warning: correct, but not necessarily complete
"""
result = self.__class__()
if self.isLeaf():
# Leaves are unchanged
result.makeLeaf(self.children[None])
elif self.isProbabilistic():
# Distributions are passed through
distribution = self.children.__class__()
for tree in self.children.domain():
prob = self.children[tree]
child = tree.prune(path,variables)
if child.isProbabilistic():
for grandchild in child.children.domain():
subprob = child.children[grandchild]
try:
distribution[grandchild] += prob*subprob
except KeyError:
distribution[grandchild] = prob*subprob
else:
try:
distribution[child] += prob
except KeyError:
distribution[child] = prob
if len(distribution) == 1:
result.graft(child)
else:
result.makeProbabilistic(distribution)
else:
# Deterministic branch
if variables:
poss = self.branch.possible(variables)
if poss is not None:
if len(poss) == 1:
result.graft(self.children[poss[0]].prune(path,variables))
return result
elif len(poss) < len(self.children):
print(poss)
exit()
vector = self.branch.planes[0][0].keys()
for branch,value in path:
conflict = self.branch.compare(branch,value)
if conflict is not None:
result.graft(self.children[conflict].prune(path,variables))
break
else:
# No matches
for weights,threshold,comparison in self.branch.planes:
if len(weights) == 1 and CONSTANT in weights:
value = self.branch.evaluate(weights[CONSTANT])
result.graft(self.children[value].prune(path,variables))
break
else:
result.makeBranch(self.branch,
{value: self.children[value].prune(path+[(self.branch,value)],variables)
for value in self.children})
return result
[docs] def minimizePlanes(self):
"""
Modifies tree in place so that there are no constant factors in branch weights
"""
if self.isProbabilistic():
for child in self.children.domain():
child.minimizePlanes()
elif not self.isLeaf():
self.branch.minimize()
self.children[True].minimizePlanes()
self.children[False].minimizePlanes()
[docs] def leaves(self):
"""
:warning: May return a list containing duplicates
"""
if self.isLeaf():
return [self.getLeaf()]
elif self.isProbabilistic():
return sum([child.leaves() for child in self.children.domain()],[])
else:
return sum([child.leaves() for child in self.children.values()],[])
def __hash__(self):
return hash(tuple(self.children.items()))
# return hash(str(self))
def __str__(self):
if self._string is None:
if self.isLeaf():
self._string = str(self.children[None])
elif self.isProbabilistic():
# Probabilistic branch
self._string = '\n'.join([f'{int(100*prob)}%: {value}' for value, prob in self.children._Distribution__items])
else:
# Deterministic branch
if len(self.branch.planes) == 1 and isinstance(self.branch.planes[0][1],list):
thresholds = self.branch.planes[0][1][:]
if self.branch.planes[0][2] < 0:
thresholds.append(1.)
elif self.branch.planes[0][2] > 0:
thresholds.insert(0, 0)
children = '\n'.join(['%s\t%s' % (thresholds[value] if isinstance(value, int) else 'Otherwise',
str(self.children[value]).replace('\n', '\n\t'))
for value in self.children])
else:
children = '\n'.join(['%s\t%s' % (value, str(self.children[value]).replace('\n', '\n\t')) for value in self.children])
self._string = 'if %s\n%s' % (str(self.branch), children)
return self._string
def __xml__(self):
doc = Document()
root = doc.createElement('tree')
if not self.isLeaf():
if self.branch:
root.appendChild(self.branch.__xml__().documentElement)
if isinstance(self.children,Distribution):
root.appendChild(self.children.__xml__().documentElement)
else:
for key,value in self.children.items():
if isinstance(value,bool):
node = doc.createElement('bool')
node.setAttribute('value',str(value))
elif isinstance(value,str):
node = doc.createElement('str')
node.appendChild(doc.createTextNode(value))
elif isinstance(value,int):
node = doc.createElement('int')
node.appendChild(doc.createTextNode(str(value)))
elif value is None:
node = doc.createElement('none')
else:
node = value.__xml__().documentElement
node.setAttribute('key',str(key))
root.appendChild(node)
doc.appendChild(root)
return doc
[docs] def parse(self,element):
assert element.tagName == 'tree'
node = element.firstChild
plane = None
children = {}
while node:
if node.nodeType == node.ELEMENT_NODE:
if node.tagName == 'vector':
if node.getAttribute('key'):
# Vector leaf
key = eval(node.getAttribute('key'))
children[key] = KeyedVector(node)
else:
# Branch
plane = KeyedPlane(node)
elif node.tagName == 'plane':
plane = KeyedPlane(node)
elif node.tagName == 'matrix':
key = eval(node.getAttribute('key'))
children[key] = KeyedMatrix(node)
elif node.tagName == 'tree':
key = eval(node.getAttribute('key'))
children[key] = KeyedTree(node)
elif node.tagName == 'distribution':
children = Distribution(node)
elif node.tagName == 'bool':
key = eval(node.getAttribute('key'))
children[key] = eval(node.getAttribute('value'))
elif node.tagName == 'action':
key = eval(node.getAttribute('key'))
children[key] = Action(node)
elif node.tagName == 'str':
key = eval(node.getAttribute('key'))
children[key] = str(node.firstChild.data).strip()
elif node.tagName == 'int':
key = int(node.getAttribute('key'))
children[key] = KeyedTree(node)
elif node.tagName == 'none':
key = eval(node.getAttribute('key'))
children[key] = None
node = node.nextSibling
if plane:
self.makeBranch(plane,children)
elif isinstance(children,Distribution):
self.makeProbabilistic(children)
else:
self.makeLeaf(children[None])
[docs]def makeTree(table, normalize=True):
if isinstance(table, bool):
# Boolean leaf
return KeyedTree(table)
elif table is None:
# Null leaf
return KeyedTree(table)
elif isinstance(table, str):
# String leaf
return KeyedTree(table)
elif isinstance(table, frozenset):
# Set leaf (e.g., ActionSet for a policy)
return KeyedTree(table)
elif isinstance(table, KeyedTree):
return table
elif isinstance(table, KeyedMatrix):
return KeyedTree(table)
elif isinstance(table, Distribution):
tree = KeyedTree()
branch = {}
for subtable, prob in table.items():
branch[makeTree(subtable, normalize)] = prob
dist = Distribution(branch)
if normalize:
dist.normalize()
tree.makeProbabilistic(dist)
return tree
elif 'if' in table:
# Binary deterministic branch
tree = KeyedTree()
children = {key: makeTree(table[key], normalize) for key in table if key != 'if'}
if not table['if'].isConjunction and len(table['if'].planes) == 2:
if table['if'].planes[0][0] == table['if'].planes[1][0] and table['if'].planes[0][1] == table['if'].planes[1][1] and table['if'].planes[0][2] == -table['if'].planes[1][2]:
# Not equal branch, let's just compact it into a single equal branch
children = {True: children[False], False: children[True]}
table['if'].planes[0] = (table['if'].planes[0][0], table['if'].planes[0][1], 0)
del table['if'].planes[1]
table['if'].isConjunction = True
table['if']._string = None
tree.makeBranch(table['if'], children)
return tree
elif 'case' in table:
# Non-binary deterministic branch
keys = list(table.keys())
keys.remove('case')
tree = {}
if 'otherwise' in table:
tree[None] = table['otherwise']
keys.remove('otherwise')
tree['if'] = equalRow(table['case'], keys)
for index, key in enumerate(keys):
tree[index] = table[key]
return makeTree(tree, normalize)
elif 'distribution' in table:
# Probabilistic branch
tree = KeyedTree()
branch = {}
for subtable, prob in table['distribution']:
branch[makeTree(subtable, normalize)] = prob
dist = Distribution(branch)
if normalize:
dist.normalize()
tree.makeProbabilistic(dist)
return tree
else:
# Leaf
return KeyedTree(table)
[docs]def collapseDynamics(tree, effects, variables={}):
effects.reverse()
present = tree.getKeysIn()
tree.makeFuture(present)
for stage in effects:
subtree = None
for key,dynamics in stage.items():
if dynamics and makeFuture(key) in tree.getKeysIn():
assert len(dynamics) == 1
if subtree is None:
subtree = dynamics[0]
else:
subtree += dynamics[0]
if subtree:
if tree is None:
tree = subtree
else:
for key in tree.getKeysIn():
if not key in subtree.getKeysOut():
fun = lambda m: KeyedMatrix(list(m.items())+[(key,KeyedVector({key: 1.}))])
subtree = subtree.map(fun)
tree = tree*subtree
future = [key for key in tree.getKeysIn() if isFuture(key)]
if future:
tree.makePresent(future)
return tree.prune(variables=variables)