forked from Rust-related/RustPython
Merge pull request #3552 from fanninpm/statistics-3.10
Update statistics.py to CPython 3.10
This commit is contained in:
415
Lib/statistics.py
vendored
415
Lib/statistics.py
vendored
@@ -73,6 +73,30 @@ second argument to the four "spread" functions to avoid recalculating it:
|
||||
2.5
|
||||
|
||||
|
||||
Statistics for relations between two inputs
|
||||
-------------------------------------------
|
||||
|
||||
================== ====================================================
|
||||
Function Description
|
||||
================== ====================================================
|
||||
covariance Sample covariance for two variables.
|
||||
correlation Pearson's correlation coefficient for two variables.
|
||||
linear_regression Intercept and slope for simple linear regression.
|
||||
================== ====================================================
|
||||
|
||||
Calculate covariance, Pearson's correlation, and simple linear regression
|
||||
for two inputs:
|
||||
|
||||
>>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
>>> y = [1, 2, 3, 1, 2, 3, 1, 2, 3]
|
||||
>>> covariance(x, y)
|
||||
0.75
|
||||
>>> correlation(x, y) #doctest: +ELLIPSIS
|
||||
0.31622776601...
|
||||
>>> linear_regression(x, y) #doctest:
|
||||
LinearRegression(slope=0.1, intercept=1.5)
|
||||
|
||||
|
||||
Exceptions
|
||||
----------
|
||||
|
||||
@@ -83,9 +107,12 @@ A single exception is defined: StatisticsError is a subclass of ValueError.
|
||||
__all__ = [
|
||||
'NormalDist',
|
||||
'StatisticsError',
|
||||
'correlation',
|
||||
'covariance',
|
||||
'fmean',
|
||||
'geometric_mean',
|
||||
'harmonic_mean',
|
||||
'linear_regression',
|
||||
'mean',
|
||||
'median',
|
||||
'median_grouped',
|
||||
@@ -106,11 +133,11 @@ import random
|
||||
|
||||
from fractions import Fraction
|
||||
from decimal import Decimal
|
||||
from itertools import groupby
|
||||
from itertools import groupby, repeat
|
||||
from bisect import bisect_left, bisect_right
|
||||
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
|
||||
from operator import itemgetter
|
||||
from collections import Counter
|
||||
from collections import Counter, namedtuple
|
||||
|
||||
# === Exceptions ===
|
||||
|
||||
@@ -120,21 +147,17 @@ class StatisticsError(ValueError):
|
||||
|
||||
# === Private utilities ===
|
||||
|
||||
def _sum(data, start=0):
|
||||
"""_sum(data [, start]) -> (type, sum, count)
|
||||
def _sum(data):
|
||||
"""_sum(data) -> (type, sum, count)
|
||||
|
||||
Return a high-precision sum of the given numeric data as a fraction,
|
||||
together with the type to be converted to and the count of items.
|
||||
|
||||
If optional argument ``start`` is given, it is added to the total.
|
||||
If ``data`` is empty, ``start`` (defaulting to 0) is returned.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
|
||||
(<class 'float'>, Fraction(11, 1), 5)
|
||||
>>> _sum([3, 2.25, 4.5, -0.5, 0.25])
|
||||
(<class 'float'>, Fraction(19, 2), 5)
|
||||
|
||||
Some sources of round-off error will be avoided:
|
||||
|
||||
@@ -157,13 +180,12 @@ def _sum(data, start=0):
|
||||
allowed.
|
||||
"""
|
||||
count = 0
|
||||
n, d = _exact_ratio(start)
|
||||
partials = {d: n}
|
||||
partials = {}
|
||||
partials_get = partials.get
|
||||
T = _coerce(int, type(start))
|
||||
T = int
|
||||
for typ, values in groupby(data, type):
|
||||
T = _coerce(T, typ) # or raise TypeError
|
||||
for n,d in map(_exact_ratio, values):
|
||||
for n, d in map(_exact_ratio, values):
|
||||
count += 1
|
||||
partials[d] = partials_get(d, 0) + n
|
||||
if None in partials:
|
||||
@@ -173,8 +195,7 @@ def _sum(data, start=0):
|
||||
assert not _isfinite(total)
|
||||
else:
|
||||
# Sum all the partial sums using builtin sum.
|
||||
# FIXME is this faster if we sum them in order of the denominator?
|
||||
total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
|
||||
total = sum(Fraction(n, d) for d, n in partials.items())
|
||||
return (T, total, count)
|
||||
|
||||
|
||||
@@ -225,27 +246,19 @@ def _exact_ratio(x):
|
||||
x is expected to be an int, Fraction, Decimal or float.
|
||||
"""
|
||||
try:
|
||||
# Optimise the common case of floats. We expect that the most often
|
||||
# used numeric type will be builtin floats, so try to make this as
|
||||
# fast as possible.
|
||||
if type(x) is float or type(x) is Decimal:
|
||||
return x.as_integer_ratio()
|
||||
try:
|
||||
# x may be an int, Fraction, or Integral ABC.
|
||||
return (x.numerator, x.denominator)
|
||||
except AttributeError:
|
||||
try:
|
||||
# x may be a float or Decimal subclass.
|
||||
return x.as_integer_ratio()
|
||||
except AttributeError:
|
||||
# Just give up?
|
||||
pass
|
||||
return x.as_integer_ratio()
|
||||
except AttributeError:
|
||||
pass
|
||||
except (OverflowError, ValueError):
|
||||
# float NAN or INF.
|
||||
assert not _isfinite(x)
|
||||
return (x, None)
|
||||
msg = "can't convert type '{}' to numerator/denominator"
|
||||
raise TypeError(msg.format(type(x).__name__))
|
||||
try:
|
||||
# x may be an Integral ABC.
|
||||
return (x.numerator, x.denominator)
|
||||
except AttributeError:
|
||||
msg = f"can't convert type '{type(x).__name__}' to numerator/denominator"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def _convert(value, T):
|
||||
@@ -261,7 +274,7 @@ def _convert(value, T):
|
||||
return T(value)
|
||||
except TypeError:
|
||||
if issubclass(T, Decimal):
|
||||
return T(value.numerator)/T(value.denominator)
|
||||
return T(value.numerator) / T(value.denominator)
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -277,8 +290,8 @@ def _find_lteq(a, x):
|
||||
def _find_rteq(a, l, x):
|
||||
'Locate the rightmost value exactly equal to x'
|
||||
i = bisect_right(a, x, lo=l)
|
||||
if i != (len(a)+1) and a[i-1] == x:
|
||||
return i-1
|
||||
if i != (len(a) + 1) and a[i - 1] == x:
|
||||
return i - 1
|
||||
raise ValueError
|
||||
|
||||
|
||||
@@ -315,7 +328,7 @@ def mean(data):
|
||||
raise StatisticsError('mean requires at least one data point')
|
||||
T, total, count = _sum(data)
|
||||
assert count == n
|
||||
return _convert(total/n, T)
|
||||
return _convert(total / n, T)
|
||||
|
||||
|
||||
def fmean(data):
|
||||
@@ -361,40 +374,39 @@ def geometric_mean(data):
|
||||
return exp(fmean(map(log, data)))
|
||||
except ValueError:
|
||||
raise StatisticsError('geometric mean requires a non-empty dataset '
|
||||
' containing positive numbers') from None
|
||||
'containing positive numbers') from None
|
||||
|
||||
|
||||
def harmonic_mean(data):
|
||||
def harmonic_mean(data, weights=None):
|
||||
"""Return the harmonic mean of data.
|
||||
|
||||
The harmonic mean, sometimes called the subcontrary mean, is the
|
||||
reciprocal of the arithmetic mean of the reciprocals of the data,
|
||||
and is often appropriate when averaging quantities which are rates
|
||||
or ratios, for example speeds. Example:
|
||||
The harmonic mean is the reciprocal of the arithmetic mean of the
|
||||
reciprocals of the data. It can be used for averaging ratios or
|
||||
rates, for example speeds.
|
||||
|
||||
Suppose an investor purchases an equal value of shares in each of
|
||||
three companies, with P/E (price/earning) ratios of 2.5, 3 and 10.
|
||||
What is the average P/E ratio for the investor's portfolio?
|
||||
Suppose a car travels 40 km/hr for 5 km and then speeds-up to
|
||||
60 km/hr for another 5 km. What is the average speed?
|
||||
|
||||
>>> harmonic_mean([2.5, 3, 10]) # For an equal investment portfolio.
|
||||
3.6
|
||||
>>> harmonic_mean([40, 60])
|
||||
48.0
|
||||
|
||||
Using the arithmetic mean would give an average of about 5.167, which
|
||||
is too high.
|
||||
Suppose a car travels 40 km/hr for 5 km, and when traffic clears,
|
||||
speeds-up to 60 km/hr for the remaining 30 km of the journey. What
|
||||
is the average speed?
|
||||
|
||||
>>> harmonic_mean([40, 60], weights=[5, 30])
|
||||
56.0
|
||||
|
||||
If ``data`` is empty, or any element is less than zero,
|
||||
``harmonic_mean`` will raise ``StatisticsError``.
|
||||
"""
|
||||
# For a justification for using harmonic mean for P/E ratios, see
|
||||
# http://fixthepitch.pellucid.com/comps-analysis-the-missing-harmony-of-summary-statistics/
|
||||
# http://papers.ssrn.com/sol3/papers.cfm?abstract_id=2621087
|
||||
if iter(data) is data:
|
||||
data = list(data)
|
||||
errmsg = 'harmonic mean does not support negative values'
|
||||
n = len(data)
|
||||
if n < 1:
|
||||
raise StatisticsError('harmonic_mean requires at least one data point')
|
||||
elif n == 1:
|
||||
elif n == 1 and weights is None:
|
||||
x = data[0]
|
||||
if isinstance(x, (numbers.Real, Decimal)):
|
||||
if x < 0:
|
||||
@@ -402,13 +414,23 @@ def harmonic_mean(data):
|
||||
return x
|
||||
else:
|
||||
raise TypeError('unsupported type')
|
||||
if weights is None:
|
||||
weights = repeat(1, n)
|
||||
sum_weights = n
|
||||
else:
|
||||
if iter(weights) is weights:
|
||||
weights = list(weights)
|
||||
if len(weights) != n:
|
||||
raise StatisticsError('Number of weights does not match data size')
|
||||
_, sum_weights, _ = _sum(w for w in _fail_neg(weights, errmsg))
|
||||
try:
|
||||
T, total, count = _sum(1/x for x in _fail_neg(data, errmsg))
|
||||
data = _fail_neg(data, errmsg)
|
||||
T, total, count = _sum(w / x if w else 0 for w, x in zip(weights, data))
|
||||
except ZeroDivisionError:
|
||||
return 0
|
||||
assert count == n
|
||||
return _convert(n/total, T)
|
||||
|
||||
if total <= 0:
|
||||
raise StatisticsError('Weighted sum must be positive')
|
||||
return _convert(sum_weights / total, T)
|
||||
|
||||
# FIXME: investigate ways to calculate medians without sorting? Quickselect?
|
||||
def median(data):
|
||||
@@ -428,11 +450,11 @@ def median(data):
|
||||
n = len(data)
|
||||
if n == 0:
|
||||
raise StatisticsError("no median for empty data")
|
||||
if n%2 == 1:
|
||||
return data[n//2]
|
||||
if n % 2 == 1:
|
||||
return data[n // 2]
|
||||
else:
|
||||
i = n//2
|
||||
return (data[i - 1] + data[i])/2
|
||||
i = n // 2
|
||||
return (data[i - 1] + data[i]) / 2
|
||||
|
||||
|
||||
def median_low(data):
|
||||
@@ -451,10 +473,10 @@ def median_low(data):
|
||||
n = len(data)
|
||||
if n == 0:
|
||||
raise StatisticsError("no median for empty data")
|
||||
if n%2 == 1:
|
||||
return data[n//2]
|
||||
if n % 2 == 1:
|
||||
return data[n // 2]
|
||||
else:
|
||||
return data[n//2 - 1]
|
||||
return data[n // 2 - 1]
|
||||
|
||||
|
||||
def median_high(data):
|
||||
@@ -473,7 +495,7 @@ def median_high(data):
|
||||
n = len(data)
|
||||
if n == 0:
|
||||
raise StatisticsError("no median for empty data")
|
||||
return data[n//2]
|
||||
return data[n // 2]
|
||||
|
||||
|
||||
def median_grouped(data, interval=1):
|
||||
@@ -510,15 +532,15 @@ def median_grouped(data, interval=1):
|
||||
return data[0]
|
||||
# Find the value at the midpoint. Remember this corresponds to the
|
||||
# centre of the class interval.
|
||||
x = data[n//2]
|
||||
x = data[n // 2]
|
||||
for obj in (x, interval):
|
||||
if isinstance(obj, (str, bytes)):
|
||||
raise TypeError('expected number but got %r' % obj)
|
||||
try:
|
||||
L = x - interval/2 # The lower limit of the median interval.
|
||||
L = x - interval / 2 # The lower limit of the median interval.
|
||||
except TypeError:
|
||||
# Mixed type. For now we just coerce to float.
|
||||
L = float(x) - float(interval)/2
|
||||
L = float(x) - float(interval) / 2
|
||||
|
||||
# Uses bisection search to search for x in data with log(n) time complexity
|
||||
# Find the position of leftmost occurrence of x in data
|
||||
@@ -528,7 +550,7 @@ def median_grouped(data, interval=1):
|
||||
l2 = _find_rteq(data, l1, x)
|
||||
cf = l1
|
||||
f = l2 - l1 + 1
|
||||
return L + interval*(n/2 - cf)/f
|
||||
return L + interval * (n / 2 - cf) / f
|
||||
|
||||
|
||||
def mode(data):
|
||||
@@ -554,8 +576,7 @@ def mode(data):
|
||||
If *data* is empty, ``mode``, raises StatisticsError.
|
||||
|
||||
"""
|
||||
data = iter(data)
|
||||
pairs = Counter(data).most_common(1)
|
||||
pairs = Counter(iter(data)).most_common(1)
|
||||
try:
|
||||
return pairs[0][0]
|
||||
except IndexError:
|
||||
@@ -597,7 +618,7 @@ def multimode(data):
|
||||
# For sample data where there is a positive probability for values
|
||||
# beyond the range of the data, the R6 exclusive method is a
|
||||
# reasonable choice. Consider a random sample of nine values from a
|
||||
# population with a uniform distribution from 0.0 to 100.0. The
|
||||
# population with a uniform distribution from 0.0 to 1.0. The
|
||||
# distribution of the third ranked sample point is described by
|
||||
# betavariate(alpha=3, beta=7) which has mode=0.250, median=0.286, and
|
||||
# mean=0.300. Only the latter (which corresponds with R6) gives the
|
||||
@@ -643,9 +664,8 @@ def quantiles(data, *, n=4, method='exclusive'):
|
||||
m = ld - 1
|
||||
result = []
|
||||
for i in range(1, n):
|
||||
j = i * m // n
|
||||
delta = i*m - j*n
|
||||
interpolated = (data[j] * (n - delta) + data[j+1] * delta) / n
|
||||
j, delta = divmod(i * m, n)
|
||||
interpolated = (data[j] * (n - delta) + data[j + 1] * delta) / n
|
||||
result.append(interpolated)
|
||||
return result
|
||||
if method == 'exclusive':
|
||||
@@ -655,7 +675,7 @@ def quantiles(data, *, n=4, method='exclusive'):
|
||||
j = i * m // n # rescale i to m/n
|
||||
j = 1 if j < 1 else ld-1 if j > ld-1 else j # clamp to 1 .. ld-1
|
||||
delta = i*m - j*n # exact integer math
|
||||
interpolated = (data[j-1] * (n - delta) + data[j] * delta) / n
|
||||
interpolated = (data[j - 1] * (n - delta) + data[j] * delta) / n
|
||||
result.append(interpolated)
|
||||
return result
|
||||
raise ValueError(f'Unknown method: {method!r}')
|
||||
@@ -685,14 +705,20 @@ def _ss(data, c=None):
|
||||
if c is not None:
|
||||
T, total, count = _sum((x-c)**2 for x in data)
|
||||
return (T, total)
|
||||
c = mean(data)
|
||||
T, total, count = _sum((x-c)**2 for x in data)
|
||||
# The following sum should mathematically equal zero, but due to rounding
|
||||
# error may not.
|
||||
U, total2, count2 = _sum((x-c) for x in data)
|
||||
assert T == U and count == count2
|
||||
total -= total2**2/len(data)
|
||||
assert not total < 0, 'negative sum of square deviations: %f' % total
|
||||
T, total, count = _sum(data)
|
||||
mean_n, mean_d = (total / count).as_integer_ratio()
|
||||
partials = Counter()
|
||||
for n, d in map(_exact_ratio, data):
|
||||
diff_n = n * mean_d - d * mean_n
|
||||
diff_d = d * mean_d
|
||||
partials[diff_d * diff_d] += diff_n * diff_n
|
||||
if None in partials:
|
||||
# The sum will be a NAN or INF. We can ignore all the finite
|
||||
# partials, and just look at this special one.
|
||||
total = partials[None]
|
||||
assert not _isfinite(total)
|
||||
else:
|
||||
total = sum(Fraction(n, d) for d, n in partials.items())
|
||||
return (T, total)
|
||||
|
||||
|
||||
@@ -740,7 +766,7 @@ def variance(data, xbar=None):
|
||||
if n < 2:
|
||||
raise StatisticsError('variance requires at least two data points')
|
||||
T, ss = _ss(data, xbar)
|
||||
return _convert(ss/(n-1), T)
|
||||
return _convert(ss / (n - 1), T)
|
||||
|
||||
|
||||
def pvariance(data, mu=None):
|
||||
@@ -784,7 +810,7 @@ def pvariance(data, mu=None):
|
||||
if n < 1:
|
||||
raise StatisticsError('pvariance requires at least one data point')
|
||||
T, ss = _ss(data, mu)
|
||||
return _convert(ss/n, T)
|
||||
return _convert(ss / n, T)
|
||||
|
||||
|
||||
def stdev(data, xbar=None):
|
||||
@@ -796,6 +822,9 @@ def stdev(data, xbar=None):
|
||||
1.0810874155219827
|
||||
|
||||
"""
|
||||
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
|
||||
# remain because there are two rounding steps. The first occurs in
|
||||
# the _convert() step for variance(), the second occurs in math.sqrt().
|
||||
var = variance(data, xbar)
|
||||
try:
|
||||
return var.sqrt()
|
||||
@@ -812,6 +841,9 @@ def pstdev(data, mu=None):
|
||||
0.986893273527251
|
||||
|
||||
"""
|
||||
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
|
||||
# remain because there are two rounding steps. The first occurs in
|
||||
# the _convert() step for pvariance(), the second occurs in math.sqrt().
|
||||
var = pvariance(data, mu)
|
||||
try:
|
||||
return var.sqrt()
|
||||
@@ -819,6 +851,119 @@ def pstdev(data, mu=None):
|
||||
return math.sqrt(var)
|
||||
|
||||
|
||||
# === Statistics for relations between two inputs ===
|
||||
|
||||
# See https://en.wikipedia.org/wiki/Covariance
|
||||
# https://en.wikipedia.org/wiki/Pearson_correlation_coefficient
|
||||
# https://en.wikipedia.org/wiki/Simple_linear_regression
|
||||
|
||||
|
||||
def covariance(x, y, /):
|
||||
"""Covariance
|
||||
|
||||
Return the sample covariance of two inputs *x* and *y*. Covariance
|
||||
is a measure of the joint variability of two inputs.
|
||||
|
||||
>>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
>>> y = [1, 2, 3, 1, 2, 3, 1, 2, 3]
|
||||
>>> covariance(x, y)
|
||||
0.75
|
||||
>>> z = [9, 8, 7, 6, 5, 4, 3, 2, 1]
|
||||
>>> covariance(x, z)
|
||||
-7.5
|
||||
>>> covariance(z, x)
|
||||
-7.5
|
||||
|
||||
"""
|
||||
n = len(x)
|
||||
if len(y) != n:
|
||||
raise StatisticsError('covariance requires that both inputs have same number of data points')
|
||||
if n < 2:
|
||||
raise StatisticsError('covariance requires at least two data points')
|
||||
xbar = fsum(x) / n
|
||||
ybar = fsum(y) / n
|
||||
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
|
||||
return sxy / (n - 1)
|
||||
|
||||
|
||||
def correlation(x, y, /):
|
||||
"""Pearson's correlation coefficient
|
||||
|
||||
Return the Pearson's correlation coefficient for two inputs. Pearson's
|
||||
correlation coefficient *r* takes values between -1 and +1. It measures the
|
||||
strength and direction of the linear relationship, where +1 means very
|
||||
strong, positive linear relationship, -1 very strong, negative linear
|
||||
relationship, and 0 no linear relationship.
|
||||
|
||||
>>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
>>> y = [9, 8, 7, 6, 5, 4, 3, 2, 1]
|
||||
>>> correlation(x, x)
|
||||
1.0
|
||||
>>> correlation(x, y)
|
||||
-1.0
|
||||
|
||||
"""
|
||||
n = len(x)
|
||||
if len(y) != n:
|
||||
raise StatisticsError('correlation requires that both inputs have same number of data points')
|
||||
if n < 2:
|
||||
raise StatisticsError('correlation requires at least two data points')
|
||||
xbar = fsum(x) / n
|
||||
ybar = fsum(y) / n
|
||||
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
|
||||
sxx = fsum((xi - xbar) ** 2.0 for xi in x)
|
||||
syy = fsum((yi - ybar) ** 2.0 for yi in y)
|
||||
try:
|
||||
return sxy / sqrt(sxx * syy)
|
||||
except ZeroDivisionError:
|
||||
raise StatisticsError('at least one of the inputs is constant')
|
||||
|
||||
|
||||
LinearRegression = namedtuple('LinearRegression', ('slope', 'intercept'))
|
||||
|
||||
|
||||
def linear_regression(x, y, /):
|
||||
"""Slope and intercept for simple linear regression.
|
||||
|
||||
Return the slope and intercept of simple linear regression
|
||||
parameters estimated using ordinary least squares. Simple linear
|
||||
regression describes relationship between an independent variable
|
||||
*x* and a dependent variable *y* in terms of linear function:
|
||||
|
||||
y = slope * x + intercept + noise
|
||||
|
||||
where *slope* and *intercept* are the regression parameters that are
|
||||
estimated, and noise represents the variability of the data that was
|
||||
not explained by the linear regression (it is equal to the
|
||||
difference between predicted and actual values of the dependent
|
||||
variable).
|
||||
|
||||
The parameters are returned as a named tuple.
|
||||
|
||||
>>> x = [1, 2, 3, 4, 5]
|
||||
>>> noise = NormalDist().samples(5, seed=42)
|
||||
>>> y = [3 * x[i] + 2 + noise[i] for i in range(5)]
|
||||
>>> linear_regression(x, y) #doctest: +ELLIPSIS
|
||||
LinearRegression(slope=3.09078914170..., intercept=1.75684970486...)
|
||||
|
||||
"""
|
||||
n = len(x)
|
||||
if len(y) != n:
|
||||
raise StatisticsError('linear regression requires that both inputs have same number of data points')
|
||||
if n < 2:
|
||||
raise StatisticsError('linear regression requires at least two data points')
|
||||
xbar = fsum(x) / n
|
||||
ybar = fsum(y) / n
|
||||
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
|
||||
sxx = fsum((xi - xbar) ** 2.0 for xi in x)
|
||||
try:
|
||||
slope = sxy / sxx # equivalent to: covariance(x, y) / variance(x)
|
||||
except ZeroDivisionError:
|
||||
raise StatisticsError('x is constant')
|
||||
intercept = ybar - slope * xbar
|
||||
return LinearRegression(slope=slope, intercept=intercept)
|
||||
|
||||
|
||||
## Normal Distribution #####################################################
|
||||
|
||||
|
||||
@@ -896,6 +1041,13 @@ def _normal_dist_inv_cdf(p, mu, sigma):
|
||||
return mu + (x * sigma)
|
||||
|
||||
|
||||
# If available, use C implementation
|
||||
try:
|
||||
from _statistics import _normal_dist_inv_cdf
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class NormalDist:
|
||||
"Normal distribution of a random variable"
|
||||
# https://en.wikipedia.org/wiki/Normal_distribution
|
||||
@@ -986,7 +1138,7 @@ class NormalDist:
|
||||
if not isinstance(other, NormalDist):
|
||||
raise TypeError('Expected another NormalDist instance')
|
||||
X, Y = self, other
|
||||
if (Y._sigma, Y._mu) < (X._sigma, X._mu): # sort to assure commutativity
|
||||
if (Y._sigma, Y._mu) < (X._sigma, X._mu): # sort to assure commutativity
|
||||
X, Y = Y, X
|
||||
X_var, Y_var = X.variance, Y.variance
|
||||
if not X_var or not Y_var:
|
||||
@@ -1001,6 +1153,17 @@ class NormalDist:
|
||||
x2 = (a - b) / dv
|
||||
return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2)))
|
||||
|
||||
def zscore(self, x):
|
||||
"""Compute the Standard Score. (x - mean) / stdev
|
||||
|
||||
Describes *x* in terms of the number of standard deviations
|
||||
above or below the mean of the normal distribution.
|
||||
"""
|
||||
# https://www.statisticshowto.com/probability-and-statistics/z-score/
|
||||
if not self._sigma:
|
||||
raise StatisticsError('zscore() not defined when sigma is zero')
|
||||
return (x - self._mu) / self._sigma
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
"Arithmetic mean of the normal distribution."
|
||||
@@ -1102,79 +1265,3 @@ class NormalDist:
|
||||
|
||||
def __repr__(self):
|
||||
return f'{type(self).__name__}(mu={self._mu!r}, sigma={self._sigma!r})'
|
||||
|
||||
# If available, use C implementation
|
||||
try:
|
||||
from _statistics import _normal_dist_inv_cdf
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Show math operations computed analytically in comparsion
|
||||
# to a monte carlo simulation of the same operations
|
||||
|
||||
from math import isclose
|
||||
from operator import add, sub, mul, truediv
|
||||
from itertools import repeat
|
||||
import doctest
|
||||
|
||||
g1 = NormalDist(10, 20)
|
||||
g2 = NormalDist(-5, 25)
|
||||
|
||||
# Test scaling by a constant
|
||||
assert (g1 * 5 / 5).mean == g1.mean
|
||||
assert (g1 * 5 / 5).stdev == g1.stdev
|
||||
|
||||
n = 100_000
|
||||
G1 = g1.samples(n)
|
||||
G2 = g2.samples(n)
|
||||
|
||||
for func in (add, sub):
|
||||
print(f'\nTest {func.__name__} with another NormalDist:')
|
||||
print(func(g1, g2))
|
||||
print(NormalDist.from_samples(map(func, G1, G2)))
|
||||
|
||||
const = 11
|
||||
for func in (add, sub, mul, truediv):
|
||||
print(f'\nTest {func.__name__} with a constant:')
|
||||
print(func(g1, const))
|
||||
print(NormalDist.from_samples(map(func, G1, repeat(const))))
|
||||
|
||||
const = 19
|
||||
for func in (add, sub, mul):
|
||||
print(f'\nTest constant with {func.__name__}:')
|
||||
print(func(const, g1))
|
||||
print(NormalDist.from_samples(map(func, repeat(const), G1)))
|
||||
|
||||
def assert_close(G1, G2):
|
||||
assert isclose(G1.mean, G1.mean, rel_tol=0.01), (G1, G2)
|
||||
assert isclose(G1.stdev, G2.stdev, rel_tol=0.01), (G1, G2)
|
||||
|
||||
X = NormalDist(-105, 73)
|
||||
Y = NormalDist(31, 47)
|
||||
s = 32.75
|
||||
n = 100_000
|
||||
|
||||
S = NormalDist.from_samples([x + s for x in X.samples(n)])
|
||||
assert_close(X + s, S)
|
||||
|
||||
S = NormalDist.from_samples([x - s for x in X.samples(n)])
|
||||
assert_close(X - s, S)
|
||||
|
||||
S = NormalDist.from_samples([x * s for x in X.samples(n)])
|
||||
assert_close(X * s, S)
|
||||
|
||||
S = NormalDist.from_samples([x / s for x in X.samples(n)])
|
||||
assert_close(X / s, S)
|
||||
|
||||
S = NormalDist.from_samples([x + y for x, y in zip(X.samples(n),
|
||||
Y.samples(n))])
|
||||
assert_close(X + Y, S)
|
||||
|
||||
S = NormalDist.from_samples([x - y for x, y in zip(X.samples(n),
|
||||
Y.samples(n))])
|
||||
assert_close(X - Y, S)
|
||||
|
||||
print(doctest.testmod())
|
||||
|
||||
231
Lib/test/test_statistics.py
vendored
231
Lib/test/test_statistics.py
vendored
@@ -14,11 +14,11 @@ import pickle
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from test import support
|
||||
from test.support import import_helper
|
||||
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from test import support
|
||||
from test.support import import_helper
|
||||
|
||||
|
||||
# Module to be tested.
|
||||
@@ -179,8 +179,10 @@ class _DoNothing:
|
||||
# We prefer this for testing numeric values that may not be exactly equal,
|
||||
# and avoid using TestCase.assertAlmostEqual, because it sucks :-)
|
||||
|
||||
py_statistics = import_helper.import_fresh_module('statistics', blocked=['_statistics'])
|
||||
c_statistics = import_helper.import_fresh_module('statistics', fresh=['_statistics'])
|
||||
py_statistics = import_helper.import_fresh_module('statistics',
|
||||
blocked=['_statistics'])
|
||||
c_statistics = import_helper.import_fresh_module('statistics',
|
||||
fresh=['_statistics'])
|
||||
|
||||
|
||||
class TestModules(unittest.TestCase):
|
||||
@@ -1006,6 +1008,10 @@ class ConvertTest(unittest.TestCase):
|
||||
x = statistics._convert(nan, type(nan))
|
||||
self.assertTrue(_nan_equal(x, nan))
|
||||
|
||||
def test_invalid_input_type(self):
|
||||
with self.assertRaises(TypeError):
|
||||
statistics._convert(None, float)
|
||||
|
||||
|
||||
class FailNegTest(unittest.TestCase):
|
||||
"""Test _fail_neg private function."""
|
||||
@@ -1035,6 +1041,50 @@ class FailNegTest(unittest.TestCase):
|
||||
self.assertEqual(errmsg, msg)
|
||||
|
||||
|
||||
class FindLteqTest(unittest.TestCase):
|
||||
# Test _find_lteq private function.
|
||||
|
||||
def test_invalid_input_values(self):
|
||||
for a, x in [
|
||||
([], 1),
|
||||
([1, 2], 3),
|
||||
([1, 3], 2)
|
||||
]:
|
||||
with self.subTest(a=a, x=x):
|
||||
with self.assertRaises(ValueError):
|
||||
statistics._find_lteq(a, x)
|
||||
|
||||
def test_locate_successfully(self):
|
||||
for a, x, expected_i in [
|
||||
([1, 1, 1, 2, 3], 1, 0),
|
||||
([0, 1, 1, 1, 2, 3], 1, 1),
|
||||
([1, 2, 3, 3, 3], 3, 2)
|
||||
]:
|
||||
with self.subTest(a=a, x=x):
|
||||
self.assertEqual(expected_i, statistics._find_lteq(a, x))
|
||||
|
||||
|
||||
class FindRteqTest(unittest.TestCase):
|
||||
# Test _find_rteq private function.
|
||||
|
||||
def test_invalid_input_values(self):
|
||||
for a, l, x in [
|
||||
([1], 2, 1),
|
||||
([1, 3], 0, 2)
|
||||
]:
|
||||
with self.assertRaises(ValueError):
|
||||
statistics._find_rteq(a, l, x)
|
||||
|
||||
def test_locate_successfully(self):
|
||||
for a, l, x, expected_i in [
|
||||
([1, 1, 1, 2, 3], 0, 1, 2),
|
||||
([0, 1, 1, 1, 2, 3], 0, 1, 3),
|
||||
([1, 2, 3, 3, 3], 0, 3, 4)
|
||||
]:
|
||||
with self.subTest(a=a, l=l, x=x):
|
||||
self.assertEqual(expected_i, statistics._find_rteq(a, l, x))
|
||||
|
||||
|
||||
# === Tests for public functions ===
|
||||
|
||||
class UnivariateCommonMixin:
|
||||
@@ -1199,20 +1249,14 @@ class TestSum(NumericTestCase):
|
||||
# Override test for empty data.
|
||||
for data in ([], (), iter([])):
|
||||
self.assertEqual(self.func(data), (int, Fraction(0), 0))
|
||||
self.assertEqual(self.func(data, 23), (int, Fraction(23), 0))
|
||||
self.assertEqual(self.func(data, 2.3), (float, Fraction(2.3), 0))
|
||||
|
||||
def test_ints(self):
|
||||
self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
|
||||
(int, Fraction(60), 8))
|
||||
self.assertEqual(self.func([4, 2, 3, -8, 7], 1000),
|
||||
(int, Fraction(1008), 5))
|
||||
|
||||
def test_floats(self):
|
||||
self.assertEqual(self.func([0.25]*20),
|
||||
(float, Fraction(5.0), 20))
|
||||
self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5),
|
||||
(float, Fraction(3.125), 4))
|
||||
|
||||
def test_fractions(self):
|
||||
self.assertEqual(self.func([Fraction(1, 1000)]*500),
|
||||
@@ -1233,14 +1277,6 @@ class TestSum(NumericTestCase):
|
||||
data = [random.uniform(-100, 1000) for _ in range(1000)]
|
||||
self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
|
||||
|
||||
def test_start_argument(self):
|
||||
# Test that the optional start argument works correctly.
|
||||
data = [random.uniform(1, 1000) for _ in range(100)]
|
||||
t = self.func(data)[1]
|
||||
self.assertEqual(t+42, self.func(data, 42)[1])
|
||||
self.assertEqual(t-23, self.func(data, -23)[1])
|
||||
self.assertEqual(t+Fraction(1e20), self.func(data, 1e20)[1])
|
||||
|
||||
def test_strings_fail(self):
|
||||
# Sum of strings should fail.
|
||||
self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
|
||||
@@ -1480,6 +1516,18 @@ class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
|
||||
with self.subTest(values=values):
|
||||
self.assertRaises(exc, self.func, values)
|
||||
|
||||
def test_invalid_type_error(self):
|
||||
# Test error is raised when input contains invalid type(s)
|
||||
for data in [
|
||||
['3.14'], # single string
|
||||
['1', '2', '3'], # multiple strings
|
||||
[1, '2', 3, '4', 5], # mixed strings and valid integers
|
||||
[2.3, 3.4, 4.5, '5.6'] # only one string and valid floats
|
||||
]:
|
||||
with self.subTest(data=data):
|
||||
with self.assertRaises(TypeError):
|
||||
self.func(data)
|
||||
|
||||
def test_ints(self):
|
||||
# Test harmonic mean with ints.
|
||||
data = [2, 4, 4, 8, 16, 16]
|
||||
@@ -1541,6 +1589,27 @@ class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
|
||||
actual = self.func(data*2)
|
||||
self.assertApproxEqual(actual, expected)
|
||||
|
||||
def test_with_weights(self):
|
||||
self.assertEqual(self.func([40, 60], [5, 30]), 56.0) # common case
|
||||
self.assertEqual(self.func([40, 60],
|
||||
weights=[5, 30]), 56.0) # keyword argument
|
||||
self.assertEqual(self.func(iter([40, 60]),
|
||||
iter([5, 30])), 56.0) # iterator inputs
|
||||
self.assertEqual(
|
||||
self.func([Fraction(10, 3), Fraction(23, 5), Fraction(7, 2)], [5, 2, 10]),
|
||||
self.func([Fraction(10, 3)] * 5 +
|
||||
[Fraction(23, 5)] * 2 +
|
||||
[Fraction(7, 2)] * 10))
|
||||
self.assertEqual(self.func([10], [7]), 10) # n=1 fast path
|
||||
with self.assertRaises(TypeError):
|
||||
self.func([1, 2, 3], [1, (), 3]) # non-numeric weight
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
self.func([1, 2, 3], [1, 2]) # wrong number of weights
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
self.func([10], [0]) # no non-zero weights
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
self.func([10, 20], [0, 0]) # no non-zero weights
|
||||
|
||||
|
||||
class TestMedian(NumericTestCase, AverageMixin):
|
||||
# Common tests for median and all median.* functions.
|
||||
@@ -1832,10 +1901,13 @@ class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin):
|
||||
|
||||
def test_counter_data(self):
|
||||
# Test that a Counter is treated like any other iterable.
|
||||
data = collections.Counter([1, 1, 1, 2])
|
||||
# Since the keys of the counter are treated as data points, not the
|
||||
# counts, this should return the first mode encountered, 1
|
||||
self.assertEqual(self.func(data), 1)
|
||||
# We're making sure mode() first calls iter() on its input.
|
||||
# The concern is that a Counter of a Counter returns the original
|
||||
# unchanged rather than counting its keys.
|
||||
c = collections.Counter(a=1, b=2)
|
||||
# If iter() is called, mode(c) loops over the keys, ['a', 'b'],
|
||||
# all the counts will be 1, and the first encountered mode is 'a'.
|
||||
self.assertEqual(self.func(c), 'a')
|
||||
|
||||
|
||||
class TestMultiMode(unittest.TestCase):
|
||||
@@ -2000,6 +2072,13 @@ class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
|
||||
self.assertEqual(result, exact)
|
||||
self.assertIsInstance(result, Decimal)
|
||||
|
||||
def test_accuracy_bug_20499(self):
|
||||
data = [0, 0, 1]
|
||||
exact = 2 / 9
|
||||
result = self.func(data)
|
||||
self.assertEqual(result, exact)
|
||||
self.assertIsInstance(result, float)
|
||||
|
||||
|
||||
class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
|
||||
# Tests for sample variance.
|
||||
@@ -2040,6 +2119,13 @@ class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
|
||||
self.assertEqual(self.func(data), 0.5)
|
||||
self.assertEqual(self.func(data, xbar=2.0), 1.0)
|
||||
|
||||
def test_accuracy_bug_20499(self):
|
||||
data = [0, 0, 2]
|
||||
exact = 4 / 3
|
||||
result = self.func(data)
|
||||
self.assertEqual(result, exact)
|
||||
self.assertIsInstance(result, float)
|
||||
|
||||
class TestPStdev(VarianceStdevMixin, NumericTestCase):
|
||||
# Tests for population standard deviation.
|
||||
def setUp(self):
|
||||
@@ -2078,6 +2164,7 @@ class TestStdev(VarianceStdevMixin, NumericTestCase):
|
||||
self.assertEqual(self.func(data, xbar=2.0), 1.0)
|
||||
|
||||
class TestGeometricMean(unittest.TestCase):
|
||||
|
||||
def test_basics(self):
|
||||
geometric_mean = statistics.geometric_mean
|
||||
self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0)
|
||||
@@ -2329,6 +2416,84 @@ class TestQuantiles(unittest.TestCase):
|
||||
quantiles([10, None, 30], n=4) # data is non-numeric
|
||||
|
||||
|
||||
class TestBivariateStatistics(unittest.TestCase):
|
||||
|
||||
def test_unequal_size_error(self):
|
||||
for x, y in [
|
||||
([1, 2, 3], [1, 2]),
|
||||
([1, 2], [1, 2, 3]),
|
||||
]:
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
statistics.covariance(x, y)
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
statistics.correlation(x, y)
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
statistics.linear_regression(x, y)
|
||||
|
||||
def test_small_sample_error(self):
|
||||
for x, y in [
|
||||
([], []),
|
||||
([], [1, 2,]),
|
||||
([1, 2,], []),
|
||||
([1,], [1,]),
|
||||
([1,], [1, 2,]),
|
||||
([1, 2,], [1,]),
|
||||
]:
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
statistics.covariance(x, y)
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
statistics.correlation(x, y)
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
statistics.linear_regression(x, y)
|
||||
|
||||
|
||||
class TestCorrelationAndCovariance(unittest.TestCase):
|
||||
|
||||
def test_results(self):
|
||||
for x, y, result in [
|
||||
([1, 2, 3], [1, 2, 3], 1),
|
||||
([1, 2, 3], [-1, -2, -3], -1),
|
||||
([1, 2, 3], [3, 2, 1], -1),
|
||||
([1, 2, 3], [1, 2, 1], 0),
|
||||
([1, 2, 3], [1, 3, 2], 0.5),
|
||||
]:
|
||||
self.assertAlmostEqual(statistics.correlation(x, y), result)
|
||||
self.assertAlmostEqual(statistics.covariance(x, y), result)
|
||||
|
||||
def test_different_scales(self):
|
||||
x = [1, 2, 3]
|
||||
y = [10, 30, 20]
|
||||
self.assertAlmostEqual(statistics.correlation(x, y), 0.5)
|
||||
self.assertAlmostEqual(statistics.covariance(x, y), 5)
|
||||
|
||||
y = [.1, .2, .3]
|
||||
self.assertAlmostEqual(statistics.correlation(x, y), 1)
|
||||
self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
|
||||
|
||||
|
||||
class TestLinearRegression(unittest.TestCase):
|
||||
|
||||
def test_constant_input_error(self):
|
||||
x = [1, 1, 1,]
|
||||
y = [1, 2, 3,]
|
||||
with self.assertRaises(statistics.StatisticsError):
|
||||
statistics.linear_regression(x, y)
|
||||
|
||||
def test_results(self):
|
||||
for x, y, true_intercept, true_slope in [
|
||||
([1, 2, 3], [0, 0, 0], 0, 0),
|
||||
([1, 2, 3], [1, 2, 3], 0, 1),
|
||||
([1, 2, 3], [100, 100, 100], 100, 0),
|
||||
([1, 2, 3], [12, 14, 16], 10, 2),
|
||||
([1, 2, 3], [-1, -2, -3], 0, -1),
|
||||
([1, 2, 3], [21, 22, 23], 20, 1),
|
||||
([1, 2, 3], [5.1, 5.2, 5.3], 5, 0.1),
|
||||
]:
|
||||
slope, intercept = statistics.linear_regression(x, y)
|
||||
self.assertAlmostEqual(intercept, true_intercept)
|
||||
self.assertAlmostEqual(slope, true_slope)
|
||||
|
||||
|
||||
class TestNormalDist:
|
||||
|
||||
# General note on precision: The pdf(), cdf(), and overlap() methods
|
||||
@@ -2480,8 +2645,6 @@ class TestNormalDist:
|
||||
self.assertEqual(X.cdf(float('Inf')), 1.0)
|
||||
self.assertTrue(math.isnan(X.cdf(float('NaN'))))
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@support.skip_if_pgo_task
|
||||
def test_inv_cdf(self):
|
||||
NormalDist = self.module.NormalDist
|
||||
@@ -2623,6 +2786,21 @@ class TestNormalDist:
|
||||
with self.assertRaises(self.module.StatisticsError):
|
||||
NormalDist(1, 0).overlap(X) # left operand sigma is zero
|
||||
|
||||
def test_zscore(self):
|
||||
NormalDist = self.module.NormalDist
|
||||
X = NormalDist(100, 15)
|
||||
self.assertEqual(X.zscore(142), 2.8)
|
||||
self.assertEqual(X.zscore(58), -2.8)
|
||||
self.assertEqual(X.zscore(100), 0.0)
|
||||
with self.assertRaises(TypeError):
|
||||
X.zscore() # too few arguments
|
||||
with self.assertRaises(TypeError):
|
||||
X.zscore(1, 1) # too may arguments
|
||||
with self.assertRaises(TypeError):
|
||||
X.zscore(None) # non-numeric type
|
||||
with self.assertRaises(self.module.StatisticsError):
|
||||
NormalDist(1, 0).zscore(100) # sigma is zero
|
||||
|
||||
def test_properties(self):
|
||||
X = self.module.NormalDist(100, 15)
|
||||
self.assertEqual(X.mean, 100)
|
||||
@@ -2740,6 +2918,11 @@ class TestNormalDistPython(unittest.TestCase, TestNormalDist):
|
||||
def tearDown(self):
|
||||
sys.modules['statistics'] = statistics
|
||||
|
||||
# TODO: RUSTPYTHON, ValueError: math domain error
|
||||
@unittest.expectedFailure
|
||||
def test_inv_cdf(self): # TODO: RUSTPYTHON, remove when this passes
|
||||
super().test_inv_cdf() # TODO: RUSTPYTHON, remove when this passes
|
||||
|
||||
|
||||
@unittest.skipUnless(c_statistics, 'requires _statistics')
|
||||
class TestNormalDistC(unittest.TestCase, TestNormalDist):
|
||||
|
||||
@@ -22,6 +22,7 @@ mod random;
|
||||
// mod re;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub mod socket;
|
||||
mod statistics;
|
||||
#[cfg(unix)]
|
||||
mod syslog;
|
||||
mod unicodedata;
|
||||
@@ -85,6 +86,7 @@ pub fn get_module_inits() -> impl Iterator<Item = (Cow<'static, str>, StdlibInit
|
||||
"pyexpat" => pyexpat::make_module,
|
||||
"_platform" => platform::make_module,
|
||||
"_random" => random::make_module,
|
||||
"_statistics" => statistics::make_module,
|
||||
"unicodedata" => unicodedata::make_module,
|
||||
"zlib" => zlib::make_module,
|
||||
// crate::vm::sysmodule::sysconfigdata_name() => sysconfigdata::make_module,
|
||||
|
||||
131
stdlib/src/statistics.rs
Normal file
131
stdlib/src/statistics.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
pub(crate) use statistics::make_module;
|
||||
|
||||
#[pymodule(name = "_statistics")]
|
||||
mod statistics {
|
||||
use rustpython_vm::{PyResult, VirtualMachine};
|
||||
|
||||
/*
|
||||
* There is no closed-form solution to the inverse CDF for the normal
|
||||
* distribution, so we use a rational approximation instead:
|
||||
* Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
|
||||
* Normal Distribution". Applied Statistics. Blackwell Publishing. 37
|
||||
* (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
|
||||
*/
|
||||
|
||||
#[pyfunction(name = "_normal_dist_inv_cdf")]
|
||||
fn normal_dist_inv_cdf(p: f64, mu: f64, sigma: f64, vm: &VirtualMachine) -> PyResult<f64> {
|
||||
if p <= 0.0 || p >= 1.0 || sigma <= 0.0 {
|
||||
return Err(vm.new_value_error("inv_cdf undefined for these parameters".to_string()));
|
||||
}
|
||||
|
||||
let q = p - 0.5;
|
||||
let num: f64;
|
||||
let den: f64;
|
||||
#[allow(clippy::excessive_precision)]
|
||||
if q.abs() <= 0.425 {
|
||||
let r = 0.180625 - q * q;
|
||||
// Hash sum-55.8831928806149014439
|
||||
num = (((((((2.5090809287301226727e+3 * r + 3.3430575583588128105e+4) * r
|
||||
+ 6.7265770927008700853e+4)
|
||||
* r
|
||||
+ 4.5921953931549871457e+4)
|
||||
* r
|
||||
+ 1.3731693765509461125e+4)
|
||||
* r
|
||||
+ 1.9715909503065514427e+3)
|
||||
* r
|
||||
+ 1.3314166789178437745e+2)
|
||||
* r
|
||||
+ 3.3871328727963666080e+0)
|
||||
* q;
|
||||
den = ((((((5.2264952788528545610e+3 * r + 2.8729085735721942674e+4) * r
|
||||
+ 3.9307895800092710610e+4)
|
||||
* r
|
||||
+ 2.1213794301586595867e+4)
|
||||
* r
|
||||
+ 5.3941960214247511077e+3)
|
||||
* r
|
||||
+ 6.8718700749205790830e+2)
|
||||
* r
|
||||
+ 4.2313330701600911252e+1)
|
||||
* r
|
||||
+ 1.0;
|
||||
if den == 0.0 {
|
||||
return Err(
|
||||
vm.new_value_error("inv_cdf undefined for these parameters".to_string())
|
||||
);
|
||||
}
|
||||
let x = num / den;
|
||||
return Ok(mu + (x * sigma));
|
||||
}
|
||||
let r = if q <= 0.0 { p } else { 1.0 - p };
|
||||
if r <= 0.0 || r >= 1.0 {
|
||||
return Err(vm.new_value_error("inv_cdf undefined for these parameters".to_string()));
|
||||
}
|
||||
let r = (-(r.ln())).sqrt();
|
||||
#[allow(clippy::excessive_precision)]
|
||||
if r <= 5.0 {
|
||||
let r = r - 1.6;
|
||||
// Hash sum-49.33206503301610289036
|
||||
num = ((((((7.74545014278341407640e-4 * r + 2.27238449892691845833e-2) * r
|
||||
+ 2.41780725177450611770e-1)
|
||||
* r
|
||||
+ 1.27045825245236838258e+0)
|
||||
* r
|
||||
+ 3.64784832476320460504e+0)
|
||||
* r
|
||||
+ 5.76949722146069140550e+0)
|
||||
* r
|
||||
+ 4.63033784615654529590e+0)
|
||||
* r
|
||||
+ 1.42343711074968357734e+0;
|
||||
den = ((((((1.05075007164441684324e-9 * r + 5.47593808499534494600e-4) * r
|
||||
+ 1.51986665636164571966e-2)
|
||||
* r
|
||||
+ 1.48103976427480074590e-1)
|
||||
* r
|
||||
+ 6.89767334985100004550e-1)
|
||||
* r
|
||||
+ 1.67638483018380384940e+0)
|
||||
* r
|
||||
+ 2.05319162663775882187e+0)
|
||||
* r
|
||||
+ 1.0;
|
||||
} else {
|
||||
let r = r - 5.0;
|
||||
// Hash sum-47.52583317549289671629
|
||||
num = ((((((2.01033439929228813265e-7 * r + 2.71155556874348757815e-5) * r
|
||||
+ 1.24266094738807843860e-3)
|
||||
* r
|
||||
+ 2.65321895265761230930e-2)
|
||||
* r
|
||||
+ 2.96560571828504891230e-1)
|
||||
* r
|
||||
+ 1.78482653991729133580e+0)
|
||||
* r
|
||||
+ 5.46378491116411436990e+0)
|
||||
* r
|
||||
+ 6.65790464350110377720e+0;
|
||||
den = ((((((2.04426310338993978564e-15 * r + 1.42151175831644588870e-7) * r
|
||||
+ 1.84631831751005468180e-5)
|
||||
* r
|
||||
+ 7.86869131145613259100e-4)
|
||||
* r
|
||||
+ 1.48753612908506148525e-2)
|
||||
* r
|
||||
+ 1.36929880922735805310e-1)
|
||||
* r
|
||||
+ 5.99832206555887937690e-1)
|
||||
* r
|
||||
+ 1.0;
|
||||
}
|
||||
if den == 0.0 {
|
||||
return Err(vm.new_value_error("inv_cdf undefined for these parameters".to_string()));
|
||||
}
|
||||
let mut x = num / den;
|
||||
if q < 0.0 {
|
||||
x = -x;
|
||||
}
|
||||
Ok(mu + (x * sigma))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user