001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math3.distribution; 018 019 import java.util.List; 020 import java.util.ArrayList; 021 import org.apache.commons.math3.exception.DimensionMismatchException; 022 import org.apache.commons.math3.exception.NotPositiveException; 023 import org.apache.commons.math3.exception.MathArithmeticException; 024 import org.apache.commons.math3.exception.util.LocalizedFormats; 025 import org.apache.commons.math3.random.RandomGenerator; 026 import org.apache.commons.math3.random.Well19937c; 027 import org.apache.commons.math3.util.Pair; 028 029 /** 030 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model"> 031 * mixture model</a> distributions. 032 * 033 * @param <T> Type of the mixture components. 034 * 035 * @version $Id: MixtureMultivariateRealDistribution.java 1416643 2012-12-03 19:37:14Z tn $ 036 * @since 3.1 037 */ 038 public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution> 039 extends AbstractMultivariateRealDistribution { 040 /** Normalized weight of each mixture component. */ 041 private final double[] weight; 042 /** Mixture components. */ 043 private final List<T> distribution; 044 045 /** 046 * Creates a mixture model from a list of distributions and their 047 * associated weights. 048 * 049 * @param components List of (weight, distribution) pairs from which to sample. 050 */ 051 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) { 052 this(new Well19937c(), components); 053 } 054 055 /** 056 * Creates a mixture model from a list of distributions and their 057 * associated weights. 058 * 059 * @param rng Random number generator. 060 * @param components Distributions from which to sample. 061 * @throws NotPositiveException if any of the weights is negative. 062 * @throws DimensionMismatchException if not all components have the same 063 * number of variables. 064 */ 065 public MixtureMultivariateRealDistribution(RandomGenerator rng, 066 List<Pair<Double, T>> components) { 067 super(rng, components.get(0).getSecond().getDimension()); 068 069 final int numComp = components.size(); 070 final int dim = getDimension(); 071 double weightSum = 0; 072 for (int i = 0; i < numComp; i++) { 073 final Pair<Double, T> comp = components.get(i); 074 if (comp.getSecond().getDimension() != dim) { 075 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim); 076 } 077 if (comp.getFirst() < 0) { 078 throw new NotPositiveException(comp.getFirst()); 079 } 080 weightSum += comp.getFirst(); 081 } 082 083 // Check for overflow. 084 if (Double.isInfinite(weightSum)) { 085 throw new MathArithmeticException(LocalizedFormats.OVERFLOW); 086 } 087 088 // Store each distribution and its normalized weight. 089 distribution = new ArrayList<T>(); 090 weight = new double[numComp]; 091 for (int i = 0; i < numComp; i++) { 092 final Pair<Double, T> comp = components.get(i); 093 weight[i] = comp.getFirst() / weightSum; 094 distribution.add(comp.getSecond()); 095 } 096 } 097 098 /** {@inheritDoc} */ 099 public double density(final double[] values) { 100 double p = 0; 101 for (int i = 0; i < weight.length; i++) { 102 p += weight[i] * distribution.get(i).density(values); 103 } 104 return p; 105 } 106 107 /** {@inheritDoc} */ 108 public double[] sample() { 109 // Sampled values. 110 double[] vals = null; 111 112 // Determine which component to sample from. 113 final double randomValue = random.nextDouble(); 114 double sum = 0; 115 116 for (int i = 0; i < weight.length; i++) { 117 sum += weight[i]; 118 if (randomValue <= sum) { 119 // pick model i 120 vals = distribution.get(i).sample(); 121 break; 122 } 123 } 124 125 if (vals == null) { 126 // This should never happen, but it ensures we won't return a null in 127 // case the loop above has some floating point inequality problem on 128 // the final iteration. 129 vals = distribution.get(weight.length - 1).sample(); 130 } 131 132 return vals; 133 } 134 135 /** {@inheritDoc} */ 136 public void reseedRandomGenerator(long seed) { 137 // Seed needs to be propagated to underlying components 138 // in order to maintain consistency between runs. 139 super.reseedRandomGenerator(seed); 140 141 for (int i = 0; i < distribution.size(); i++) { 142 // Make each component's seed different in order to avoid 143 // using the same sequence of random numbers. 144 distribution.get(i).reseedRandomGenerator(i + 1 + seed); 145 } 146 } 147 148 /** 149 * Gets the distributions that make up the mixture model. 150 * 151 * @return the component distributions and associated weights. 152 */ 153 public List<Pair<Double, T>> getComponents() { 154 final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(); 155 156 for (int i = 0; i < weight.length; i++) { 157 list.add(new Pair<Double, T>(weight[i], distribution.get(i))); 158 } 159 160 return list; 161 } 162 }