Commit d3d801d8 by Jan Pöppel

### reworked how gibbs transition handles stepbystep case for visualisation

parent 33b0a02c
 ... ... @@ -54,7 +54,7 @@ class MCMC(object): self.numSamples = numSamples self.sampler = MarkovChainSampler(transitionModel, burnIn, fullChange) def marginals(self, variables, evidence=None, sampleChain=None, stepByStep=False): def marginals(self, variables, evidence=None, sampleChain=None, stepByStep=False, sampledVars=None): """ Function that approximates the joint prior or posterior marginals for the given variables, potentially given some evidence. ... ... @@ -89,8 +89,8 @@ class MCMC(object): res_dict[v] = np.zeros([len(self.bn.get_node(v).values), self.numSamples, 4]) for i, s in enumerate(sampleChain): for v in res_dict: res_dict[v][self.bn.get_node(v).values.index(s[0][v])][i][0] = 1 if v == s[1]: res_dict[v][self.bn.get_node(v).values.index(s[v])][i][0] = 1 if v == sampledVars[i]: res_dict[v][0][i][3] = 1 for p in range(len(res_dict[v])): curSum = np.sum(res_dict[v][p][:,0]) ... ... @@ -132,7 +132,7 @@ class MarkovChainSampler(object): transitionModel = MetropolisHastingsTransition() self.transitionModel = transitionModel def generate_markov_chain(self, bn, numSamples, initialState, evidence=None, stepByStep=False): def generate_markov_chain(self, bn, numSamples, initialState, evidence=None, stepByStep=False, sampledVars=None): """ Generator actually yielding the given number of samples drawn from the given network starting from the initialState. ... ... @@ -163,15 +163,21 @@ class MarkovChainSampler(object): evidence = {} variablesToChange = [node for node in bn.get_all_nodes() if node not in evidence] curSamples = 0 if stepByStep: state = [dict(initialState), ''] else: state = dict(initialState) # if stepByStep: # state = [dict(initialState), ''] # else: state = initialState while curSamples < self.burnIn: state = self.transitionModel.step(state, variablesToChange, bn, self.fullChange, stepByStep) state = self.transitionModel.step(dict(state), variablesToChange, bn, self.fullChange, stepByStep) if stepByStep: state = state[0] curSamples += 1 for i in range(numSamples): state = self.transitionModel.step(state, variablesToChange, bn, self.fullChange, stepByStep) state = self.transitionModel.step(dict(state), variablesToChange, bn, self.fullChange, stepByStep) if stepByStep: if sampledVars != None: sampledVars.append(state[1]) state = state[0] yield state class TransitionModel(object): ... ... @@ -227,13 +233,9 @@ class GibbsTransition(TransitionModel): if remaining_vars: varToChange = random.choice(remaining_vars) self.checked_vars.add(varToChange) currentState[varToChange.name] = varToChange.sample_value(currentState, bn.get_children(varToChange.name)) if stepByStep: currentState[0] = dict(currentState[0]) currentState[0][varToChange.name] = varToChange.sample_value(currentState[0], bn.get_children(varToChange.name)) return [currentState[0], varToChange] else: currentState = dict(currentState) currentState[varToChange.name] = varToChange.sample_value(currentState, bn.get_children(varToChange.name)) return currentState, varToChange return currentState ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!