001/* 002// This software is subject to the terms of the Eclipse Public License v1.0 003// Agreement, available at the following URL: 004// http://www.eclipse.org/legal/epl-v10.html. 005// You must accept the terms of that agreement to use this software. 006// 007// Copyright (C) 2005-2005 Julian Hyde 008// Copyright (C) 2005-2011 Pentaho 009// All Rights Reserved. 010*/ 011package mondrian.olap.fun; 012 013import mondrian.calc.*; 014import mondrian.calc.impl.AbstractDoubleCalc; 015import mondrian.calc.impl.ValueCalc; 016import mondrian.mdx.ResolvedFunCall; 017import mondrian.olap.*; 018 019import java.util.*; 020 021/** 022 * Abstract base class for definitions of linear regression functions. 023 * 024 * @see InterceptFunDef 025 * @see PointFunDef 026 * @see R2FunDef 027 * @see SlopeFunDef 028 * @see VarianceFunDef 029 * 030 * <h2>Correlation coefficient</h2> 031 * <p><i>Correlation coefficient</i></p> 032 * 033 * <p>The correlation coefficient, r, ranges from -1 to + 1. The 034 * nonparametric Spearman correlation coefficient, abbreviated rs, has 035 * the same range.</p> 036 * 037 * <table border="1" cellpadding="6" cellspacing="0"> 038 * <tr> 039 * <td>Value of r (or rs)</td> 040 * <td>Interpretation</td> 041 * </tr> 042 * <tr> 043 * <td valign="top">r= 0</td> 044 * 045 * <td>The two variables do not vary together at all.</td> 046 * </tr> 047 * <tr> 048 * <td valign="top">0 > r > 1</td> 049 * <td> 050 * <p>The two variables tend to increase or decrease together.</p> 051 * </td> 052 * </tr> 053 * <tr> 054 * <td valign="top">r = 1.0</td> 055 * <td> 056 * <p>Perfect correlation.</p> 057 * </td> 058 * </tr> 059 * 060 * <tr> 061 * <td valign="top">-1 > r > 0</td> 062 * <td> 063 * <p>One variable increases as the other decreases.</p> 064 * </td> 065 * </tr> 066 * 067 * <tr> 068 * <td valign="top">r = -1.0</td> 069 * <td> 070 * <p></p> 071 * <p>Perfect negative or inverse correlation.</p> 072 * </td> 073 * </tr> 074 * </table> 075 * 076 * <p>If r or rs is far from zero, there are four possible explanations:</p> 077 * <p>The X variable helps determine the value of the Y variable.</p> 078 * <ul> 079 * <li>The Y variable helps determine the value of the X variable. 080 * <li>Another variable influences both X and Y. 081 * <li>X and Y don't really correlate at all, and you just 082 * happened to observe such a strong correlation by chance. The P value 083 * determines how often this could occur. 084 * </ul> 085 * <p><i>r2 </i></p> 086 * 087 * <p>Perhaps the best way to interpret the value of r is to square it to 088 * calculate r2. Statisticians call this quantity the coefficient of 089 * determination, but scientists call it r squared. It is has a value 090 * that ranges from zero to one, and is the fraction of the variance in 091 * the two variables that is shared. For example, if r2=0.59, then 59% of 092 * the variance in X can be explained by variation in Y. Likewise, 093 * 59% of the variance in Y can be explained by (or goes along with) 094 * variation in X. More simply, 59% of the variance is shared between X 095 * and Y.</p> 096 * 097 * <p>(<a href="http://www.graphpad.com/articles/interpret/corl_n_linear_reg/correlation.htm">Source</a>). 098 * 099 * <p>Also see: <a href="http://mathworld.wolfram.com/LeastSquaresFitting.html">least squares fitting</a>. 100 */ 101 102 103public abstract class LinReg extends FunDefBase { 104 /** Code for the specific function. */ 105 final int regType; 106 107 public static final int Point = 0; 108 public static final int R2 = 1; 109 public static final int Intercept = 2; 110 public static final int Slope = 3; 111 public static final int Variance = 4; 112 113 static final Resolver InterceptResolver = 114 new ReflectiveMultiResolver( 115 "LinRegIntercept", 116 "LinRegIntercept(<Set>, <Numeric Expression>[, <Numeric Expression>])", 117 "Calculates the linear regression of a set and returns the value of b in the regression line y = ax + b.", 118 new String[]{"fnxn", "fnxnn"}, 119 InterceptFunDef.class); 120 121 static final Resolver PointResolver = 122 new ReflectiveMultiResolver( 123 "LinRegPoint", 124 "LinRegPoint(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])", 125 "Calculates the linear regression of a set and returns the value of y in the regression line y = ax + b.", 126 new String[]{"fnnxn", "fnnxnn"}, 127 PointFunDef.class); 128 129 static final Resolver SlopeResolver = 130 new ReflectiveMultiResolver( 131 "LinRegSlope", 132 "LinRegSlope(<Set>, <Numeric Expression>[, <Numeric Expression>])", 133 "Calculates the linear regression of a set and returns the value of a in the regression line y = ax + b.", 134 new String[]{"fnxn", "fnxnn"}, 135 SlopeFunDef.class); 136 137 static final Resolver R2Resolver = 138 new ReflectiveMultiResolver( 139 "LinRegR2", 140 "LinRegR2(<Set>, <Numeric Expression>[, <Numeric Expression>])", 141 "Calculates the linear regression of a set and returns R2 (the coefficient of determination).", 142 new String[]{"fnxn", "fnxnn"}, 143 R2FunDef.class); 144 145 static final Resolver VarianceResolver = 146 new ReflectiveMultiResolver( 147 "LinRegVariance", 148 "LinRegVariance(<Set>, <Numeric Expression>[, <Numeric Expression>])", 149 "Calculates the linear regression of a set and returns the variance associated with the regression line y = ax + b.", 150 new String[]{"fnxn", "fnxnn"}, 151 VarianceFunDef.class); 152 153 154 public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) { 155 final ListCalc listCalc = compiler.compileList(call.getArg(0)); 156 final DoubleCalc yCalc = compiler.compileDouble(call.getArg(1)); 157 final DoubleCalc xCalc = 158 call.getArgCount() > 2 159 ? compiler.compileDouble(call.getArg(2)) 160 : new ValueCalc(call); 161 return new LinRegCalc(call, listCalc, yCalc, xCalc, regType); 162 } 163 164 ///////////////////////////////////////////////////////////////////////// 165 // 166 // Helper 167 // 168 ///////////////////////////////////////////////////////////////////////// 169 static class Value { 170 private List xs; 171 private List ys; 172 /** 173 * The intercept for the linear regression model. Initialized 174 * following a call to accuracy. 175 */ 176 double intercept; 177 178 /** 179 * The slope for the linear regression model. Initialized following a 180 * call to accuracy. 181 */ 182 double slope; 183 184 /** the coefficient of determination */ 185 double rSquared = Double.MAX_VALUE; 186 187 /** variance = sum square diff mean / n - 1 */ 188 double variance = Double.MAX_VALUE; 189 190 Value(double intercept, double slope, List xs, List ys) { 191 this.intercept = intercept; 192 this.slope = slope; 193 this.xs = xs; 194 this.ys = ys; 195 } 196 197 public double getIntercept() { 198 return this.intercept; 199 } 200 201 public double getSlope() { 202 return this.slope; 203 } 204 205 public double getRSquared() { 206 return this.rSquared; 207 } 208 209 /** 210 * strength of the correlation 211 * 212 * @param rSquared Strength of the correlation 213 */ 214 public void setRSquared(double rSquared) { 215 this.rSquared = rSquared; 216 } 217 218 public double getVariance() { 219 return this.variance; 220 } 221 222 public void setVariance(double variance) { 223 this.variance = variance; 224 } 225 226 public String toString() { 227 return "LinReg.Value: slope of " 228 + slope 229 + " and an intercept of " + intercept 230 + ". That is, y=" 231 + intercept 232 + (slope > 0.0 ? " +" : " ") 233 + slope 234 + " * x."; 235 } 236 } 237 238 /** 239 * Definition of the <code>LinRegIntercept</code> MDX function. 240 * 241 * <p>Synopsis: 242 * 243 * <blockquote><code>LinRegIntercept(<Numeric Expression>, 244 * <Set>, <Numeric Expression>[, <Numeric 245 * Expression>])</code></blockquote> 246 */ 247 public static class InterceptFunDef extends LinReg { 248 public InterceptFunDef(FunDef funDef) { 249 super(funDef, Intercept); 250 } 251 } 252 253 /** 254 * Definition of the <code>LinRegPoint</code> MDX function. 255 * 256 * <p>Synopsis: 257 * 258 * <blockquote><code>LinRegPoint(<Numeric Expression>, 259 * <Set>, <Numeric Expression>[, <Numeric 260 * Expression>])</code></blockquote> 261 */ 262 public static class PointFunDef extends LinReg { 263 public PointFunDef(FunDef funDef) { 264 super(funDef, Point); 265 } 266 267 public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) { 268 final DoubleCalc xPointCalc = 269 compiler.compileDouble(call.getArg(0)); 270 final ListCalc listCalc = compiler.compileList(call.getArg(1)); 271 final DoubleCalc yCalc = compiler.compileDouble(call.getArg(2)); 272 final DoubleCalc xCalc = 273 call.getArgCount() > 3 274 ? compiler.compileDouble(call.getArg(3)) 275 : new ValueCalc(call); 276 return new PointCalc( 277 call, xPointCalc, listCalc, yCalc, xCalc); 278 } 279 } 280 281 private static class PointCalc extends AbstractDoubleCalc { 282 private final DoubleCalc xPointCalc; 283 private final ListCalc listCalc; 284 private final DoubleCalc yCalc; 285 private final DoubleCalc xCalc; 286 287 public PointCalc( 288 ResolvedFunCall call, 289 DoubleCalc xPointCalc, 290 ListCalc listCalc, 291 DoubleCalc yCalc, 292 DoubleCalc xCalc) 293 { 294 super(call, new Calc[]{xPointCalc, listCalc, yCalc, xCalc}); 295 this.xPointCalc = xPointCalc; 296 this.listCalc = listCalc; 297 this.yCalc = yCalc; 298 this.xCalc = xCalc; 299 } 300 301 public double evaluateDouble(Evaluator evaluator) { 302 double xPoint = xPointCalc.evaluateDouble(evaluator); 303 Value value = process(evaluator, listCalc, yCalc, xCalc); 304 if (value == null) { 305 return FunUtil.DoubleNull; 306 } 307 // use first arg to generate y position 308 double yPoint = 309 xPoint * value.getSlope() 310 + value.getIntercept(); 311 return yPoint; 312 } 313 } 314 315 /** 316 * Definition of the <code>LinRegSlope</code> MDX function. 317 * 318 * <p>Synopsis: 319 * 320 * <blockquote><code>LinRegSlope(<Numeric Expression>, 321 * <Set>, <Numeric Expression>[, <Numeric 322 * Expression>])</code></blockquote> 323 */ 324 public static class SlopeFunDef extends LinReg { 325 public SlopeFunDef(FunDef funDef) { 326 super(funDef, Slope); 327 } 328 } 329 330 /** 331 * Definition of the <code>LinRegR2</code> MDX function. 332 * 333 * <p>Synopsis: 334 * 335 * <blockquote><code>LinRegR2(<Numeric Expression>, 336 * <Set>, <Numeric Expression>[, <Numeric 337 * Expression>])</code></blockquote> 338 */ 339 public static class R2FunDef extends LinReg { 340 public R2FunDef(FunDef funDef) { 341 super(funDef, R2); 342 } 343 } 344 345 /** 346 * Definition of the <code>LinRegVariance</code> MDX function. 347 * 348 * <p>Synopsis: 349 * 350 * <blockquote><code>LinRegVariance(<Numeric Expression>, 351 * <Set>, <Numeric Expression>[, <Numeric 352 * Expression>])</code></blockquote> 353 */ 354 public static class VarianceFunDef extends LinReg { 355 public VarianceFunDef(FunDef funDef) { 356 super(funDef, Variance); 357 } 358 } 359 360 protected static void debug(String type, String msg) { 361 // comment out for no output 362// RME 363 //System.out.println(type + ": " +msg); 364 } 365 366 367 protected LinReg(FunDef funDef, int regType) { 368 super(funDef); 369 this.regType = regType; 370 } 371 372 protected static LinReg.Value process( 373 Evaluator evaluator, 374 ListCalc listCalc, 375 DoubleCalc yCalc, 376 DoubleCalc xCalc) 377 { 378 final int savepoint = evaluator.savepoint(); 379 TupleList members; 380 try { 381 evaluator.setNonEmpty(false); 382 members = listCalc.evaluateList(evaluator); 383 } finally { 384 evaluator.restore(savepoint); 385 } 386 SetWrapper[] sws; 387 try { 388 sws = 389 evaluateSet( 390 evaluator, members, new DoubleCalc[] {yCalc, xCalc}); 391 } finally { 392 evaluator.restore(savepoint); 393 } 394 SetWrapper swY = sws[0]; 395 SetWrapper swX = sws[1]; 396 397 if (swY.errorCount > 0) { 398 debug("LinReg.process", "ERROR error(s) count =" + swY.errorCount); 399 // TODO: throw exception 400 return null; 401 } else if (swY.v.size() == 0) { 402 return null; 403 } 404 405 return linearReg(swX.v, swY.v); 406 } 407 408 public static LinReg.Value accuracy(LinReg.Value value) { 409 // for variance 410 double sumErrSquared = 0.0; 411 412 double sumErr = 0.0; 413 414 // for r2 415 // data 416 double sumSquaredY = 0.0; 417 double sumY = 0.0; 418 // predicted 419 double sumSquaredYF = 0.0; 420 double sumYF = 0.0; 421 422 // Obtain the forecast values for this model 423 List yfs = forecast(value); 424 425 // Calculate the Sum of the Absolute Errors 426 Iterator ity = value.ys.iterator(); 427 Iterator ityf = yfs.iterator(); 428 while (ity.hasNext()) { 429 // Get next data point 430 Double dy = (Double) ity.next(); 431 if (dy == null) { 432 continue; 433 } 434 Double dyf = (Double) ityf.next(); 435 if (dyf == null) { 436 continue; 437 } 438 439 double y = dy.doubleValue(); 440 double yf = dyf.doubleValue(); 441 442 // Calculate error in forecast, and update sums appropriately 443 444 // the y residual or error 445 double error = yf - y; 446 447 sumErr += error; 448 sumErrSquared += error * error; 449 450 sumY += y; 451 sumSquaredY += (y * y); 452 453 sumYF =+ yf; 454 sumSquaredYF =+ (yf * yf); 455 } 456 457 458 // Initialize the accuracy indicators 459 int n = value.ys.size(); 460 461 // Variance 462 // The estimate the value of the error variance is a measure of 463 // variability of the y values about the estimated line. 464 // http://home.ubalt.edu/ntsbarsh/Business-stat/opre504.htm 465 // s2 = SSE/(n-2) = sum (y - yf)2 /(n-2) 466 if (n > 2) { 467 double variance = sumErrSquared / (n - 2); 468 469 value.setVariance(variance); 470 } 471 472 // R2 473 // R2 = 1 - (SSE/SST) 474 // SSE = sum square error = Sum((error-MSE)*(error-MSE)) 475 // MSE = mean error = Sum(error)/n 476 // SST = sum square y diff = Sum((y-MST)*(y-MST)) 477 // MST = mean y = Sum(y)/n 478 double MSE = sumErr / n; 479 double MST = sumY / n; 480 double SSE = 0.0; 481 double SST = 0.0; 482 ity = value.ys.iterator(); 483 ityf = yfs.iterator(); 484 while (ity.hasNext()) { 485 // Get next data point 486 Double dy = (Double) ity.next(); 487 if (dy == null) { 488 continue; 489 } 490 Double dyf = (Double) ityf.next(); 491 if (dyf == null) { 492 continue; 493 } 494 495 double y = dy.doubleValue(); 496 double yf = dyf.doubleValue(); 497 498 double error = yf - y; 499 SSE += (error - MSE) * (error - MSE); 500 SST += (y - MST) * (y - MST); 501 } 502 if (SST != 0.0) { 503 double rSquared = 1 - (SSE / SST); 504 505 value.setRSquared(rSquared); 506 } 507 508 509 return value; 510 } 511 512 public static LinReg.Value linearReg(List xlist, List ylist) { 513 // y and x have same number of points 514 int size = ylist.size(); 515 double sumX = 0.0; 516 double sumY = 0.0; 517 double sumXX = 0.0; 518 double sumXY = 0.0; 519 520 debug("LinReg.linearReg", "ylist.size()=" + ylist.size()); 521 debug("LinReg.linearReg", "xlist.size()=" + xlist.size()); 522 int n = 0; 523 for (int i = 0; i < size; i++) { 524 Object yo = ylist.get(i); 525 Object xo = xlist.get(i); 526 if ((yo == null) || (xo == null)) { 527 continue; 528 } 529 n++; 530 double y = ((Double) yo).doubleValue(); 531 double x = ((Double) xo).doubleValue(); 532 533 debug("LinReg.linearReg", " " + i + " (" + x + "," + y + ")"); 534 sumX += x; 535 sumY += y; 536 sumXX += x * x; 537 sumXY += x * y; 538 } 539 540 double xMean = sumX / n; 541 double yMean = sumY / n; 542 543 debug("LinReg.linearReg", "yMean=" + yMean); 544 debug( 545 "LinReg.linearReg", 546 "(n*sumXX - sumX*sumX)=" + (n * sumXX - sumX * sumX)); 547 // The regression line is the line that minimizes the variance of the 548 // errors. The mean error is zero; so, this means that it minimizes the 549 // sum of the squares errors. 550 double slope = (n * sumXY - sumX * sumY) / (n * sumXX - sumX * sumX); 551 double intercept = yMean - slope * xMean; 552 553 LinReg.Value value = new LinReg.Value(intercept, slope, xlist, ylist); 554 debug("LinReg.linearReg", "value=" + value); 555 556 return value; 557 } 558 559 560 public static List forecast(LinReg.Value value) { 561 List yfs = new ArrayList(value.xs.size()); 562 563 Iterator it = value.xs.iterator(); 564 while (it.hasNext()) { 565 Double d = (Double) it.next(); 566 // If the value is missing we still must put a place 567 // holder in the y axis, otherwise there is a discontinuity 568 // between the data and the fit. 569 if (d == null) { 570 yfs.add(null); 571 } else { 572 double x = d.doubleValue(); 573 double yf = value.intercept + value.slope * x; 574 yfs.add(new Double(yf)); 575 } 576 } 577 578 return yfs; 579 } 580 581 private static class LinRegCalc extends AbstractDoubleCalc { 582 private final ListCalc listCalc; 583 private final DoubleCalc yCalc; 584 private final DoubleCalc xCalc; 585 private final int regType; 586 587 public LinRegCalc( 588 ResolvedFunCall call, 589 ListCalc listCalc, 590 DoubleCalc yCalc, 591 DoubleCalc xCalc, 592 int regType) 593 { 594 super(call, new Calc[]{listCalc, yCalc, xCalc}); 595 this.listCalc = listCalc; 596 this.yCalc = yCalc; 597 this.xCalc = xCalc; 598 this.regType = regType; 599 } 600 601 public double evaluateDouble(Evaluator evaluator) { 602 Value value = process(evaluator, listCalc, yCalc, xCalc); 603 if (value == null) { 604 return FunUtil.DoubleNull; 605 } 606 switch (regType) { 607 case Intercept: 608 return value.getIntercept(); 609 case Slope: 610 return value.getSlope(); 611 case Variance: 612 return value.getVariance(); 613 case R2: 614 return value.getRSquared(); 615 default: 616 case Point: 617 throw Util.newInternal("unexpected value " + regType); 618 } 619 } 620 } 621} 622 623// End LinReg.java