Commit d3d801d8 authored by Jan Pöppel's avatar Jan Pöppel
Browse files

reworked how gibbs transition handles stepbystep case for visualisation

parent 33b0a02c
......@@ -54,7 +54,7 @@ class MCMC(object):
self.numSamples = numSamples
self.sampler = MarkovChainSampler(transitionModel, burnIn, fullChange)
def marginals(self, variables, evidence=None, sampleChain=None, stepByStep=False):
def marginals(self, variables, evidence=None, sampleChain=None, stepByStep=False, sampledVars=None):
"""
Function that approximates the joint prior or posterior marginals for the
given variables, potentially given some evidence.
......@@ -89,8 +89,8 @@ class MCMC(object):
res_dict[v] = np.zeros([len(self.bn.get_node(v).values), self.numSamples, 4])
for i, s in enumerate(sampleChain):
for v in res_dict:
res_dict[v][self.bn.get_node(v).values.index(s[0][v])][i][0] = 1
if v == s[1]:
res_dict[v][self.bn.get_node(v).values.index(s[v])][i][0] = 1
if v == sampledVars[i]:
res_dict[v][0][i][3] = 1
for p in range(len(res_dict[v])):
curSum = np.sum(res_dict[v][p][:,0])
......@@ -132,7 +132,7 @@ class MarkovChainSampler(object):
transitionModel = MetropolisHastingsTransition()
self.transitionModel = transitionModel
def generate_markov_chain(self, bn, numSamples, initialState, evidence=None, stepByStep=False):
def generate_markov_chain(self, bn, numSamples, initialState, evidence=None, stepByStep=False, sampledVars=None):
"""
Generator actually yielding the given number of samples drawn from
the given network starting from the initialState.
......@@ -163,15 +163,21 @@ class MarkovChainSampler(object):
evidence = {}
variablesToChange = [node for node in bn.get_all_nodes() if node not in evidence]
curSamples = 0
if stepByStep:
state = [dict(initialState), '']
else:
state = dict(initialState)
# if stepByStep:
# state = [dict(initialState), '']
# else:
state = initialState
while curSamples < self.burnIn:
state = self.transitionModel.step(state, variablesToChange, bn, self.fullChange, stepByStep)
state = self.transitionModel.step(dict(state), variablesToChange, bn, self.fullChange, stepByStep)
if stepByStep:
state = state[0]
curSamples += 1
for i in range(numSamples):
state = self.transitionModel.step(state, variablesToChange, bn, self.fullChange, stepByStep)
state = self.transitionModel.step(dict(state), variablesToChange, bn, self.fullChange, stepByStep)
if stepByStep:
if sampledVars != None:
sampledVars.append(state[1])
state = state[0]
yield state
class TransitionModel(object):
......@@ -227,13 +233,9 @@ class GibbsTransition(TransitionModel):
if remaining_vars:
varToChange = random.choice(remaining_vars)
self.checked_vars.add(varToChange)
currentState[varToChange.name] = varToChange.sample_value(currentState, bn.get_children(varToChange.name))
if stepByStep:
currentState[0] = dict(currentState[0])
currentState[0][varToChange.name] = varToChange.sample_value(currentState[0], bn.get_children(varToChange.name))
return [currentState[0], varToChange]
else:
currentState = dict(currentState)
currentState[varToChange.name] = varToChange.sample_value(currentState, bn.get_children(varToChange.name))
return currentState, varToChange
return currentState
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment