package com.mockturtlesolutions.snifflib.invprobs;

import com.mockturtlesolutions.snifflib.datatypes.DblMatrix;
import com.mockturtlesolutions.snifflib.datatypes.DblParamSet;
import com.mockturtlesolutions.snifflib.graphics.DefaultReportInstance;
import com.mockturtlesolutions.snifflib.graphics.DefaultReportable;
import com.mockturtlesolutions.snifflib.graphics.DefaultReporter;
import com.mockturtlesolutions.snifflib.graphics.SnifflibGraphicsException;
import com.mockturtlesolutions.snifflib.stats.UniformDistribution;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Set;
import java.util.Vector;
import org.codehaus.groovy.syntax.Types;

/* loaded from: input_file:com/mockturtlesolutions/snifflib/invprobs/MCMC.class */
public class MCMC implements DefaultReportable, Runnable, ObjectSampler, Optimization {
    private LinkedList objectChain;
    private Vector blocks;
    private MCMCSampler sampler;
    private int burn_in;
    private int chain_length;
    private UniformDistribution unif;
    private HashMap defaultReporters;
    private int reject_count;
    private int accept_count;
    private boolean haltSolution;
    private Vector optimizationListeners = new Vector();
    private boolean stillRunning = false;
    private int n_samples = 10;
    private int max_iterations = Types.KEYWORD_PRIVATE;
    private LinkedList chain = new LinkedList();

    public MCMC(MCMCSampler mCMCSampler) {
        this.sampler = mCMCSampler;
        this.chain.add(this.sampler.getNewParams());
        this.objectChain = new LinkedList();
        this.blocks = new Vector();
        this.blocks.add(new MCMCBlock(((DblParamSet) this.chain.get(0)).parameterSet(), 1));
        this.chain_length = 100;
        this.unif = new UniformDistribution();
        this.defaultReporters = new HashMap();
        this.defaultReporters.put("MCMC", new DefaultReporter());
        this.reject_count = 0;
        this.accept_count = 0;
        this.haltSolution = false;
    }

    @Override // com.mockturtlesolutions.snifflib.invprobs.Optimization
    public void addOptimizationListener(OptimizationListener optimizationListener) {
        this.optimizationListeners.add(optimizationListener);
    }

    @Override // com.mockturtlesolutions.snifflib.invprobs.Optimization
    public void removeOptimizationListener(OptimizationListener optimizationListener) {
        this.optimizationListeners.remove(optimizationListener);
    }

    @Override // com.mockturtlesolutions.snifflib.invprobs.ObjectSampler
    public void addObjectToSample(String str) {
        this.sampler.addObjectToSample(str);
    }

    @Override // com.mockturtlesolutions.snifflib.invprobs.ObjectSampler
    public void removeObjectToSample(String str) {
        this.sampler.removeObjectToSample(str);
    }

    @Override // com.mockturtlesolutions.snifflib.invprobs.ObjectSampler
    public String[] getObjectsToSample() {
        return this.sampler.getObjectsToSample();
    }

    @Override // com.mockturtlesolutions.snifflib.invprobs.ObjectSampler
    public boolean isSamplingObject(String str) {
        return this.sampler.isSamplingObject(str);
    }

    @Override // com.mockturtlesolutions.snifflib.invprobs.ObjectSampler
    public void setObjectsToSample(String[] strArr) {
        this.sampler.setObjectsToSample(strArr);
    }

    public void haltSolution(boolean z) {
        this.haltSolution = z;
    }

    @Override // com.mockturtlesolutions.snifflib.graphics.DefaultReportable
    public DefaultReporter getDefaultReporter(String str) {
        return (DefaultReporter) this.defaultReporters.get(str);
    }

    @Override // com.mockturtlesolutions.snifflib.graphics.DefaultReportable
    public Set getDefaultReporters() {
        return this.defaultReporters.keySet();
    }

    public void setInitialGuess(DblParamSet dblParamSet) {
        this.chain = new LinkedList();
        this.chain.add(dblParamSet);
    }

    public DblParamSet getInitialGuess() {
        return (DblParamSet) this.chain.get(0);
    }

    public LinkedList getChain() {
        return this.chain;
    }

    public LinkedList getObjectChain() {
        return this.objectChain;
    }

    public void setChainLength(int i) {
        this.chain_length = i;
    }

    public DblParamSet getEstimate() {
        DblParamSet copy = ((DblParamSet) this.chain.get(0)).copy();
        for (int i = 1; i < this.chain.size(); i++) {
            copy = copy.plus((DblParamSet) this.chain.get(i));
        }
        return copy.divideBy(this.chain.size());
    }

    public DblParamSet getEstimate(int i) {
        int size = this.chain.size() - i;
        if (size < 0) {
            throw new IllegalArgumentException("Not enough elements in chaing to get desired estimate.");
        }
        DblParamSet copy = ((DblParamSet) this.chain.get(size)).copy();
        for (int i2 = 0; i2 < i; i2++) {
            copy = copy.plus((DblParamSet) this.chain.get(size + i2));
        }
        return copy.divideBy(i);
    }

    public DblParamSet getVariance() {
        DblParamSet copy = ((DblParamSet) this.chain.get(0)).copy();
        for (int i = 1; i < this.chain.size(); i++) {
            copy = copy.plus((DblParamSet) this.chain.get(i));
        }
        return copy.divideBy(this.chain.size());
    }

    public void addBlock(MCMCBlock mCMCBlock) {
        this.blocks.add(mCMCBlock);
    }

