001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * https://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.statistics.distribution; 019 020import java.util.function.DoubleSupplier; 021import org.apache.commons.numbers.gamma.Erf; 022import org.apache.commons.numbers.gamma.ErfDifference; 023import org.apache.commons.numbers.gamma.Erfcx; 024import org.apache.commons.rng.UniformRandomProvider; 025import org.apache.commons.rng.sampling.distribution.ZigguratSampler; 026 027/** 028 * Implementation of the truncated normal distribution. 029 * 030 * <p>The probability density function of \( X \) is: 031 * 032 * <p>\[ f(x;\mu,\sigma,a,b) = \frac{1}{\sigma}\,\frac{\phi(\frac{x - \mu}{\sigma})}{\Phi(\frac{b - \mu}{\sigma}) - \Phi(\frac{a - \mu}{\sigma}) } \] 033 * 034 * <p>for \( \mu \) mean of the parent normal distribution, 035 * \( \sigma \) standard deviation of the parent normal distribution, 036 * \( -\infty \le a \lt b \le \infty \) the truncation interval, and 037 * \( x \in [a, b] \), where \( \phi \) is the probability 038 * density function of the standard normal distribution and \( \Phi \) 039 * is its cumulative distribution function. 040 * 041 * @see <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution"> 042 * Truncated normal distribution (Wikipedia)</a> 043 */ 044public final class TruncatedNormalDistribution extends AbstractContinuousDistribution { 045 046 /** The max allowed value for x where (x*x) will not overflow. 047 * This is a limit on computation of the moments of the truncated normal 048 * as some calculations assume x*x is finite. Value is sqrt(MAX_VALUE). */ 049 private static final double MAX_X = 0x1.fffffffffffffp511; 050 051 /** The min allowed probability range of the parent normal distribution. 052 * Set to 0.0. This may be too low for accurate usage. It is a signal that 053 * the truncation is invalid. */ 054 private static final double MIN_P = 0.0; 055 056 /** sqrt(2). */ 057 private static final double ROOT2 = Constants.ROOT_TWO; 058 /** Normalisation constant 2 / sqrt(2 pi) = sqrt(2 / pi). */ 059 private static final double ROOT_2_PI = Constants.ROOT_TWO_DIV_PI; 060 /** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */ 061 private static final double ROOT_PI_2 = Constants.ROOT_PI_DIV_TWO; 062 063 /** 064 * The threshold to switch to a rejection sampler. When the truncated 065 * distribution covers more than this fraction of the CDF then rejection 066 * sampling will be more efficient than inverse CDF sampling. Performance 067 * benchmarks indicate that a normalized Gaussian sampler is up to 10 times 068 * faster than inverse transform sampling using a fast random generator. See 069 * STATISTICS-55. 070 */ 071 private static final double REJECTION_THRESHOLD = 0.2; 072 073 /** Parent normal distribution. */ 074 private final NormalDistribution parentNormal; 075 /** Lower bound of this distribution. */ 076 private final double lower; 077 /** Upper bound of this distribution. */ 078 private final double upper; 079 080 /** Stored value of {@code parentNormal.probability(lower, upper)}. This is used to 081 * normalise the probability computations. */ 082 private final double cdfDelta; 083 /** log(cdfDelta). */ 084 private final double logCdfDelta; 085 /** Stored value of {@code parentNormal.cumulativeProbability(lower)}. Used to map 086 * a probability into the range of the parent normal distribution. */ 087 private final double cdfAlpha; 088 /** Stored value of {@code parentNormal.survivalProbability(upper)}. Used to map 089 * a probability into the range of the parent normal distribution. */ 090 private final double sfBeta; 091 092 /** 093 * @param parent Parent distribution. 094 * @param z Probability of the parent distribution for {@code [lower, upper]}. 095 * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}. 096 * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}. 097 */ 098 private TruncatedNormalDistribution(NormalDistribution parent, double z, double lower, double upper) { 099 this.parentNormal = parent; 100 this.lower = lower; 101 this.upper = upper; 102 103 cdfDelta = z; 104 logCdfDelta = Math.log(cdfDelta); 105 // Used to map the inverse probability. 106 cdfAlpha = parentNormal.cumulativeProbability(lower); 107 sfBeta = parentNormal.survivalProbability(upper); 108 } 109 110 /** 111 * Creates a truncated normal distribution. 112 * 113 * <p>Note that the {@code mean} and {@code sd} is of the parent normal distribution, 114 * and not the true mean and standard deviation of the truncated normal distribution. 115 * The {@code lower} and {@code upper} bounds define the truncation of the parent 116 * normal distribution. 117 * 118 * @param mean Mean for the parent distribution. 119 * @param sd Standard deviation for the parent distribution. 120 * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}. 121 * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}. 122 * @return the distribution 123 * @throws IllegalArgumentException if {@code sd <= 0}; if {@code lower >= upper}; or if 124 * the truncation covers no probability range in the parent distribution. 125 */ 126 public static TruncatedNormalDistribution of(double mean, double sd, double lower, double upper) { 127 if (sd <= 0) { 128 throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd); 129 } 130 if (lower >= upper) { 131 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GTE_HIGH, lower, upper); 132 } 133 134 // Use an instance for the parent normal distribution to maximise accuracy 135 // in range computations using the error function 136 final NormalDistribution parent = NormalDistribution.of(mean, sd); 137 138 // If there is no computable range then raise an exception. 139 final double z = parent.probability(lower, upper); 140 if (z <= MIN_P) { 141 // Map the bounds to a standard normal distribution for the message 142 final double a = (lower - mean) / sd; 143 final double b = (upper - mean) / sd; 144 throw new DistributionException( 145 "Excess truncation of standard normal : CDF(%s, %s) = %s", a, b, z); 146 } 147 148 // Here we have a meaningful truncation. Note that excess truncation may not be optimal. 149 // For example truncation close to zero where the PDF is constant can be approximated 150 // using a uniform distribution. 151 152 return new TruncatedNormalDistribution(parent, z, lower, upper); 153 } 154 155 /** 156 * Gets the mean for the parent distribution. 157 * 158 * <p>Note that the mean is of the parent normal distribution, 159 * and not the true mean of the truncated normal distribution. 160 * This is the {@code mean} parameter used to construct the truncated distribution. 161 * 162 * @return the parent mean. 163 * @see #getMean 164 * @since 1.3 165 */ 166 public double getParentMean() { 167 return parentNormal.getMean(); 168 } 169 170 /** 171 * Gets the standard deviation for the parent distribution. 172 * 173 * <p>Note that the standard deviation (SD) is of the parent normal distribution, 174 * and not the true standard deviation of the truncated normal distribution. 175 * This is the {@code sd} parameter used to construct the truncated distribution. 176 * 177 * @return the parent standard deviation. 178 * @since 1.3 179 */ 180 public double getParentStandardDeviation() { 181 return parentNormal.getStandardDeviation(); 182 } 183 184 /** {@inheritDoc} */ 185 @Override 186 public double density(double x) { 187 if (x < lower || x > upper) { 188 return 0; 189 } 190 return parentNormal.density(x) / cdfDelta; 191 } 192 193 /** {@inheritDoc} */ 194 @Override 195 public double probability(double x0, double x1) { 196 if (x0 > x1) { 197 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, 198 x0, x1); 199 } 200 return parentNormal.probability(clipToRange(x0), clipToRange(x1)) / cdfDelta; 201 } 202 203 /** {@inheritDoc} */ 204 @Override 205 public double logDensity(double x) { 206 if (x < lower || x > upper) { 207 return Double.NEGATIVE_INFINITY; 208 } 209 return parentNormal.logDensity(x) - logCdfDelta; 210 } 211 212 /** {@inheritDoc} */ 213 @Override 214 public double cumulativeProbability(double x) { 215 if (x <= lower) { 216 return 0; 217 } else if (x >= upper) { 218 return 1; 219 } 220 return parentNormal.probability(lower, x) / cdfDelta; 221 } 222 223 /** {@inheritDoc} */ 224 @Override 225 public double survivalProbability(double x) { 226 if (x <= lower) { 227 return 1; 228 } else if (x >= upper) { 229 return 0; 230 } 231 return parentNormal.probability(x, upper) / cdfDelta; 232 } 233 234 /** {@inheritDoc} */ 235 @Override 236 public double inverseCumulativeProbability(double p) { 237 ArgumentUtils.checkProbability(p); 238 // Exact bound 239 if (p == 0) { 240 return lower; 241 } else if (p == 1) { 242 return upper; 243 } 244 // Linearly map p to the range [lower, upper] 245 final double x = parentNormal.inverseCumulativeProbability(cdfAlpha + p * cdfDelta); 246 return clipToRange(x); 247 } 248 249 /** {@inheritDoc} */ 250 @Override 251 public double inverseSurvivalProbability(double p) { 252 ArgumentUtils.checkProbability(p); 253 // Exact bound 254 if (p == 1) { 255 return lower; 256 } else if (p == 0) { 257 return upper; 258 } 259 // Linearly map p to the range [lower, upper] 260 final double x = parentNormal.inverseSurvivalProbability(sfBeta + p * cdfDelta); 261 return clipToRange(x); 262 } 263 264 /** {@inheritDoc} */ 265 @Override 266 public Sampler createSampler(UniformRandomProvider rng) { 267 // Map the bounds to a standard normal distribution 268 final double u = parentNormal.getMean(); 269 final double s = parentNormal.getStandardDeviation(); 270 final double a = (lower - u) / s; 271 final double b = (upper - u) / s; 272 // If the truncation covers a reasonable amount of the normal distribution 273 // then a rejection sampler can be used. 274 double threshold = REJECTION_THRESHOLD; 275 // If the truncation is entirely in the upper or lower half then adjust the 276 // threshold as twice the samples can be used 277 if (a >= 0 || b <= 0) { 278 threshold *= 0.5; 279 } 280 281 if (cdfDelta > threshold) { 282 // Create the rejection sampler 283 final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng); 284 final DoubleSupplier gen; 285 // Use mirroring if possible 286 if (a >= 0) { 287 // Return the upper-half of the Gaussian 288 gen = () -> Math.abs(sampler.sample()); 289 } else if (b <= 0) { 290 // Return the lower-half of the Gaussian 291 gen = () -> -Math.abs(sampler.sample()); 292 } else { 293 // Return the full range of the Gaussian 294 gen = sampler::sample; 295 } 296 // Sample in [a, b] using rejection 297 return () -> { 298 double x = gen.getAsDouble(); 299 while (x < a || x > b) { 300 x = gen.getAsDouble(); 301 } 302 // Avoid floating-point error when mapping back 303 return clipToRange(u + x * s); 304 }; 305 } 306 307 // Default to an inverse CDF sampler 308 return super.createSampler(rng); 309 } 310 311 /** 312 * {@inheritDoc} 313 * 314 * <p>Represents the true mean of the truncated normal distribution rather 315 * than the parent normal distribution mean. 316 * 317 * <p>For \( \mu \) mean of the parent normal distribution, 318 * \( \sigma \) standard deviation of the parent normal distribution, and 319 * \( a \lt b \) the truncation interval of the parent normal distribution, the mean is: 320 * 321 * <p>\[ \mu + \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)}\sigma \] 322 * 323 * <p>where \( \phi \) is the probability density function of the standard normal distribution 324 * and \( \Phi \) is its cumulative distribution function. 325 */ 326 @Override 327 public double getMean() { 328 final double u = parentNormal.getMean(); 329 final double s = parentNormal.getStandardDeviation(); 330 final double a = (lower - u) / s; 331 final double b = (upper - u) / s; 332 return u + moment1(a, b) * s; 333 } 334 335 /** 336 * {@inheritDoc} 337 * 338 * <p>Represents the true variance of the truncated normal distribution rather 339 * than the parent normal distribution variance. 340 * 341 * <p>For \( \mu \) mean of the parent normal distribution, 342 * \( \sigma \) standard deviation of the parent normal distribution, and 343 * \( a \lt b \) the truncation interval of the parent normal distribution, the variance is: 344 * 345 * <p>\[ \sigma^2 \left[1 + \frac{a\phi(a)-b\phi(b)}{\Phi(b) - \Phi(a)} - 346 * \left( \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)} \right)^2 \right] \] 347 * 348 * <p>where \( \phi \) is the probability density function of the standard normal distribution 349 * and \( \Phi \) is its cumulative distribution function. 350 */ 351 @Override 352 public double getVariance() { 353 final double u = parentNormal.getMean(); 354 final double s = parentNormal.getStandardDeviation(); 355 final double a = (lower - u) / s; 356 final double b = (upper - u) / s; 357 return variance(a, b) * s * s; 358 } 359 360 /** 361 * {@inheritDoc} 362 * 363 * <p>The lower bound of the support is equal to the lower bound parameter 364 * of the distribution. 365 */ 366 @Override 367 public double getSupportLowerBound() { 368 return lower; 369 } 370 371 /** 372 * {@inheritDoc} 373 * 374 * <p>The upper bound of the support is equal to the upper bound parameter 375 * of the distribution. 376 */ 377 @Override 378 public double getSupportUpperBound() { 379 return upper; 380 } 381 382 /** 383 * Clip the value to the range [lower, upper]. 384 * This is used to handle floating-point error at the support bound. 385 * 386 * @param x Value x 387 * @return x clipped to the range 388 */ 389 private double clipToRange(double x) { 390 return clip(x, lower, upper); 391 } 392 393 /** 394 * Clip the value to the range [lower, upper]. 395 * 396 * @param x Value x 397 * @param lower Lower bound (inclusive) 398 * @param upper Upper bound (inclusive) 399 * @return x clipped to the range 400 */ 401 private static double clip(double x, double lower, double upper) { 402 if (x <= lower) { 403 return lower; 404 } 405 return x < upper ? x : upper; 406 } 407 408 // Calculation of variance and mean can suffer from cancellation. 409 // 410 // Use formulas from Jorge Fernandez-de-Cossio-Diaz adapted under the 411 // terms of the MIT "Expat" License (see NOTICE and LICENSE). 412 // 413 // These formulas use the complementary error function 414 // erfcx(z) = erfc(z) * exp(z^2) 415 // This avoids computation of exp terms for the Gaussian PDF and then 416 // dividing by the error functions erf or erfc: 417 // exp(-0.5*x*x) / erfc(x / sqrt(2)) == 1 / erfcx(x / sqrt(2)) 418 // At large z the erfcx function is computable but exp(-0.5*z*z) and 419 // erfc(z) are zero. Use of these formulas allows computation of the 420 // mean and variance for the usable range of the truncated distribution 421 // (cdf(a, b) != 0). The variance is not accurate when it approaches 422 // machine epsilon (2^-52) at extremely narrow truncations and the 423 // computation -> 0. 424 // 425 // See: https://github.com/cossio/TruncatedNormal.jl 426 427 /** 428 * Compute the first moment (mean) of the truncated standard normal distribution. 429 * 430 * <p>Assumes {@code a <= b}. 431 * 432 * @param a Lower bound 433 * @param b Upper bound 434 * @return the first moment 435 */ 436 static double moment1(double a, double b) { 437 // Assume a <= b 438 if (a == b) { 439 return a; 440 } 441 if (Math.abs(a) > Math.abs(b)) { 442 // Subtract from zero to avoid generating -0.0 443 return 0 - moment1(-b, -a); 444 } 445 446 // Here: 447 // |a| <= |b| 448 // a < b 449 // 0 < b 450 451 if (a <= -MAX_X) { 452 // No truncation 453 return 0; 454 } 455 if (b >= MAX_X) { 456 // One-sided truncation 457 return ROOT_2_PI / Erfcx.value(a / ROOT2); 458 } 459 460 // pdf = exp(-0.5*x*x) / sqrt(2*pi) 461 // cdf = erfc(-x/sqrt(2)) / 2 462 // Compute: 463 // -(pdf(b) - pdf(a)) / cdf(b, a) 464 // Note: 465 // exp(-0.5*b*b) - exp(-0.5*a*a) 466 // Use cancellation of powers: 467 // exp(-0.5*(b*b-a*a)) * exp(-0.5*a*a) - exp(-0.5*a*a) 468 // expm1(-0.5*(b*b-a*a)) * exp(-0.5*a*a) 469 470 // dx = -0.5*(b*b-a*a) 471 final double dx = 0.5 * (b + a) * (b - a); 472 final double m; 473 if (a <= 0) { 474 // Opposite signs 475 m = ROOT_2_PI * -Math.expm1(-dx) * Math.exp(-0.5 * a * a) / ErfDifference.value(a / ROOT2, b / ROOT2); 476 } else { 477 final double z = Math.exp(-dx) * Erfcx.value(b / ROOT2) - Erfcx.value(a / ROOT2); 478 if (z == 0) { 479 // Occurs when a and b have large magnitudes and are very close 480 return (a + b) * 0.5; 481 } 482 m = ROOT_2_PI * Math.expm1(-dx) / z; 483 } 484 485 // Clip to the range 486 return clip(m, a, b); 487 } 488 489 /** 490 * Compute the second moment of the truncated standard normal distribution. 491 * 492 * <p>Assumes {@code a <= b}. 493 * 494 * @param a Lower bound 495 * @param b Upper bound 496 * @return the first moment 497 */ 498 private static double moment2(double a, double b) { 499 // Assume a < b. 500 // a == b is handled in the variance method 501 if (Math.abs(a) > Math.abs(b)) { 502 return moment2(-b, -a); 503 } 504 505 // Here: 506 // |a| <= |b| 507 // a < b 508 // 0 < b 509 510 if (a <= -MAX_X) { 511 // No truncation 512 return 1; 513 } 514 if (b >= MAX_X) { 515 // One-sided truncation. 516 // For a -> inf : moment2 -> a*a 517 // This occurs when erfcx(z) is approximated by (1/sqrt(pi)) / z and terms 518 // cancel. z > 6.71e7, a > 9.49e7 519 return 1 + ROOT_2_PI * a / Erfcx.value(a / ROOT2); 520 } 521 522 // pdf = exp(-0.5*x*x) / sqrt(2*pi) 523 // cdf = erfc(-x/sqrt(2)) / 2 524 // Compute: 525 // 1 - (b*pdf(b) - a*pdf(a)) / cdf(b, a) 526 // = (cdf(b, a) - b*pdf(b) -a*pdf(a)) / cdf(b, a) 527 528 // Note: 529 // For z -> 0: 530 // sqrt(pi / 2) * erf(z / sqrt(2)) -> z 531 // z * Math.exp(-0.5 * z * z) -> z 532 // Both computations below have cancellation as b -> 0 and the 533 // second moment is not computable as the fraction P/Q 534 // since P < ulp(Q). This always occurs when b < MIN_X 535 // if MIN_X is set at the point where 536 // exp(-0.5 * z * z) / sqrt(2 pi) == 1 / sqrt(2 pi). 537 // This is JDK dependent due to variations in Math.exp. 538 // For b < MIN_X the second moment can be approximated using 539 // a uniform distribution: (b^3 - a^3) / (3b - 3a). 540 // In practice it also occurs when b > MIN_X since any a < MIN_X 541 // is effectively zero for part of the computation. A 542 // threshold to transition to a uniform distribution 543 // approximation is a compromise. Also note it will not 544 // correct computation when (b-a) is small and is far from 0. 545 // Thus the second moment is left to be inaccurate for 546 // small ranges (b-a) and the variance -> 0 when the true 547 // variance is close to or below machine epsilon. 548 549 double m; 550 551 if (a <= 0) { 552 // Opposite signs 553 final double ea = ROOT_PI_2 * Erf.value(a / ROOT2); 554 final double eb = ROOT_PI_2 * Erf.value(b / ROOT2); 555 final double fa = ea - a * Math.exp(-0.5 * a * a); 556 final double fb = eb - b * Math.exp(-0.5 * b * b); 557 // Assume fb >= fa && eb >= ea 558 // If fb <= fa this is a tiny range around 0 559 m = (fb - fa) / (eb - ea); 560 // Clip to the range 561 m = clip(m, 0, 1); 562 } else { 563 final double dx = 0.5 * (b + a) * (b - a); 564 final double ex = Math.exp(-dx); 565 final double ea = ROOT_PI_2 * Erfcx.value(a / ROOT2); 566 final double eb = ROOT_PI_2 * Erfcx.value(b / ROOT2); 567 final double fa = ea + a; 568 final double fb = eb + b; 569 m = (fa - fb * ex) / (ea - eb * ex); 570 // Clip to the range 571 m = clip(m, a * a, b * b); 572 } 573 return m; 574 } 575 576 /** 577 * Compute the variance of the truncated standard normal distribution. 578 * 579 * <p>Assumes {@code a <= b}. 580 * 581 * @param a Lower bound 582 * @param b Upper bound 583 * @return the first moment 584 */ 585 static double variance(double a, double b) { 586 if (a == b) { 587 return 0; 588 } 589 590 final double m1 = moment1(a, b); 591 double m2 = moment2(a, b); 592 // variance = m2 - m1*m1 593 // rearrange x^2 - y^2 as (x-y)(x+y) 594 m2 = Math.sqrt(m2); 595 final double variance = (m2 - m1) * (m2 + m1); 596 597 // Detect floating-point error. 598 if (variance >= 1) { 599 // Note: 600 // Extreme truncations in the tails can compute a variance above 1, 601 // for example if m2 is infinite: m2 - m1*m1 > 1 602 // Detect no truncation as the terms a and b lie far either side of zero; 603 // otherwise return 0 to indicate very small unknown variance. 604 return a < -1 && b > 1 ? 1 : 0; 605 } else if (variance <= 0) { 606 // Floating-point error can create negative variance so return 0. 607 return 0; 608 } 609 610 return variance; 611 } 612}