from collections import OrderedDict
import copy
import heapq
import itertools
import logging
import math
from psychsim.probability import Distribution
from . import keys
from psychsim.pwl.vector import KeyedVector, VectorDistribution
from psychsim.pwl.matrix import KeyedMatrix
from psychsim.pwl.tree import KeyedTree
[docs]class VectorDistributionSet:
"""
Represents a distribution over independent state vectors, i.e., independent L{VectorDistribution}s
"""
def __init__(self, start=None):
self.distributions = {}
self.certain = {}
self.keyMap = {}
if isinstance(start, dict):
substate = 0
for key, value in start.items():
self.join(key, value, substate)
substate += 1
elif isinstance(start, VectorDistribution):
self.distributions[0] = start
self.keyMap = {k: 0 for k in start.keys()}
elif start is not None:
raise TypeError(f'Unknown argument type for initial value: {start.__class__.__name__}')
[docs] def add_distribution(self, dist):
"""
:param dist: The sub-distribution to add to this state
:type dist: VectorDistribution
:warning: Keys in new distribution must not already exist in current distribution
:return: the index into the newly added distribution
"""
for key in dist.keys():
if key in self:
raise ValueError(f'Cannot add distribution containing {key}, because it already exists')
substate = 0
while substate in self.distributions:
substate += 1
for key in dist.keys():
self.keyMap[key] = substate
self.distributions[substate] = dist
return substate
[docs] def keys(self):
return self.keyMap.keys()
def __contains__(self, key):
return key in self.keyMap
def __iter__(self):
"""
Iterate through elements of this set, with each element being a L{VectorDistributionSet} (with probability not necessarily 1)
"""
size = 1
domains = {}
substates = sorted(self.distributions.keys())
for substate in substates:
dist = self.distributions[substate]
domains[substate] = dist.domain()
size *= len(dist.domain())
for i in range(size):
value = self.__class__()
value.certain.update(self.certain)
value.keyMap.update(self.keyMap)
for substate in substates:
element = domains[substate][i % len(domains[substate])]
prob = self.distributions[substate][element]
i //= len(domains[substate])
value.distributions[substate] = VectorDistribution({element: prob})
yield value
[docs] def probability(self, vector=None):
"""
:type vector: KeyedVector
:return: the probability of the given world according to this state distribution. If no world is given, then return the probability of the overall state
"""
prob = 1
if vector is None:
for distribution in self.distributions.values():
prob *= distribution.probability()
else:
for key,value in vector.items():
if self.keyMap[key] is not None:
prob *= self.marginal(key)[value]
return prob
def __len__(self):
"""
:return: the number of elements in the implied joint distribution
:rtype: int
"""
prod = 1
for dist in self.distributions.values():
prod *= len(dist)
return prod
def __getitem__(self,key):
return self.marginal(key)
def __setitem__(self,key,value):
"""
Computes a conditional probability of this distribution given the value for this key. To do so, it removes any elements from the distribution that are inconsistent with the given value and then normalizes.
.. warning:: If you want to overwrite any existing values for this key use L{join} (which computes a new joint probability)
"""
self.setitem(key, value)
[docs] def setitem(self, key, value):
substate = self.keyMap[key]
if substate is None:
if self.certain[key] != value:
raise ValueError(f'P({key}={value}) = 0 because {key}={self.certain[key]}')
else:
prob = 0
dist = self.distributions[substate]
items = dist._Distribution__items
i = 0
while i < len(items):
if math.isclose(items[i][0][key], value):
prob += items[i][1]
i += 1
else:
del items[i]
if len(items) == 0:
raise ValueError(f'P({key}={value}) = 0)')
else:
dist.normalize()
return prob
[docs] def subDistribution(self,key):
raise DeprecationWarning('Use "subdistribution" instead.')
[docs] def subdistribution(self,key):
"""
:return: the minimal joint distribution containing this key
"""
substate = self.keyMap[key]
if substate is None:
return VectorDistribution({KeyedVector({key: self.certain[key]}): 1})
else:
return self.distributions[substate]
def __delitem__(self,key):
""""
Removes the given column from its corresponding vector (raises KeyError if not present in this distribution)
"""
substate = self.keyMap[key]
del self.keyMap[key]
if substate is None:
del self.certain[key]
else:
dist = self.distributions[substate]
if len(dist.first()) <= 2:
# Assume CONSTANT is the other key, so this whole distribution goes
del self.distributions[substate]
else:
# Go through each vector and remove the key
for vector in dist.domain():
prob = dist[vector]
del dist[vector]
del vector[key]
dist.addProb(vector,prob)
[docs] def deleteKeys(self,toDelete):
"""
Removes multiple columns at once
"""
distributions = {}
for key in toDelete:
substate = self.keyMap[key]
del self.keyMap[key]
if substate is None:
del self.certain[key]
elif substate in distributions:
old = distributions[substate]
distributions[substate] = []
for vector,prob in old:
del vector[key]
distributions[substate].append((vector,prob))
else:
dist = self.distributions[substate]
distributions[substate] = []
for vector in dist.domain():
prob = dist[vector]
del vector[key]
distributions[substate].append((vector,prob))
for substate,dist in distributions.items():
if len(dist[0][0]) == 1:
assert next(iter(dist[0][0].keys())) == keys.CONSTANT
del self.distributions[substate]
else:
self.distributions[substate].clear()
for vector,prob in distributions[substate]:
self.distributions[substate].addProb(vector,prob)
[docs] def split(self,key):
"""
:return: partitions this distribution into subsets corresponding to possible values for the given key
:rtype: dict(str,VectorDistributionSet)
"""
destination = self.keyMap[key]
original = self.distributions[destination]
result = {}
for vector in original.domain():
value = vector[key]
if not value in result:
# Copy everything from me except the distribution of the given key
result[value] = self.__class__()
result.keyMap.update(self.keyMap)
for substate,distribution in self.distributions.items():
if substate != destination:
result.distributions[substate] = copy.deepcopy(distribution)
result[value].distributions[destination][vector] = original[vector]
return result
[docs] def collapse(self, substates, preserve_certainty=True):
"""
Collapses (in place) the given substates into a single joint L{VectorDistribution}
"""
if len(substates) > 0:
if isinstance(next(iter(substates)), str):
# Why not handle keys, too?
substates = self.substate(substates)
if preserve_certainty:
substates = {s for s in substates
if s is not None and len(self.distributions[s]) > 1}
result = self.merge(substates)
return result
else:
raise ValueError('No substates to collapse')
[docs] def uncertain(self):
"""
:return: C{True} iff this distribution has any uncertainty about the vector
:rtype: bool
"""
return sum(map(len,self.distributions.values())) > len(self.distributions)
[docs] def findUncertainty(self,substates=None):
"""
:param substates: Consider only the given substates as candidates
:return: a substate containing an uncertain distribution if one exists; otherwise, None
:rtype: int
"""
if substates is None:
substates = self.distributions.keys()
for substate in substates:
if substate is not None and len(self.distributions[substate]) > 1:
return substate
else:
return None
[docs] def make_certain(self):
for substate, dist in list(self.distributions.items()):
if len(dist) == 1:
# 100% probability for one vector
vector = dist.first()
for var, value in vector.items():
if var != keys.CONSTANT:
self.keyMap[var] = None
self.certain[var] = value
del self.distributions[substate]
[docs] def vector(self):
"""
:return: if this distribution contains only a single vector, return that vector; otherwise, throw exception
:rtype: KeyedVector
"""
vector = KeyedVector()
for substate,distribution in self.distributions.items():
assert len(distribution) == 1,'Cannot return vector from uncertain distribution'
vector.update(distribution.domain()[0])
return vector
[docs] def worlds(self):
"""
:return: iterator through all possible joint vectors (i.e., possible worlds) and their probabilities
:rtype: KeyedVector,float
"""
# Convert to lists now to ensure same ordering throughout
substates = list(self.distributions.keys())
domains = {substate: list(self.distributions[substate].domain()) for substate in substates}
for index in range(len(self)):
vector = {}
prob = 1
for substate in substates:
subindex = index % len(self.distributions[substate])
subvector = domains[substate][subindex % len(domains[substate])]
vector.update(subvector)
prob *= self.distributions[substate][subvector]
index = index // len(self.distributions[substate])
yield KeyedVector(vector),prob
[docs] def select(self,maximize=False,incremental=False):
"""
Reduce distribution to a single element, sampled according to the given distribution
:param incremental: if C{True}, then select each key value in series (rather than picking out a joint vector all at once, default is C{False})
:return: the probability of the selection made
"""
if incremental:
prob = KeyedVector()
else:
prob = 1
for distribution in self.distributions.values():
if incremental:
prob.update(distribution.select(maximize,incremental))
else:
prob *= distribution.select(maximize,incremental)
return prob
[docs] def substate(self, obj, ignoreCertain=False):
"""
:return: the substate referred to by all of the keys in the given object
"""
if isinstance(obj,bool):
raise DeprecationWarning('If you really need this, please inform management.')
return set()
elif ignoreCertain:
return {self.keyMap[k] for k in obj if k != keys.CONSTANT and self.keyMap[k] is not None and len(self.distributions[self.keyMap[k]]) > 1}
else:
return {self.keyMap[k] for k in obj if k != keys.CONSTANT}
[docs] def merge(self, substates):
"""
:return: the substate into which they've all been merged
"""
destination = None
for substate in substates:
if destination is None:
destination = substate
elif substate is not None:
dist = self.distributions[substate]
self.distributions[destination].merge(dist,True)
del self.distributions[substate]
for key in dist.keys():
if key != keys.CONSTANT:
self.keyMap[key] = destination
return destination
[docs] def join(self, key, value, substate=None):
"""
Modifies the distribution over vectors to have the given value for the given key
:param key: the key to the column to modify
:type key: str
:param value: either a single value to apply to all vectors, or else a L{Distribution} over possible values
:substate: name of substate vector distribution to join with, ignored if the key already exists in this state. By default, find a new substate
"""
if key in self.keyMap:
assert substate is None, f'Cannot join {key} to distribution {substate} as it already exists in distribution {self.keyMap[key]}'
substate = self.keyMap[key]
if substate is None:
del self.certain[key]
else:
self.distributions[substate].delete_column(key)
if substate is None:
if isinstance(value, Distribution):
substate = 0
while substate in self.distributions:
substate += 1
self.keyMap[key] = substate
else:
# Certain value
self.keyMap[key] = substate = None
else:
self.keyMap[key] = substate
if substate is None:
self.certain[key] = value
else:
if not substate in self.distributions:
self.distributions[substate] = VectorDistribution([(KeyedVector({keys.CONSTANT:1}), 1)])
self.distributions[substate].join(key, value)
[docs] def marginal(self, key):
substate = self.keyMap[key]
if substate is None:
return Distribution({self.certain[key]: 1})
else:
return self.distributions[substate].marginal(key)
[docs] def domain(self, key):
if isinstance(key, str):
substate = self.keyMap[key]
if substate is None:
return [self.certain[key]]
else:
return {v[key] for v in self.distributions[substate].domain()}
elif isinstance(key, list):
# Identify the relevant subdistributions
substates = OrderedDict()
for subkey in key:
loc = self.keyMap[subkey]
try:
substates[loc].append(subkey)
raise RuntimeError('Currently unable to compute domains over interdependent state features')
except KeyError:
substates[loc] = [subkey]
# Determine the domain of each feature across distributions
domains = []
for loc,subkeys in substates.items():
dist = self.distributions[loc]
domains.append([[vector[k] for k in subkeys] for vector in dist.domain()])
return [sum(combo,[]) for combo in itertools.product(*domains)]
else:
raise NotImplementedError
[docs] def replace(self, substitution, key=None):
"""
Replaces column values, either across all columns, or only for the specified column
"""
if key is None:
for dist in self.distributions.values():
dist.replace(substitution)
else:
self.distributions[self.keyMap[key]].replace(substitution, key)
[docs] def items(self):
return self.distributions.items()
[docs] def clear(self):
self.distributions.clear()
self.keyMap.clear()
[docs] def normalize(self):
for dist in self.distributions.values():
dist.normalize()
[docs] def prune(self, threshold):
raise DeprecationWarning('Use prune_probability or prune_size')
[docs] def prune_probability(self, threshold: float) -> float:
prob = 1
for dist in self.distributions.values():
prob *= dist.prune_probability(threshold)
return prob
[docs] def prune_size(self, k: int = 1) -> float:
if len(self) > k:
# List of substate/distribution pairs for uncertain subdistributions
dist_list = [dist_tup for dist_tup in self.distributions.items() if len(dist_tup[1]) > 1]
# Most probable element in each subdistribution (along with index)
for dist_tup in dist_list:
dist_tup[1]._Distribution__items.sort(key=lambda tup: (tup[1], tup[0]), reverse=True)
max_list = [dist_tup[1]._Distribution__items[0] for dist_tup in dist_list]
# Probability of most probable combination
max_prob = math.prod([tup[1] for tup in max_list])
# Worlds added so far
worlds = []
heapq.heappush(worlds, (max_prob, [0 for max_item in max_list]))
for i, dist_tup in enumerate(dist_list):
current_worlds = worlds[:]
for element, tup in enumerate(dist_tup[1]._Distribution__items[1:]):
# This is not the max item (i.e., it is a viable pruning candidate)
obj, prob = tup
change = False
for other_prob, other_world in current_worlds:
if other_world[i] != element+1:
# What if we insert this element to this currently max-k world?
new_prob = other_prob*prob/dist_tup[1]._Distribution__items[other_world[i]][1]
if len(worlds) < k:
# Come on in, there's plenty of room
new_world = other_world[:]
new_world[i] = element+1
heapq.heappush(worlds, (new_prob, new_world))
change = True
elif new_prob > worlds[0][0]:
# This new world is better than the least likely world currently active
new_world = other_world[:]
new_world[i] = element+1
heapq.heapreplace(worlds, (new_prob, new_world))
change = True
if not change:
# Worlds only going to get less likely from here
break
target = dist_list[0][0]
for sub, dist in dist_list[1:]:
del self.distributions[sub]
first = True
items = []
for prob, world in worlds:
for i, element in enumerate(world):
if i == 0:
vector = dist_list[i][1]._Distribution__items[element][0]
vector = vector.__class__(vector)
else:
vector.update(dist_list[i][1]._Distribution__items[element][0])
if first:
for key in vector.keys():
self.keyMap[key] = target
items.append((vector, prob))
self.distributions[target]._Distribution__items = items
return self.probability()
[docs] def update(self,other,keySet,scale=1):
# Anyone else mixed up in this?
toMerge = set(keySet)
for key in keySet:
# Any new keys in the same joint as this guy?
for newKey in self.keyMap:
if (key in self.keyMap and self.keyMap[key] == self.keyMap[newKey]) \
or other.keyMap[key] == other.keyMap[newKey]:
# This key is in the same joint
if len(self.distributions[self.keyMap[newKey]]) > 1 or \
len(other.distributions[other.keyMap[newKey]]) > 1 or \
self.marginal(newKey) != other.marginal(newKey):
# If there's uncertainty
toMerge.add(newKey)
if len(toMerge) > 0: # If 0, no difference between self and other to begin with
# Prepare myself to merge
substates = {self.keyMap[k] for k in toMerge if k in self.keyMap}
self.collapse(substates,False)
for key in toMerge:
if key in self.keyMap:
destination = self.keyMap[key]
break
else:
destination = max(self.keyMap.values())+1
self.distributions[destination] = VectorDistribution()
# Align and merge the other
substates = {other.keyMap[k] for k in toMerge}
other.collapse(substates,False)
dist = other.distributions[other.keyMap[key]]
for vector in dist.domain():
self.distributions[destination].addProb(vector,dist[vector]*scale)
for key in vector.keys():
if key != keys.CONSTANT:
self.keyMap[key] = destination
return destination
else:
return None
def __add__(self,other):
if isinstance(other,self.__class__):
assert self.keyMap == other.keyMap,'Currently unable to add distributions with mismatched substates'
result = self.__class__()
result.keyMap.update(self.keyMap)
for substate,value in self.distributions.items():
result.distributions[substate] = value + other.distributions[substate]
return result
else:
raise NotImplementedError
def __sub__(self,other):
if isinstance(other,self.__class__):
assert self.keyMap == other.keyMap,'Currently unable to subtract distributions with mismatched substates'
result = self.__class__()
result.keyMap.update(self.keyMap)
for substate,value in self.distributions.items():
result.distributions[substate] = value - other.distributions[substate]
return result
else:
raise NotImplementedError
def __imul__(self,other,select=False):
if isinstance(other,KeyedMatrix):
self.multiply_matrix(other)
elif isinstance(other,KeyedTree):
self.multiply_tree(other, select=select)
elif isinstance(other,KeyedVector):
self.multiply_vector(other)
else:
raise NotImplementedError
return self
[docs] def multiply_vector(self, other):
substates = self.substate(other)
if substates:
self.collapse(substates)
destination = self.findUncertainty(substates)
else:
destination = None
if destination is None:
destination = len(self.distributions)
while destination in self.distributions:
destination -= 1
# destination = max(self.keyMap.values())+1
total = 0.
for key in other:
if key == keys.CONSTANT and key not in self.keyMap:
# Assume CONSTANT value is 1?
total += other[key]
elif key != keys.CONSTANT and self.keyMap[key] != destination:
# Certain value for this key
marginal = self.marginal(key)
total += other[key]*next(iter(marginal.domain()))
self.join(keys.VALUE, total, destination)
dist = self.distributions[destination]
for index, item in enumerate(dist._Distribution__items):
item[0][keys.VALUE] += sum([other[key]*item[0].get(key, 0) for key in other])
[docs] def multiply_matrix(self, other):
# Focus on subset that this matrix affects
substates = self.substate(other.getKeysIn(), True)
if substates:
destination = self.collapse(substates)
else:
destination = None
# Go through each key this matrix sets
for rowKey, vector in other.items():
result = Distribution()
if destination is None:
# Every value is 100%
total = 0
for colKey, weight in vector.items():
if colKey == keys.CONSTANT:
# Doesn't really matter
total += weight
else:
substate = self.keyMap[colKey]
if substate is None:
value = self.certain[colKey]
else:
value = self.distributions[substate].first()[colKey]
total += weight*value
# assert not rowKey in self.keyMap,'%s already exists' % (rowKey)
destination = len(self.distributions)
while destination in self.distributions:
destination -= 1
self.join(rowKey, total, destination)
else:
# There is at least one uncertain multiplicand
for state, prob in self.distributions[destination].items():
total = 0
for colKey, weight in vector.items():
if colKey == keys.CONSTANT:
# Doesn't really matter
total += weight
else:
substate = self.keyMap[colKey]
if substate == destination:
value = state[colKey]
elif substate is None:
value = self.certain[colKey]
else:
# Certainty
value = self.distributions[substate].first()[colKey]
total += weight*value
state[rowKey] = total
self.keyMap[rowKey] = destination
[docs] def multiply_tree(self, other, probability=1, select=False):
if other.isLeaf():
self *= other.children[None]
elif other.isProbabilistic():
if select:
oldKid, prob = other.children.sample(quantify=True, most_likely=select=='max')
self.multiply_tree(oldKid, probability=prob, select=select)
else:
oldKids = list(other.children.domain())
# Multiply out children, other than first-born
newKids = []
for child in oldKids[1:]:
prob = other.children[child]
assert child.getKeysOut() == oldKids[0].getKeysOut()
myChild = copy.deepcopy(self)
myChild.multiply_tree(child, probability=prob, select=select)
newKids.append(myChild)
self.multiply_tree(oldKids[0], probability=other.children[oldKids[0]], select=select)
subkeys = oldKids[0].getKeysOut()
# Compute first-born child
newKids.insert(0,self)
for index in range(len(oldKids)):
prob = other.children[oldKids[index]]
substates = newKids[index].substate(subkeys)
if len(substates) > 1:
substate = newKids[index].collapse(substates)
else:
substate = next(iter(substates))
if index == 0:
for vector in self.distributions[substate].domain():
self.distributions[substate][vector] *= prob
mySubstate = substate
else:
toCollapse = (subkeys,set())
while len(toCollapse[0]) + len(toCollapse[1]) > 0:
mySubstates = self.substate(toCollapse[1]|\
set(self.distributions[mySubstate].keys()))
if len(mySubstates) > 1:
mySubstate = self.collapse(mySubstates,False)
else:
mySubstate = next(iter(mySubstates))
substates = newKids[index].substate(toCollapse[0]|set(newKids[index].distributions[substate].keys()))
if len(substates) > 1:
substate = newKids[index].collapse(substates,False)
else:
substate = next(iter(substates))
toCollapse = ({k for k in self.distributions[mySubstate].keys() \
if k != keys.CONSTANT and \
not k in newKids[index].distributions[substate].keys()},
{k for k in newKids[index].distributions[substate].keys() \
if k != keys.CONSTANT and \
not k in self.distributions[mySubstate].keys()})
distribution = newKids[index].distributions[substate]
for vector in distribution.domain():
self.distributions[mySubstate].addProb(vector,distribution[vector]*prob)
else:
# Apply the test to this tree
sufficient = not other.branch.isConjunction # If any plane test gets this value, no need to test further (e.g., False for conjunctions)
first = '__null__'
states = {first: (self, probability, None)} # (state, probability, substate)
for p_index, plane in enumerate(other.branch.planes):
current_states = [(old_value, s_tuple) for old_value, s_tuple in list(states.items()) if old_value != sufficient]
if len(current_states) == 0:
# No more possibility of a different result
break
states = {sufficient: states[sufficient]} if sufficient in states else {}
for old_value, s_tuple in current_states:
s = s_tuple[0]
s *= plane[0]
should_copy = False
valSub = s.keyMap[keys.VALUE]
if not math.isclose(s_tuple[1], 1, rel_tol=1e-8):
# We've already descended along one side of a branch
partials = [substate for substate, dist in self.distributions.items() if not dist.is_complete()]
if len(partials) > 1:
raise ValueError(f'Miraculous but incorrect appearance of multiple subdistributions with probability mass < 1 {[self.distributions[s].probability() for s in partials]}')
elif len(partials) == 0:
raise ValueError(f'Where did all the incompleteness go?')
if partials[0] != valSub:
# The test result covers a different set of variables than was tested upstream
valSub = s.merge([partials[0], valSub])
del s.keyMap[keys.VALUE]
# Iterate through possible test results
vector_list = list(s.distributions[valSub].items())
s.distributions[valSub].clear()
for vector, prob in vector_list:
# Test this vector against the hyperplane
test = other.branch.evaluate(vector[keys.VALUE], p_index)
del vector[keys.VALUE]
if test in states:
if len(vector) > 1:
# Merge in this vector's keys with an existing matching test result
new_sub = states[test][0].merge(states[test][0].substate(vector.keys()))
else:
# Nothing to merge, just carry over the substate from an existing matching result
new_sub = states[test][2]
old_dist = states[test][0].distributions[states[test][2]]
for old_key in old_dist.keys():
if old_key not in vector:
raise RuntimeError
sub_dist = s[old_key]
if len(sub_dist) > 1:
raise ValueError('Worlds are branching in a way that I am not prepared to handle')
vector[old_key] = sub_dist.first()
if s is self and not should_copy:
for old_vec, old_prob in states[test][0].distributions[new_sub].items():
s.distributions[valSub].addProb(old_vec, old_prob)
states[test] = (self, states[test][1]+prob, valSub)
should_copy = True
else:
states[test] = (states[test][0], states[test][1]+prob, states[test][2])
if len(vector) > 1:
states[test][0].distributions[states[test][2]].addProb(vector, prob)
elif should_copy:
states[test] = (copy.deepcopy(s), prob, valSub)
states[test][0].distributions[valSub].clear()
states[test][0].distributions[valSub].addProb(vector, prob)
else:
states[test] = (s, prob, valSub)
should_copy = True
states[test][0].distributions[valSub].addProb(vector, prob)
if states[test][0] is self:
first = test
if len(s.distributions[valSub]) == 0:
del s.distributions[valSub]
assert states, 'Empty result of multiplication'
for test in states:
if test not in other.children:
if test is None:
logging.error('Missing fallback branch in tree:\n%s' % (str(other)))
else:
logging.error('Missing branch for value %s in tree:\n%s' % (test, str(other)))
self.multiply_tree(other.children[first], probability*states[first][1])
del states[first]
new_keys = set(other.getKeysOut())
for test, s_plus in states.items():
s = s_plus[0]
branch_keys = set(s.distributions[s_plus[2]].keys()) - {keys.CONSTANT}
s.multiply_tree(other.children[test], states[test][1])
self.update(s, new_keys|branch_keys)
def __rmul__(self,other):
if isinstance(other,KeyedVector) or isinstance(other,KeyedTree):
self *= other
substate = self.keyMap[keys.VALUE]
distribution = self.distributions[substate]
del self.keyMap[keys.VALUE]
total = 0.
for vector, prob in distribution.items():
total += prob*vector[keys.VALUE]
del vector[keys.VALUE]
if len(vector) <= 1:
del self.distributions[substate]
# for s in self.distributions:
# assert s in self.keyMap.values(),self.distributions[s]
# for k,s in self.keyMap.items():
# if k != keys.CONSTANT:
# assert s in self.distributions
return total
else:
raise NotImplementedError
[docs] def rollback(self, debug=False):
"""
Removes any current state values and makes any future state values the current ones
:param debug: if True, then run some checks on the values
"""
# What keys have both current and future values?
pairs = [k for k in self.keyMap if k != keys.CONSTANT and
not keys.isFuture(k) and keys.makeFuture(k) in self.keyMap]
for now in pairs:
nowSub = self.keyMap[now]
future = keys.makeFuture(now)
futureSub = self.keyMap[future]
del self.keyMap[future]
if nowSub is None:
if futureSub is None:
self.certain[now] = self.certain[future]
del self.certain[future]
else:
del self.certain[now]
distribution = None
else:
distribution = self.distributions[nowSub]
items = list(distribution.items())
distribution.clear()
for vector, prob in items:
if nowSub == futureSub:
# Kill two birds with one stone
vector[now] = vector[future]
del vector[future]
else:
del vector[now]
if len(vector) > 1:
distribution.add_prob(vector, prob)
elif len(vector) == 1 and debug:
assert next(iter(vector.keys())) == keys.CONSTANT
if nowSub != futureSub:
# Kill two birds with two stones
if distribution is not None and len(distribution) == 0:
del self.distributions[nowSub]
self.keyMap[now] = futureSub
if futureSub is None:
self.certain[now] = self.certain[future]
else:
distribution = self.distributions[futureSub]
items = list(distribution.items())
distribution.clear()
for vector, prob in items:
vector[now] = vector[future]
del vector[future]
distribution.add_prob(vector, prob)
if debug:
assert now in self.keyMap
assert self.keyMap[now] in self.distributions,now
if debug:
for s in self.distributions:
assert s in self.keyMap.values(),'Distribution %s is missing\n%s' % (s,self.distributions[s])
for k in self.distributions[s].keys():
assert not keys.isFuture(k),'Future key %s persists after rollback' \
% (k)
for k,s in self.keyMap.items():
if k != keys.CONSTANT:
assert s in self.distributions,'%s: %s' % (k,s)
assert not keys.isFuture(k)
[docs] def simpleRollback(self,futures):
# Make the future the present
for key in futures:
future = keys.makeFuture(key)
oldstate = self.keyMap[key]
newstate = self.keyMap[future]
dist = self.distributions[newstate]
if oldstate == newstate:
for vector in dist.domain():
prob = dist[vector]
del dist[vector]
vector[key] = vector[future]
del vector[future]
dist.addProb(vector,prob)
elif len(dist) > 1:
# New value is probabilistic, not a single value, so update old value across possible worlds
for vector in dist.domain():
prob = dist[vector]
del dist[vector]
vector[key] = vector[future]
del vector[future]
dist[vector] = prob
self.keyMap[key] = newstate
# Remove old state values
dist = self.distributions[oldstate]
if len(dist.first()) > 2:
# Other variables still remain
for vector in dist.domain():
prob = dist[vector]
del dist[vector]
del vector[key]
dist.addProb(vector,prob)
else:
del self.distributions[oldstate]
else:
vector = dist.first()
value = vector[future]
del dist[vector]
del vector[future]
if len(vector) > 1:
dist[vector] = 1
else:
del self.distributions[newstate]
dist = self.distributions[oldstate]
for vector in dist.domain():
prob = dist[vector]
del dist[vector]
vector[key] = value
dist.addProb(vector,prob)
del self.keyMap[future]
def __eq__(self,other):
if not isinstance(other, VectorDistributionSet):
return False
remaining = set(self.keyMap.keys())
if remaining != set(other.keyMap.keys()):
# The two do not even contain the same columns
return False
else:
while remaining:
key = remaining.pop()
if self.keyMap[key] is None:
if other.keyMap[key] is None:
if self.certain[key] != other.certain[key]:
return False
else:
return False
elif other.keyMap[key] is None:
return False
else:
distributionMe = self.distributions[self.keyMap[key]]
distributionYou = other.distributions[other.keyMap[key]]
if distributionMe != distributionYou:
return False
remaining -= set(distributionMe.keys())
return True
[docs] def delete_value(self, key, value):
"""Removes the given value for the given key from the state and then renormalizes
:param value: value (or set of values) to be removed
"""
distribution = self.distributions[self.keyMap[key]]
for vector in distribution.domain():
if isinstance(value, set):
if vector[key] in value:
del distribution[vector]
elif vector[key] == value:
del distribution[vector]
distribution.normalize()
def __deepcopy__(self,memo):
result = self.__class__()
for substate, distribution in self.distributions.items():
new = copy.deepcopy(distribution)
result.distributions[substate] = new
result.keyMap.update(self.keyMap)
result.certain.update(self.certain)
return result
def __str__(self):
uncertain = '\n---\n'.join([str(dist) for dist in self.distributions.values() if len(dist) > 1])
if self.certain:
vector = KeyedVector(self.certain)
return f'100%\n{vector.sortedString()}\n{uncertain}'
else:
return uncertain
[docs] def copySubset(self, ignore=None, include=None):
raise DeprecationWarning('Use copy_subset instead')
[docs] def copy_subset(self, ignore=None, include=None):
result = self.__class__()
if ignore is None and include is None:
# Ignoring nothing, including everything, so this is just a copy
return self.__deepcopy__({})
if include is None:
include = set(self.keys())
if ignore is None:
keySubset = include
else:
keySubset = include - ignore
for key in keySubset:
if key not in result and key in self:
if self.keyMap[key] is None:
result.certain[key] = self.certain[key]
result.keyMap[key] = None
else:
distribution = self.distributions[self.keyMap[key]]
substate = len(result.distributions)
result.distributions[substate] = distribution.__class__()
intersection = distribution.keys() & keySubset #[k for k in distribution.keys() if k in keySubset]
for subkey in intersection:
result.keyMap[subkey] = substate
new_dist = []
for vector, prob in distribution.items():
new_dict = {subkey: vector[subkey] for subkey in intersection}
new_dict[keys.CONSTANT] = 1
new_dist.append((vector.__class__(new_dict), prob))
result.distributions[substate] = VectorDistribution(new_dist)
result.distributions[substate].remove_duplicates()
return result
[docs] def verifyIntegrity(self,sumToOne=False):
for key in self.keys():
assert self.keyMap[key] in self.distributions,'Distribution %s missing for key %s' % \
(self.keyMap[key],key)
distribution = self.distributions[self.keyMap[key]]
for vector in distribution.domain():
assert key in vector,'Key %s is missing from vector\n%s\nProb: %d%%' % \
(key,vector,distribution[vector]*100)
for other in vector:
assert other == keys.CONSTANT or self.keyMap[other] == self.keyMap[key] ,\
f'Unmapped key {other} is in vector\n{vector}'
if sumToOne:
assert (distribution.probability()-1)<.000001,'Distribution sums to %4.2f' % (distribution.probability())
else:
assert distribution.probability()<1.000001, f'Distribution sums to {distribution.probability()}'
[docs] def copy_value(self, old_key, new_key):
"""
Modifies the state so that the distribution over the new key's values is identical to that of the old key
"""
substate = self.keyMap[old_key]
self.keyMap[new_key] = substate
if substate is None:
self.certain[new_key] = self.certain[old_key]
else:
dist = self.distributions[substate]
for vector in dist.domain():
prob = dist[vector]
del dist[vector]
vector[new_key] = vector[old_key]
dist[vector] = prob
[docs] def is_minimal(self):
"""
:return: False iff any non-singleton distributions are non-singleton for all variables in that distribution"
"""
for dist in self.distributions.values():
if len(dist) > 1:
for vector in dist.domain():
for key in vector.keys():
if key != keys.CONSTANT and len(self.marginal(key)) == 1:
# This feature is 100% certain, yet exists in a distribution that is uncertain
return False
else:
return True
[docs] def diff(self, other):
"""
:return: a dictionary of differences between me and the given state
"""
result = {'only_me': self.keys() - other.keys(),
'only_you': other.keys() - self.keys(),
'dependency mismatch': {},
'probability mismatch': set()}
for key in self.keys() & other.keys():
if self.marginal(key) != other.marginal(key):
result['probability mismatch'].add(key)
dist_me = self.distributions[self.keyMap[key]]
dist_you = other.distributions[other.keyMap[key]]
match = dist_me.keys() & dist_you.keys()
mismatch = (dist_me.keys() | dist_you.keys()) - match
if mismatch:
result['dependency mismatch'][key] = mismatch
return result