001/*
002// This software is subject to the terms of the Eclipse Public License v1.0
003// Agreement, available at the following URL:
004// http://www.eclipse.org/legal/epl-v10.html.
005// You must accept the terms of that agreement to use this software.
006//
007// Copyright (C) 2005-2005 Julian Hyde
008// Copyright (C) 2005-2011 Pentaho
009// All Rights Reserved.
010*/
011package mondrian.olap.fun;
012
013import mondrian.calc.*;
014import mondrian.calc.impl.AbstractDoubleCalc;
015import mondrian.calc.impl.ValueCalc;
016import mondrian.mdx.ResolvedFunCall;
017import mondrian.olap.*;
018
019import java.util.*;
020
021/**
022 * Abstract base class for definitions of linear regression functions.
023 *
024 * @see InterceptFunDef
025 * @see PointFunDef
026 * @see R2FunDef
027 * @see SlopeFunDef
028 * @see VarianceFunDef
029 *
030 * <h2>Correlation coefficient</h2>
031 * <p><i>Correlation coefficient</i></p>
032 *
033 * <p>The correlation coefficient, r, ranges from -1 to  + 1. The
034 * nonparametric Spearman correlation coefficient, abbreviated rs, has
035 * the same range.</p>
036 *
037 * <table border="1" cellpadding="6" cellspacing="0">
038 *   <tr>
039 *     <td>Value of r (or rs)</td>
040 *     <td>Interpretation</td>
041 *   </tr>
042 *   <tr>
043 *     <td valign="top">r= 0</td>
044 *
045 *     <td>The two variables do not vary together at all.</td>
046 *   </tr>
047 *   <tr>
048 *     <td valign="top">0 &gt; r &gt; 1</td>
049 *     <td>
050 *       <p>The two variables tend to increase or decrease together.</p>
051 *     </td>
052 *   </tr>
053 *   <tr>
054 *     <td valign="top">r = 1.0</td>
055 *     <td>
056 *       <p>Perfect correlation.</p>
057 *     </td>
058 *   </tr>
059 *
060 *   <tr>
061 *     <td valign="top">-1 &gt; r &gt; 0</td>
062 *     <td>
063 *       <p>One variable increases as the other decreases.</p>
064 *     </td>
065 *   </tr>
066 *
067 *   <tr>
068 *     <td valign="top">r = -1.0</td>
069 *     <td>
070 *       <p></p>
071 *       <p>Perfect negative or inverse correlation.</p>
072 *     </td>
073 *   </tr>
074 * </table>
075 *
076 * <p>If r or rs is far from zero, there are four possible explanations:</p>
077 * <p>The X variable helps determine the value of the Y variable.</p>
078 * <ul>
079 *   <li>The Y variable helps determine the value of the X variable.
080 *   <li>Another variable influences both X and Y.
081 *   <li>X and Y don't really correlate at all, and you just
082 *       happened to observe such a strong correlation by chance. The P value
083 *       determines how often this could occur.
084 * </ul>
085 * <p><i>r2 </i></p>
086 *
087 * <p>Perhaps the best way to interpret the value of r is to square it to
088 * calculate r2. Statisticians call this quantity the coefficient of
089 * determination, but scientists call it r squared. It is has a value
090 * that ranges from zero to one, and is the fraction of the variance in
091 * the two variables that is shared. For example, if r2=0.59, then 59% of
092 * the variance in X can be explained by variation in Y. &nbsp;Likewise,
093 * 59% of the variance in Y can be explained by (or goes along with)
094 * variation in X. More simply, 59% of the variance is shared between X
095 * and Y.</p>
096 *
097 * <p>(<a href="http://www.graphpad.com/articles/interpret/corl_n_linear_reg/correlation.htm">Source</a>).
098 *
099 * <p>Also see: <a href="http://mathworld.wolfram.com/LeastSquaresFitting.html">least squares fitting</a>.
100 */
101
102
103public abstract class LinReg extends FunDefBase {
104    /** Code for the specific function. */
105    final int regType;
106
107    public static final int Point = 0;
108    public static final int R2 = 1;
109    public static final int Intercept = 2;
110    public static final int Slope = 3;
111    public static final int Variance = 4;
112
113    static final Resolver InterceptResolver =
114        new ReflectiveMultiResolver(
115            "LinRegIntercept",
116            "LinRegIntercept(<Set>, <Numeric Expression>[, <Numeric Expression>])",
117            "Calculates the linear regression of a set and returns the value of b in the regression line y = ax + b.",
118            new String[]{"fnxn", "fnxnn"},
119            InterceptFunDef.class);
120
121    static final Resolver PointResolver =
122        new ReflectiveMultiResolver(
123            "LinRegPoint",
124            "LinRegPoint(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])",
125            "Calculates the linear regression of a set and returns the value of y in the regression line y = ax + b.",
126            new String[]{"fnnxn", "fnnxnn"},
127            PointFunDef.class);
128
129    static final Resolver SlopeResolver =
130        new ReflectiveMultiResolver(
131            "LinRegSlope",
132            "LinRegSlope(<Set>, <Numeric Expression>[, <Numeric Expression>])",
133            "Calculates the linear regression of a set and returns the value of a in the regression line y = ax + b.",
134            new String[]{"fnxn", "fnxnn"},
135            SlopeFunDef.class);
136
137    static final Resolver R2Resolver =
138        new ReflectiveMultiResolver(
139            "LinRegR2",
140            "LinRegR2(<Set>, <Numeric Expression>[, <Numeric Expression>])",
141            "Calculates the linear regression of a set and returns R2 (the coefficient of determination).",
142            new String[]{"fnxn", "fnxnn"},
143            R2FunDef.class);
144
145    static final Resolver VarianceResolver =
146        new ReflectiveMultiResolver(
147            "LinRegVariance",
148            "LinRegVariance(<Set>, <Numeric Expression>[, <Numeric Expression>])",
149            "Calculates the linear regression of a set and returns the variance associated with the regression line y = ax + b.",
150            new String[]{"fnxn", "fnxnn"},
151            VarianceFunDef.class);
152
153
154    public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) {
155        final ListCalc listCalc = compiler.compileList(call.getArg(0));
156        final DoubleCalc yCalc = compiler.compileDouble(call.getArg(1));
157        final DoubleCalc xCalc =
158            call.getArgCount() > 2
159            ? compiler.compileDouble(call.getArg(2))
160            : new ValueCalc(call);
161        return new LinRegCalc(call, listCalc, yCalc, xCalc, regType);
162    }
163
164    /////////////////////////////////////////////////////////////////////////
165    //
166    // Helper
167    //
168    /////////////////////////////////////////////////////////////////////////
169    static class Value {
170        private List xs;
171        private List ys;
172        /**
173         * The intercept for the linear regression model. Initialized
174         * following a call to accuracy.
175         */
176        double intercept;
177
178        /**
179         * The slope for the linear regression model. Initialized following a
180         * call to accuracy.
181         */
182        double slope;
183
184         /** the coefficient of determination */
185        double rSquared = Double.MAX_VALUE;
186
187        /** variance = sum square diff mean / n - 1 */
188        double variance = Double.MAX_VALUE;
189
190        Value(double intercept, double slope, List xs, List ys) {
191            this.intercept = intercept;
192            this.slope = slope;
193            this.xs = xs;
194            this.ys = ys;
195        }
196
197        public double getIntercept() {
198            return this.intercept;
199        }
200
201        public double getSlope() {
202            return this.slope;
203        }
204
205        public double getRSquared() {
206            return this.rSquared;
207        }
208
209        /**
210         * strength of the correlation
211         *
212         * @param rSquared Strength of the correlation
213         */
214        public void setRSquared(double rSquared) {
215            this.rSquared = rSquared;
216        }
217
218        public double getVariance() {
219            return this.variance;
220        }
221
222        public void setVariance(double variance) {
223            this.variance = variance;
224        }
225
226        public String toString() {
227            return "LinReg.Value: slope of "
228                + slope
229                + " and an intercept of " + intercept
230                + ". That is, y="
231                + intercept
232                + (slope > 0.0 ? " +" : " ")
233                + slope
234                + " * x.";
235        }
236    }
237
238    /**
239     * Definition of the <code>LinRegIntercept</code> MDX function.
240     *
241     * <p>Synopsis:
242     *
243     * <blockquote><code>LinRegIntercept(&lt;Numeric Expression&gt;,
244     * &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric
245     * Expression&gt;])</code></blockquote>
246     */
247    public static class InterceptFunDef extends LinReg {
248        public InterceptFunDef(FunDef funDef) {
249            super(funDef, Intercept);
250        }
251    }
252
253    /**
254     * Definition of the <code>LinRegPoint</code> MDX function.
255     *
256     * <p>Synopsis:
257     *
258     * <blockquote><code>LinRegPoint(&lt;Numeric Expression&gt;,
259     * &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric
260     * Expression&gt;])</code></blockquote>
261     */
262    public static class PointFunDef extends LinReg {
263        public PointFunDef(FunDef funDef) {
264            super(funDef, Point);
265        }
266
267        public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) {
268            final DoubleCalc xPointCalc =
269                compiler.compileDouble(call.getArg(0));
270            final ListCalc listCalc = compiler.compileList(call.getArg(1));
271            final DoubleCalc yCalc = compiler.compileDouble(call.getArg(2));
272            final DoubleCalc xCalc =
273                call.getArgCount() > 3
274                ? compiler.compileDouble(call.getArg(3))
275                : new ValueCalc(call);
276            return new PointCalc(
277                call, xPointCalc, listCalc, yCalc, xCalc);
278        }
279    }
280
281    private static class PointCalc extends AbstractDoubleCalc {
282        private final DoubleCalc xPointCalc;
283        private final ListCalc listCalc;
284        private final DoubleCalc yCalc;
285        private final DoubleCalc xCalc;
286
287        public PointCalc(
288            ResolvedFunCall call,
289            DoubleCalc xPointCalc,
290            ListCalc listCalc,
291            DoubleCalc yCalc,
292            DoubleCalc xCalc)
293        {
294            super(call, new Calc[]{xPointCalc, listCalc, yCalc, xCalc});
295            this.xPointCalc = xPointCalc;
296            this.listCalc = listCalc;
297            this.yCalc = yCalc;
298            this.xCalc = xCalc;
299        }
300
301        public double evaluateDouble(Evaluator evaluator) {
302            double xPoint = xPointCalc.evaluateDouble(evaluator);
303            Value value = process(evaluator, listCalc, yCalc, xCalc);
304            if (value == null) {
305                return FunUtil.DoubleNull;
306            }
307            // use first arg to generate y position
308            double yPoint =
309                xPoint * value.getSlope()
310                + value.getIntercept();
311            return yPoint;
312        }
313    }
314
315    /**
316     * Definition of the <code>LinRegSlope</code> MDX function.
317     *
318     * <p>Synopsis:
319     *
320     * <blockquote><code>LinRegSlope(&lt;Numeric Expression&gt;,
321     * &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric
322     * Expression&gt;])</code></blockquote>
323     */
324    public static class SlopeFunDef extends LinReg {
325        public SlopeFunDef(FunDef funDef) {
326            super(funDef, Slope);
327        }
328    }
329
330    /**
331     * Definition of the <code>LinRegR2</code> MDX function.
332     *
333     * <p>Synopsis:
334     *
335     * <blockquote><code>LinRegR2(&lt;Numeric Expression&gt;,
336     * &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric
337     * Expression&gt;])</code></blockquote>
338     */
339    public static class R2FunDef extends LinReg {
340        public R2FunDef(FunDef funDef) {
341            super(funDef, R2);
342        }
343    }
344
345    /**
346     * Definition of the <code>LinRegVariance</code> MDX function.
347     *
348     * <p>Synopsis:
349     *
350     * <blockquote><code>LinRegVariance(&lt;Numeric Expression&gt;,
351     * &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric
352     * Expression&gt;])</code></blockquote>
353     */
354    public static class VarianceFunDef extends LinReg {
355        public VarianceFunDef(FunDef funDef) {
356            super(funDef, Variance);
357        }
358    }
359
360    protected static void debug(String type, String msg) {
361        // comment out for no output
362// RME
363        //System.out.println(type + ": " +msg);
364    }
365
366
367    protected LinReg(FunDef funDef, int regType) {
368        super(funDef);
369        this.regType = regType;
370    }
371
372    protected static LinReg.Value process(
373        Evaluator evaluator,
374        ListCalc listCalc,
375        DoubleCalc yCalc,
376        DoubleCalc xCalc)
377    {
378        final int savepoint = evaluator.savepoint();
379        TupleList members;
380        try {
381            evaluator.setNonEmpty(false);
382            members = listCalc.evaluateList(evaluator);
383        } finally {
384            evaluator.restore(savepoint);
385        }
386        SetWrapper[] sws;
387        try {
388            sws =
389                evaluateSet(
390                    evaluator, members, new DoubleCalc[] {yCalc, xCalc});
391        } finally {
392            evaluator.restore(savepoint);
393        }
394        SetWrapper swY = sws[0];
395        SetWrapper swX = sws[1];
396
397        if (swY.errorCount > 0) {
398            debug("LinReg.process", "ERROR error(s) count ="  + swY.errorCount);
399            // TODO: throw exception
400            return null;
401        } else if (swY.v.size() == 0) {
402            return null;
403        }
404
405        return linearReg(swX.v, swY.v);
406    }
407
408    public static LinReg.Value accuracy(LinReg.Value value) {
409        // for variance
410        double sumErrSquared = 0.0;
411
412        double sumErr = 0.0;
413
414        // for r2
415        // data
416        double sumSquaredY = 0.0;
417        double sumY = 0.0;
418        // predicted
419        double sumSquaredYF = 0.0;
420        double sumYF = 0.0;
421
422        // Obtain the forecast values for this model
423        List yfs = forecast(value);
424
425        // Calculate the Sum of the Absolute Errors
426        Iterator ity = value.ys.iterator();
427        Iterator ityf = yfs.iterator();
428        while (ity.hasNext()) {
429            // Get next data point
430            Double dy = (Double) ity.next();
431            if (dy == null) {
432                continue;
433            }
434            Double dyf = (Double) ityf.next();
435            if (dyf == null) {
436                continue;
437            }
438
439            double y = dy.doubleValue();
440            double yf = dyf.doubleValue();
441
442            // Calculate error in forecast, and update sums appropriately
443
444            // the y residual or error
445            double error = yf - y;
446
447            sumErr += error;
448            sumErrSquared += error * error;
449
450            sumY += y;
451            sumSquaredY += (y * y);
452
453            sumYF =+ yf;
454            sumSquaredYF =+ (yf * yf);
455        }
456
457
458        // Initialize the accuracy indicators
459        int n = value.ys.size();
460
461        // Variance
462        // The estimate the value of the error variance is a measure of
463        // variability of the y values about the estimated line.
464        // http://home.ubalt.edu/ntsbarsh/Business-stat/opre504.htm
465        // s2 = SSE/(n-2) = sum (y - yf)2 /(n-2)
466        if (n > 2) {
467            double variance = sumErrSquared / (n - 2);
468
469            value.setVariance(variance);
470        }
471
472        // R2
473        // R2 = 1 - (SSE/SST)
474        // SSE = sum square error = Sum((error-MSE)*(error-MSE))
475        // MSE = mean error = Sum(error)/n
476        // SST = sum square y diff = Sum((y-MST)*(y-MST))
477        // MST = mean y = Sum(y)/n
478        double MSE = sumErr / n;
479        double MST = sumY / n;
480        double SSE = 0.0;
481        double SST = 0.0;
482        ity = value.ys.iterator();
483        ityf = yfs.iterator();
484        while (ity.hasNext()) {
485            // Get next data point
486            Double dy = (Double) ity.next();
487            if (dy == null) {
488                continue;
489            }
490            Double dyf = (Double) ityf.next();
491            if (dyf == null) {
492                continue;
493            }
494
495            double y = dy.doubleValue();
496            double yf = dyf.doubleValue();
497
498            double error = yf - y;
499            SSE += (error - MSE) * (error - MSE);
500            SST += (y - MST) * (y - MST);
501        }
502        if (SST != 0.0) {
503            double rSquared =  1 - (SSE / SST);
504
505            value.setRSquared(rSquared);
506        }
507
508
509        return value;
510    }
511
512    public static LinReg.Value linearReg(List xlist, List ylist) {
513        // y and x have same number of points
514        int size = ylist.size();
515        double sumX = 0.0;
516        double sumY = 0.0;
517        double sumXX = 0.0;
518        double sumXY = 0.0;
519
520        debug("LinReg.linearReg", "ylist.size()=" + ylist.size());
521        debug("LinReg.linearReg", "xlist.size()=" + xlist.size());
522        int n = 0;
523        for (int i = 0; i < size; i++) {
524            Object yo = ylist.get(i);
525            Object xo = xlist.get(i);
526            if ((yo == null) || (xo == null)) {
527                continue;
528            }
529            n++;
530            double y = ((Double) yo).doubleValue();
531            double x = ((Double) xo).doubleValue();
532
533            debug("LinReg.linearReg", " " + i + " (" + x + "," + y + ")");
534            sumX += x;
535            sumY += y;
536            sumXX += x * x;
537            sumXY += x * y;
538        }
539
540        double xMean = sumX / n;
541        double yMean = sumY / n;
542
543        debug("LinReg.linearReg", "yMean=" + yMean);
544        debug(
545            "LinReg.linearReg",
546            "(n*sumXX - sumX*sumX)=" + (n * sumXX - sumX * sumX));
547        // The regression line is the line that minimizes the variance of the
548        // errors. The mean error is zero; so, this means that it minimizes the
549        // sum of the squares errors.
550        double slope = (n * sumXY - sumX * sumY) / (n * sumXX - sumX * sumX);
551        double intercept = yMean - slope * xMean;
552
553        LinReg.Value value = new LinReg.Value(intercept, slope, xlist, ylist);
554        debug("LinReg.linearReg", "value=" + value);
555
556        return value;
557    }
558
559
560    public static List forecast(LinReg.Value value) {
561        List yfs = new ArrayList(value.xs.size());
562
563        Iterator it = value.xs.iterator();
564        while (it.hasNext()) {
565            Double d = (Double) it.next();
566            // If the value is missing we still must put a place
567            // holder in the y axis, otherwise there is a discontinuity
568            // between the data and the fit.
569            if (d == null) {
570                yfs.add(null);
571            } else {
572                double x = d.doubleValue();
573                double yf = value.intercept + value.slope * x;
574                yfs.add(new Double(yf));
575            }
576        }
577
578        return yfs;
579    }
580
581    private static class LinRegCalc extends AbstractDoubleCalc {
582        private final ListCalc listCalc;
583        private final DoubleCalc yCalc;
584        private final DoubleCalc xCalc;
585        private final int regType;
586
587        public LinRegCalc(
588            ResolvedFunCall call,
589            ListCalc listCalc,
590            DoubleCalc yCalc,
591            DoubleCalc xCalc,
592            int regType)
593        {
594            super(call, new Calc[]{listCalc, yCalc, xCalc});
595            this.listCalc = listCalc;
596            this.yCalc = yCalc;
597            this.xCalc = xCalc;
598            this.regType = regType;
599        }
600
601        public double evaluateDouble(Evaluator evaluator) {
602            Value value = process(evaluator, listCalc, yCalc, xCalc);
603            if (value == null) {
604                return FunUtil.DoubleNull;
605            }
606            switch (regType) {
607            case Intercept:
608                return value.getIntercept();
609            case Slope:
610                return value.getSlope();
611            case Variance:
612                return value.getVariance();
613            case R2:
614                return value.getRSquared();
615            default:
616            case Point:
617                throw Util.newInternal("unexpected value " + regType);
618            }
619        }
620    }
621}
622
623// End LinReg.java