01062a92c72c296955f806cf7c3f7b86341534fd
[hive.git] / ql / src / gen / vectorization / UDAFTemplates / VectorUDAFVarDecimal.txt
1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18
19 package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen;
20
21 import java.util.ArrayList;
22 import java.util.List;
23
24 import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
25 import org.apache.hadoop.hive.common.type.HiveDecimal;
26 import org.apache.hadoop.hive.ql.exec.Description;
27 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
28 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
29 import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
30 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
31 import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector;
32 import org.apache.hadoop.hive.ql.metadata.HiveException;
33 import org.apache.hadoop.hive.ql.plan.AggregationDesc;
34 import org.apache.hadoop.hive.ql.util.JavaDataModel;
35 import org.apache.hadoop.io.LongWritable;
36 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
37 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
38 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
39 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
40
41 /**
42 * <ClassName>. Vectorized implementation for VARIANCE aggregates.
43 */
44 @Description(name = "<DescriptionName>",
45     value = "<DescriptionValue>")
46 public class <ClassName> extends VectorAggregateExpression {
47
48     private static final long serialVersionUID = 1L;
49
50     /**
51     /* class for storing the current aggregate value.
52     */
53     private static final class Aggregation implements AggregationBuffer {
54
55       private static final long serialVersionUID = 1L;
56
57       transient private double sum;
58       transient private long count;
59       transient private double variance;
60
61       /**
62       * Value is explicitly (re)initialized in reset()
63       */
64       transient private boolean isNull = true;
65
66       public Aggregation() {
67       }
68
69       public void init() {
70         isNull = false;
71         sum = 0f;
72         count = 0;
73         variance = 0f;
74       }
75
76       @Override
77       public int getVariableSize() {
78         throw new UnsupportedOperationException();
79       }
80
81       @Override
82       public void reset () {
83         isNull = true;
84         sum = 0f;
85         count = 0;
86         variance = 0f;
87       }
88
89       public void updateValueWithCheckAndInit(HiveDecimalWritable value, short scale) {
90         if (this.isNull) {
91           this.init();
92         }
93
94         double dval = value.getHiveDecimal().doubleValue();
95         this.sum += dval;
96         this.count += 1;
97         if(this.count > 1) {
98            double t = this.count*dval - this.sum;
99            this.variance += (t*t) / ((double)this.count*(this.count-1));
100         }
101       }
102
103       public void updateValueNoCheck(HiveDecimalWritable value, short scale) {
104         double dval = value.getHiveDecimal().doubleValue();
105         this.sum += dval;
106         this.count += 1;
107         double t = this.count*dval - this.sum;
108         this.variance += (t*t) / ((double)this.count*(this.count-1));
109       }
110
111     }
112
113     private VectorExpression inputExpression;
114
115     @Override
116     public VectorExpression inputExpression() {
117       return inputExpression;
118     }
119
120     transient private LongWritable resultCount;
121     transient private DoubleWritable resultSum;
122     transient private DoubleWritable resultVariance;
123     transient private Object[] partialResult;
124
125     transient private ObjectInspector soi;
126
127     public <ClassName>(VectorExpression inputExpression) {
128       this();
129       this.inputExpression = inputExpression;
130     }
131
132     public <ClassName>() {
133       super();
134       partialResult = new Object[3];
135       resultCount = new LongWritable();
136       resultSum = new DoubleWritable();
137       resultVariance = new DoubleWritable();
138       partialResult[0] = resultCount;
139       partialResult[1] = resultSum;
140       partialResult[2] = resultVariance;
141       initPartialResultInspector();
142     }
143
144   private void initPartialResultInspector() {
145         List<ObjectInspector> foi = new ArrayList<ObjectInspector>();
146         foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
147         foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
148         foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
149
150         List<String> fname = new ArrayList<String>();
151         fname.add("count");
152         fname.add("sum");
153         fname.add("variance");
154
155         soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
156     }
157
158     private Aggregation getCurrentAggregationBuffer(
159         VectorAggregationBufferRow[] aggregationBufferSets,
160         int aggregateIndex,
161         int row) {
162       VectorAggregationBufferRow mySet = aggregationBufferSets[row];
163       Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex);
164       return myagg;
165     }
166
167
168     @Override
169     public void aggregateInputSelection(
170       VectorAggregationBufferRow[] aggregationBufferSets,
171       int aggregateIndex,
172       VectorizedRowBatch batch) throws HiveException {
173
174       inputExpression.evaluate(batch);
175
176       DecimalColumnVector inputVector = (DecimalColumnVector)batch.
177         cols[this.inputExpression.getOutputColumn()];
178
179       int batchSize = batch.size;
180
181       if (batchSize == 0) {
182         return;
183       }
184
185       HiveDecimalWritable[] vector = inputVector.vector;
186
187       if (inputVector.isRepeating) {
188         if (inputVector.noNulls || !inputVector.isNull[0]) {
189           iterateRepeatingNoNullsWithAggregationSelection(
190             aggregationBufferSets, aggregateIndex, vector[0], inputVector.scale, batchSize);
191         }
192       }
193       else if (!batch.selectedInUse && inputVector.noNulls) {
194         iterateNoSelectionNoNullsWithAggregationSelection(
195             aggregationBufferSets, aggregateIndex, vector, inputVector.scale, batchSize);
196       }
197       else if (!batch.selectedInUse) {
198         iterateNoSelectionHasNullsWithAggregationSelection(
199             aggregationBufferSets, aggregateIndex, vector, inputVector.scale, 
200             batchSize, inputVector.isNull);
201       }
202       else if (inputVector.noNulls){
203         iterateSelectionNoNullsWithAggregationSelection(
204             aggregationBufferSets, aggregateIndex, vector, inputVector.scale, 
205             batchSize, batch.selected);
206       }
207       else {
208         iterateSelectionHasNullsWithAggregationSelection(
209             aggregationBufferSets, aggregateIndex, vector, inputVector.scale, batchSize,
210             inputVector.isNull, batch.selected);
211       }
212
213     }
214
215     private void  iterateRepeatingNoNullsWithAggregationSelection(
216         VectorAggregationBufferRow[] aggregationBufferSets,
217         int aggregateIndex,
218         HiveDecimalWritable value,
219         short scale,
220         int batchSize) {
221
222       for (int i=0; i<batchSize; ++i) {
223         Aggregation myagg = getCurrentAggregationBuffer(
224           aggregationBufferSets,
225           aggregateIndex,
226           i);
227         myagg.updateValueWithCheckAndInit(value, scale);
228       }
229     }
230
231     private void iterateSelectionHasNullsWithAggregationSelection(
232         VectorAggregationBufferRow[] aggregationBufferSets,
233         int aggregateIndex,
234         HiveDecimalWritable[] vector,
235         short scale,
236         int batchSize,
237         boolean[] isNull,
238         int[] selected) {
239
240       for (int j=0; j< batchSize; ++j) {
241         Aggregation myagg = getCurrentAggregationBuffer(
242           aggregationBufferSets,
243           aggregateIndex,
244           j);
245         int i = selected[j];
246         if (!isNull[i]) {
247           HiveDecimalWritable value = vector[i];
248           myagg.updateValueWithCheckAndInit(value, scale);
249         }
250       }
251     }
252
253     private void iterateSelectionNoNullsWithAggregationSelection(
254         VectorAggregationBufferRow[] aggregationBufferSets,
255         int aggregateIndex,
256         HiveDecimalWritable[] vector,
257         short scale,
258         int batchSize,
259         int[] selected) {
260
261       for (int i=0; i< batchSize; ++i) {
262         Aggregation myagg = getCurrentAggregationBuffer(
263           aggregationBufferSets,
264           aggregateIndex,
265           i);
266         HiveDecimalWritable value = vector[selected[i]];
267         myagg.updateValueWithCheckAndInit(value, scale);
268       }
269     }
270
271     private void iterateNoSelectionHasNullsWithAggregationSelection(
272         VectorAggregationBufferRow[] aggregationBufferSets,
273         int aggregateIndex,
274         HiveDecimalWritable[] vector,
275         short scale,
276         int batchSize,
277         boolean[] isNull) {
278
279       for(int i=0;i<batchSize;++i) {
280         if (!isNull[i]) {
281           Aggregation myagg = getCurrentAggregationBuffer(
282             aggregationBufferSets,
283             aggregateIndex,
284           i);
285           HiveDecimalWritable value = vector[i];
286           myagg.updateValueWithCheckAndInit(value, scale);
287         }
288       }
289     }
290
291     private void iterateNoSelectionNoNullsWithAggregationSelection(
292         VectorAggregationBufferRow[] aggregationBufferSets,
293         int aggregateIndex,
294         HiveDecimalWritable[] vector,
295         short scale,
296         int batchSize) {
297
298       for (int i=0; i<batchSize; ++i) {
299         Aggregation myagg = getCurrentAggregationBuffer(
300           aggregationBufferSets,
301           aggregateIndex,
302           i);
303         HiveDecimalWritable value = vector[i];
304         myagg.updateValueWithCheckAndInit(value, scale);
305       }
306     }
307
308     @Override
309     public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch)
310     throws HiveException {
311
312       inputExpression.evaluate(batch);
313
314       DecimalColumnVector inputVector = (DecimalColumnVector)batch.
315         cols[this.inputExpression.getOutputColumn()];
316
317       int batchSize = batch.size;
318
319       if (batchSize == 0) {
320         return;
321       }
322
323       Aggregation myagg = (Aggregation)agg;
324
325       HiveDecimalWritable[] vector = inputVector.vector;
326
327       if (inputVector.isRepeating) {
328         if (inputVector.noNulls) {
329           iterateRepeatingNoNulls(myagg, vector[0], inputVector.scale, batchSize);
330         }
331       }
332       else if (!batch.selectedInUse && inputVector.noNulls) {
333         iterateNoSelectionNoNulls(myagg, vector, inputVector.scale, batchSize);
334       }
335       else if (!batch.selectedInUse) {
336         iterateNoSelectionHasNulls(myagg, vector, inputVector.scale, batchSize, inputVector.isNull);
337       }
338       else if (inputVector.noNulls){
339         iterateSelectionNoNulls(myagg, vector, inputVector.scale, batchSize, batch.selected);
340       }
341       else {
342         iterateSelectionHasNulls(myagg, vector, inputVector.scale, 
343           batchSize, inputVector.isNull, batch.selected);
344       }
345     }
346
347     private void  iterateRepeatingNoNulls(
348         Aggregation myagg,
349         HiveDecimalWritable value,
350         short scale,
351         int batchSize) {
352
353       // TODO: conjure a formula w/o iterating
354       //
355
356       myagg.updateValueWithCheckAndInit(value, scale);
357
358       // We pulled out i=0 so we can remove the count > 1 check in the loop
359       for (int i=1; i<batchSize; ++i) {
360         myagg.updateValueNoCheck(value, scale);
361       }
362     }
363
364     private void iterateSelectionHasNulls(
365         Aggregation myagg,
366         HiveDecimalWritable[] vector,
367         short scale,
368         int batchSize,
369         boolean[] isNull,
370         int[] selected) {
371
372       for (int j=0; j< batchSize; ++j) {
373         int i = selected[j];
374         if (!isNull[i]) {
375           HiveDecimalWritable value = vector[i];
376           myagg.updateValueWithCheckAndInit(value, scale);
377         }
378       }
379     }
380
381     private void iterateSelectionNoNulls(
382         Aggregation myagg,
383         HiveDecimalWritable[] vector,
384         short scale,
385         int batchSize,
386         int[] selected) {
387
388       if (myagg.isNull) {
389         myagg.init ();
390       }
391
392       HiveDecimalWritable value = vector[selected[0]];
393       myagg.updateValueWithCheckAndInit(value, scale);
394
395       // i=0 was pulled out to remove the count > 1 check in the loop
396       //
397       for (int i=1; i< batchSize; ++i) {
398         value = vector[selected[i]];
399         myagg.updateValueNoCheck(value, scale);
400       }
401     }
402
403     private void iterateNoSelectionHasNulls(
404         Aggregation myagg,
405         HiveDecimalWritable[] vector,
406         short scale,
407         int batchSize,
408         boolean[] isNull) {
409
410       for(int i=0;i<batchSize;++i) {
411         if (!isNull[i]) {
412           HiveDecimalWritable value = vector[i];
413           myagg.updateValueWithCheckAndInit(value, scale);
414         }
415       }
416     }
417
418     private void iterateNoSelectionNoNulls(
419         Aggregation myagg,
420         HiveDecimalWritable[] vector,
421         short scale,
422         int batchSize) {
423
424       if (myagg.isNull) {
425         myagg.init ();
426       }
427
428       HiveDecimalWritable value = vector[0];
429       myagg.updateValueWithCheckAndInit(value, scale);
430
431       // i=0 was pulled out to remove count > 1 check
432       for (int i=1; i<batchSize; ++i) {
433         value = vector[i];
434         myagg.updateValueNoCheck(value, scale);
435       }
436     }
437
438     @Override
439     public AggregationBuffer getNewAggregationBuffer() throws HiveException {
440       return new Aggregation();
441     }
442
443     @Override
444     public void reset(AggregationBuffer agg) throws HiveException {
445       Aggregation myAgg = (Aggregation) agg;
446       myAgg.reset();
447     }
448
449     @Override
450     public Object evaluateOutput(
451         AggregationBuffer agg) throws HiveException {
452       Aggregation myagg = (Aggregation) agg;
453       if (myagg.isNull) {
454         return null;
455       }
456       else {
457         assert(0 < myagg.count);
458         resultCount.set(myagg.count);
459         resultSum.set(myagg.sum);
460         resultVariance.set(myagg.variance);
461         return partialResult;
462       }
463     }
464   @Override
465     public ObjectInspector getOutputObjectInspector() {
466       return soi;
467     }
468
469   @Override
470   public int getAggregationBufferFixedSize() {
471       JavaDataModel model = JavaDataModel.get();
472       return JavaDataModel.alignUp(
473         model.object() +
474         model.primitive2()*3+
475         model.primitive1(),
476         model.memoryAlign());
477   }
478
479   @Override
480   public void init(AggregationDesc desc) throws HiveException {
481     // No-op
482   }
483
484   public VectorExpression getInputExpression() {
485     return inputExpression;
486   }
487
488   public void setInputExpression(VectorExpression inputExpression) {
489     this.inputExpression = inputExpression;
490   }
491 }