    public void setBurnIn(int i) {
        this.burn_in = i;
    }

    public void setMaxIterations(int i) {
        this.max_iterations = i;
    }

    public int getBurnIn(int i) {
        return this.burn_in;
    }

    public int getMaxIterations(int i) {
        return this.max_iterations;
    }

    public void removeBlock(MCMCBlock mCMCBlock) {
        this.blocks.remove(mCMCBlock);
    }

    public void removeAllBlocks() {
        this.blocks.removeAllElements();
    }

    public int getRejectCount() {
        return this.reject_count;
    }

    public int getAcceptCount() {
        return this.accept_count;
    }

    @Override // java.lang.Runnable
    public void run() {
        for (int i = 0; i < this.optimizationListeners.size(); i++) {
            try {
                ((OptimizationListener) this.optimizationListeners.get(i)).actionPerformed(new OptimizationEvent(1, this));
            } catch (Exception e) {
                haltSolution(true);
                this.stillRunning = false;
                for (int i2 = 0; i2 < this.optimizationListeners.size(); i2++) {
                    ((OptimizationListener) this.optimizationListeners.get(i2)).actionPerformed(new OptimizationEvent(2, this));
                }
                throw new RuntimeException("Problem optimizing via MCMC.", e);
            }
        }
        this.stillRunning = true;
        int i3 = 0;
        this.reject_count = 0;
        this.accept_count = 0;
        DblParamSet dblParamSet = (DblParamSet) this.chain.getLast();
        this.sampler.setPreviousParams(dblParamSet);
        this.sampler.setNewParams(dblParamSet);
        new DblMatrix(0.0d);
        DblMatrix alpha = this.sampler.getAlpha();
        this.objectChain.add(this.sampler.getNewObjectSamples());
        while (i3 < this.max_iterations) {
            for (int i4 = 0; i4 < this.optimizationListeners.size(); i4++) {
                ((OptimizationListener) this.optimizationListeners.get(i4)).actionPerformed(new OptimizationEvent(4, this, this.max_iterations, i3));
            }
            System.out.println("The MCMC max_iterations is:" + this.max_iterations);
            for (int i5 = 0; i5 < this.blocks.size() && !this.haltSolution; i5++) {
                MCMCBlock mCMCBlock = (MCMCBlock) this.blocks.get(i5);
                String[] parameters = mCMCBlock.getParameters();
                for (int i6 = 0; i6 < mCMCBlock.getCycles() && !this.haltSolution; i6++) {
                    DblParamSet dblParamSet2 = (DblParamSet) this.chain.getLast();
                    HashMap hashMap = (HashMap) this.objectChain.getLast();
                    DblParamSet sample = this.sampler.sample(parameters);
                    this.sampler.setNewParams(sample);
                    alpha = this.sampler.getAlpha();
                    HashMap newObjectSamples = this.sampler.getNewObjectSamples();
                    if (alpha == null) {
                        throw new RuntimeException("Null alpha returned from sampler.");
                    }
                    if (alpha.isEmpty()) {
                        throw new RuntimeException("Empty alpha returned from sampler.");
                    }
                    alpha.show("alpha", "0.00E0");
                    String name = mCMCBlock.getName();
                    String str = name == null ? "" : "(" + name + ")";
                    if (DblMatrix.test(alpha.gt(this.unif.random(1)))) {
                        System.out.println("MCMC: Step Accepted\tBlock " + i5 + " " + str);
                        this.accept_count++;
                        this.chain.add(sample);
                        this.objectChain.add(newObjectSamples);
                        this.sampler.setPreviousParams(sample);
                    } else {
                        System.out.println("MCMC: Step Rejected\tBlock " + i5 + " " + str);
                        this.reject_count++;
                        this.chain.add(dblParamSet2);
                        this.objectChain.add(hashMap);
                    }
                    while (this.chain.size() > this.chain_length) {
                        this.chain.remove(0);
                        this.objectChain.remove(0);
                    }
                }
            }
            DefaultReportInstance defaultReportInstance = new DefaultReportInstance();
            defaultReportInstance.setIteration(i3);
            defaultReportInstance.setParams((DblParamSet) this.chain.getLast());
            defaultReportInstance.setValue(alpha);
            defaultReportInstance.setValueName("alpha");
            DefaultReporter defaultReporter = getDefaultReporter("MCMC");
            if (defaultReporter.shouldAddToReport(defaultReportInstance)) {
                defaultReporter.addToReport(defaultReportInstance);
            }
            for (int i7 = 0; i7 < this.optimizationListeners.size(); i7++) {
                ((OptimizationListener) this.optimizationListeners.get(i7)).actionPerformed(new OptimizationEvent(3, this, this.max_iterations, i3));
            }
            i3++;
            if (this.haltSolution) {
                this.stillRunning = false;
                for (int i8 = 0; i8 < this.optimizationListeners.size(); i8++) {
                    ((OptimizationListener) this.optimizationListeners.get(i8)).actionPerformed(new OptimizationEvent(2, this));
                }
                return;
            }
        }
        this.stillRunning = false;
        for (int i9 = 0; i9 < this.optimizationListeners.size(); i9++) {
            ((OptimizationListener) this.optimizationListeners.get(i9)).actionPerformed(new OptimizationEvent(2, this));
        }
    }

    public boolean isStillRunning() {
        return this.stillRunning;
    }

    public void solve() throws SnifflibGraphicsException {
        this.stillRunning = true;
        new Thread(this).start();
    }
}
