HiveBrain v1.2.0
Get Started
← Back to all entries
patternjavaMinor

Computing the standard deviation of a Number array in Java

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
numberthedeviationarrayjavastandardcomputing

Problem

(See the next iteration.)

I have this funky class method for computing standard deviation from an array of Number objects using the Stream API:

StandardDeviation.java:

package net.coderodde.util;

import java.util.Arrays;

public class StandardDeviation {

    public static double computeStandardDeviation(Number... collection) {
        if (collection.length == 0) {
            return Double.NaN;
        }

        final double average =
                Arrays.stream(collection)
                      .mapToDouble((x) -> x.doubleValue())
                      .summaryStatistics()
                      .getAverage();

        final double rawSum = 
                Arrays.stream(collection)
                      .mapToDouble((x) -> Math.pow(x.doubleValue() - average,
                                                   2.0))
                      .sum();

        return Math.sqrt(rawSum / (collection.length - 1));
    }

    public static void main(String[] args) {
        // Mix 'em all!
        double sd = computeStandardDeviation((byte) 1, 
                                             (short) 2, 
                                             3, 
                                             4L, 
                                             5.0f, 
                                             6.0);

        System.out.println(sd);
    }
}


Please, tell me anything that comes to mind.

Solution


  • You are traversing the collection twice to determine the standard deviation when you could do it in a single pass.



  • Also, you could accumulate quickly rounding errors with the Math.pow(x.doubleValue() - average, 2.0) call. It would be best to implement the Kahan summation algorithm (that the Stream API has for DoubleStream#sum()).



  • In the lambda expression (x) -> x.doubleValue(), you don't need to add the parentheses around (x). You can just have x -> x.doubleValue(). You could also use a method-reference, which avoids a lamda, and have Number::doubleValue.



On Stack Overflow, I wrote an answer which calculates the standard deviation in a single pass with compensation. It is parallel-friendly:

static class DoubleStatistics extends DoubleSummaryStatistics {

    private double sumOfSquare = 0.0d;
    private double sumOfSquareCompensation; // Low order bits of sum
    private double simpleSumOfSquare; // Used to compute right sum for
                                        // non-finite inputs

    @Override
    public void accept(double value) {
        super.accept(value);
        double squareValue = value * value;
        simpleSumOfSquare += squareValue;
        sumOfSquareWithCompensation(squareValue);
    }

    public DoubleStatistics combine(DoubleStatistics other) {
        super.combine(other);
        simpleSumOfSquare += other.simpleSumOfSquare;
        sumOfSquareWithCompensation(other.sumOfSquare);
        sumOfSquareWithCompensation(other.sumOfSquareCompensation);
        return this;
    }

    private void sumOfSquareWithCompensation(double value) {
        double tmp = value - sumOfSquareCompensation;
        double velvel = sumOfSquare + tmp; // Little wolf of rounding error
        sumOfSquareCompensation = (velvel - sumOfSquare) - tmp;
        sumOfSquare = velvel;
    }

    public double getSumOfSquare() {
        double tmp = sumOfSquare + sumOfSquareCompensation;
        if (Double.isNaN(tmp) && Double.isInfinite(simpleSumOfSquare)) {
            return simpleSumOfSquare;
        }
        return tmp;
    }

    public final double getStandardDeviation() {
        long count = getCount();
        double sumOfSquare = getSumOfSquare();
        double average = getAverage();
        return count > 0 ? Math.sqrt((sumOfSquare - count * Math.pow(average, 2)) / (count - 1)) : 0.0d;
    }

    public static Collector collector() {
        return Collector.of(DoubleStatistics::new, DoubleStatistics::accept, DoubleStatistics::combine);
    }

}


It has the same logic as DoubleSummaryStatistics but extended to calculate the sum of squares.

With such a class, you can then have:

public static double computeStandardDeviation(Number... collection) {
    return Arrays.stream(collection)
                 .map(Number::doubleValue)
                 .collect(DoubleStatistics.collector())
                 .getStandardDeviation();
}

Code Snippets

static class DoubleStatistics extends DoubleSummaryStatistics {

    private double sumOfSquare = 0.0d;
    private double sumOfSquareCompensation; // Low order bits of sum
    private double simpleSumOfSquare; // Used to compute right sum for
                                        // non-finite inputs

    @Override
    public void accept(double value) {
        super.accept(value);
        double squareValue = value * value;
        simpleSumOfSquare += squareValue;
        sumOfSquareWithCompensation(squareValue);
    }

    public DoubleStatistics combine(DoubleStatistics other) {
        super.combine(other);
        simpleSumOfSquare += other.simpleSumOfSquare;
        sumOfSquareWithCompensation(other.sumOfSquare);
        sumOfSquareWithCompensation(other.sumOfSquareCompensation);
        return this;
    }

    private void sumOfSquareWithCompensation(double value) {
        double tmp = value - sumOfSquareCompensation;
        double velvel = sumOfSquare + tmp; // Little wolf of rounding error
        sumOfSquareCompensation = (velvel - sumOfSquare) - tmp;
        sumOfSquare = velvel;
    }

    public double getSumOfSquare() {
        double tmp = sumOfSquare + sumOfSquareCompensation;
        if (Double.isNaN(tmp) && Double.isInfinite(simpleSumOfSquare)) {
            return simpleSumOfSquare;
        }
        return tmp;
    }

    public final double getStandardDeviation() {
        long count = getCount();
        double sumOfSquare = getSumOfSquare();
        double average = getAverage();
        return count > 0 ? Math.sqrt((sumOfSquare - count * Math.pow(average, 2)) / (count - 1)) : 0.0d;
    }

    public static Collector<Double, ?, DoubleStatistics> collector() {
        return Collector.of(DoubleStatistics::new, DoubleStatistics::accept, DoubleStatistics::combine);
    }

}
public static double computeStandardDeviation(Number... collection) {
    return Arrays.stream(collection)
                 .map(Number::doubleValue)
                 .collect(DoubleStatistics.collector())
                 .getStandardDeviation();
}

Context

StackExchange Code Review Q#125409, answer score: 6

Revisions (0)

No revisions yet.