c6c9c5273f2de73111b4fad7d80fc507d66a7f69
[hive.git] / ql / src / gen / vectorization / UDAFTemplates / VectorUDAFVar.txt
1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18
19 package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen;
20
21 import java.util.ArrayList;
22 import java.util.List;
23
24 import org.apache.hadoop.hive.ql.exec.Description;
25 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
26 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
27 import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
28 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
29 import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
30 import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector;
31 import org.apache.hadoop.hive.ql.metadata.HiveException;
32 import org.apache.hadoop.hive.ql.plan.AggregationDesc;
33 import org.apache.hadoop.hive.ql.util.JavaDataModel;
34 import org.apache.hadoop.io.LongWritable;
35 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
36 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
37 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
38 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
39
40 /**
41 * <ClassName>. Vectorized implementation for VARIANCE aggregates.
42 */
43 @Description(name = "<DescriptionName>",
44     value = "<DescriptionValue>")
45 public class <ClassName> extends VectorAggregateExpression {
46
47     private static final long serialVersionUID = 1L;
48
49     /**
50     /* class for storing the current aggregate value.
51     */
52     private static final class Aggregation implements AggregationBuffer {
53
54       private static final long serialVersionUID = 1L;
55
56       transient private double sum;
57       transient private long count;
58       transient private double variance;
59
60       /**
61       * Value is explicitly (re)initialized in reset() (despite the init() bellow...)
62       */
63       transient private boolean isNull = true;
64
65       public void init() {
66         isNull = false;
67         sum = 0;
68         count = 0;
69         variance = 0;
70       }
71
72       @Override
73       public int getVariableSize() {
74         throw new UnsupportedOperationException();
75       }
76
77       @Override
78       public void reset () {
79         isNull = true;
80         sum = 0;
81         count = 0;
82         variance = 0;
83       }
84     }
85
86     private VectorExpression inputExpression;
87     transient private LongWritable resultCount;
88     transient private DoubleWritable resultSum;
89     transient private DoubleWritable resultVariance;
90     transient private Object[] partialResult;
91
92     transient private ObjectInspector soi;
93
94
95     public <ClassName>(VectorExpression inputExpression) {
96       this();
97       this.inputExpression = inputExpression;
98     }
99
100     public <ClassName>() {
101       super();
102       partialResult = new Object[3];
103       resultCount = new LongWritable();
104       resultSum = new DoubleWritable();
105       resultVariance = new DoubleWritable();
106       partialResult[0] = resultCount;
107       partialResult[1] = resultSum;
108       partialResult[2] = resultVariance;
109       initPartialResultInspector();
110     }
111
112   private void initPartialResultInspector() {
113         List<ObjectInspector> foi = new ArrayList<ObjectInspector>();
114         foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
115         foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
116         foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
117
118         List<String> fname = new ArrayList<String>();
119         fname.add("count");
120         fname.add("sum");
121         fname.add("variance");
122
123         soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
124     }
125
126     private Aggregation getCurrentAggregationBuffer(
127         VectorAggregationBufferRow[] aggregationBufferSets,
128         int aggregateIndex,
129         int row) {
130       VectorAggregationBufferRow mySet = aggregationBufferSets[row];
131       Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex);
132       return myagg;
133     }
134
135
136     @Override
137     public void aggregateInputSelection(
138       VectorAggregationBufferRow[] aggregationBufferSets,
139       int aggregateIndex,
140       VectorizedRowBatch batch) throws HiveException {
141
142       inputExpression.evaluate(batch);
143
144       <InputColumnVectorType> inputVector = (<InputColumnVectorType>)batch.
145         cols[this.inputExpression.getOutputColumn()];
146
147       int batchSize = batch.size;
148
149       if (batchSize == 0) {
150         return;
151       }
152
153       <ValueType>[] vector = inputVector.vector;
154
155       if (inputVector.isRepeating) {
156         if (inputVector.noNulls || !inputVector.isNull[0]) {
157           iterateRepeatingNoNullsWithAggregationSelection(
158             aggregationBufferSets, aggregateIndex, vector[0], batchSize);
159         }
160       }
161       else if (!batch.selectedInUse && inputVector.noNulls) {
162         iterateNoSelectionNoNullsWithAggregationSelection(
163             aggregationBufferSets, aggregateIndex, vector, batchSize);
164       }
165       else if (!batch.selectedInUse) {
166         iterateNoSelectionHasNullsWithAggregationSelection(
167             aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull);
168       }
169       else if (inputVector.noNulls){
170         iterateSelectionNoNullsWithAggregationSelection(
171             aggregationBufferSets, aggregateIndex, vector, batchSize, batch.selected);
172       }
173       else {
174         iterateSelectionHasNullsWithAggregationSelection(
175             aggregationBufferSets, aggregateIndex, vector, batchSize,
176             inputVector.isNull, batch.selected);
177       }
178
179     }
180
181     private void  iterateRepeatingNoNullsWithAggregationSelection(
182         VectorAggregationBufferRow[] aggregationBufferSets,
183         int aggregateIndex,
184         double value,
185         int batchSize) {
186
187       for (int i=0; i<batchSize; ++i) {
188         Aggregation myagg = getCurrentAggregationBuffer(
189           aggregationBufferSets,
190           aggregateIndex,
191           i);
192         if (myagg.isNull) {
193           myagg.init ();
194         }
195         myagg.sum += value;
196         myagg.count += 1;
197         if(myagg.count > 1) {
198           double t = myagg.count*value - myagg.sum;
199           myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
200         }
201       }
202     }
203
204     private void iterateSelectionHasNullsWithAggregationSelection(
205         VectorAggregationBufferRow[] aggregationBufferSets,
206         int aggregateIndex,
207         <ValueType>[] vector,
208         int batchSize,
209         boolean[] isNull,
210         int[] selected) {
211
212       for (int j=0; j< batchSize; ++j) {
213         Aggregation myagg = getCurrentAggregationBuffer(
214           aggregationBufferSets,
215           aggregateIndex,
216           j);
217         int i = selected[j];
218         if (!isNull[i]) {
219           double value = vector[i];
220           if (myagg.isNull) {
221             myagg.init ();
222           }
223           myagg.sum += value;
224           myagg.count += 1;
225           if(myagg.count > 1) {
226             double t = myagg.count*value - myagg.sum;
227             myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
228           }
229         }
230       }
231     }
232
233     private void iterateSelectionNoNullsWithAggregationSelection(
234         VectorAggregationBufferRow[] aggregationBufferSets,
235         int aggregateIndex,
236         <ValueType>[] vector,
237         int batchSize,
238         int[] selected) {
239
240       for (int i=0; i< batchSize; ++i) {
241         Aggregation myagg = getCurrentAggregationBuffer(
242           aggregationBufferSets,
243           aggregateIndex,
244           i);
245         double value = vector[selected[i]];
246         if (myagg.isNull) {
247           myagg.init ();
248         }
249         myagg.sum += value;
250         myagg.count += 1;
251         if(myagg.count > 1) {
252           double t = myagg.count*value - myagg.sum;
253           myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
254         }
255       }
256     }
257
258     private void iterateNoSelectionHasNullsWithAggregationSelection(
259         VectorAggregationBufferRow[] aggregationBufferSets,
260         int aggregateIndex,
261         <ValueType>[] vector,
262         int batchSize,
263         boolean[] isNull) {
264
265       for(int i=0;i<batchSize;++i) {
266         if (!isNull[i]) {
267           Aggregation myagg = getCurrentAggregationBuffer(
268             aggregationBufferSets,
269             aggregateIndex,
270           i);
271           double value = vector[i];
272           if (myagg.isNull) {
273             myagg.init ();
274           }
275           myagg.sum += value;
276           myagg.count += 1;
277         if(myagg.count > 1) {
278           double t = myagg.count*value - myagg.sum;
279           myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
280         }
281         }
282       }
283     }
284
285     private void iterateNoSelectionNoNullsWithAggregationSelection(
286         VectorAggregationBufferRow[] aggregationBufferSets,
287         int aggregateIndex,
288         <ValueType>[] vector,
289         int batchSize) {
290
291       for (int i=0; i<batchSize; ++i) {
292         Aggregation myagg = getCurrentAggregationBuffer(
293           aggregationBufferSets,
294           aggregateIndex,
295           i);
296         if (myagg.isNull) {
297           myagg.init ();
298         }
299         double value = vector[i];
300         myagg.sum += value;
301         myagg.count += 1;
302         if(myagg.count > 1) {
303           double t = myagg.count*value - myagg.sum;
304           myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
305         }
306       }
307     }
308
309     @Override
310     public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch)
311     throws HiveException {
312
313       inputExpression.evaluate(batch);
314
315       <InputColumnVectorType> inputVector = (<InputColumnVectorType>)batch.
316         cols[this.inputExpression.getOutputColumn()];
317
318       int batchSize = batch.size;
319
320       if (batchSize == 0) {
321         return;
322       }
323
324       Aggregation myagg = (Aggregation)agg;
325
326       <ValueType>[] vector = inputVector.vector;
327
328       if (inputVector.isRepeating) {
329         if (inputVector.noNulls) {
330           iterateRepeatingNoNulls(myagg, vector[0], batchSize);
331         }
332       }
333       else if (!batch.selectedInUse && inputVector.noNulls) {
334         iterateNoSelectionNoNulls(myagg, vector, batchSize);
335       }
336       else if (!batch.selectedInUse) {
337         iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull);
338       }
339       else if (inputVector.noNulls){
340         iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected);
341       }
342       else {
343         iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected);
344       }
345     }
346
347     private void  iterateRepeatingNoNulls(
348         Aggregation myagg,
349         double value,
350         int batchSize) {
351
352       if (myagg.isNull) {
353         myagg.init ();
354       }
355
356       // TODO: conjure a formula w/o iterating
357       //
358
359       myagg.sum += value;
360       myagg.count += 1;
361       if(myagg.count > 1) {
362         double t = myagg.count*value - myagg.sum;
363         myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
364       }
365
366       // We pulled out i=0 so we can remove the count > 1 check in the loop
367       for (int i=1; i<batchSize; ++i) {
368         myagg.sum += value;
369         myagg.count += 1;
370         double t = myagg.count*value - myagg.sum;
371         myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
372       }
373     }
374
375     private void iterateSelectionHasNulls(
376         Aggregation myagg,
377         <ValueType>[] vector,
378         int batchSize,
379         boolean[] isNull,
380         int[] selected) {
381
382       for (int j=0; j< batchSize; ++j) {
383         int i = selected[j];
384         if (!isNull[i]) {
385           double value = vector[i];
386           if (myagg.isNull) {
387             myagg.init ();
388           }
389           myagg.sum += value;
390           myagg.count += 1;
391         if(myagg.count > 1) {
392           double t = myagg.count*value - myagg.sum;
393           myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
394         }
395         }
396       }
397     }
398
399     private void iterateSelectionNoNulls(
400         Aggregation myagg,
401         <ValueType>[] vector,
402         int batchSize,
403         int[] selected) {
404
405       if (myagg.isNull) {
406         myagg.init ();
407       }
408
409       double value = vector[selected[0]];
410       myagg.sum += value;
411       myagg.count += 1;
412       if(myagg.count > 1) {
413         double t = myagg.count*value - myagg.sum;
414         myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
415       }
416
417       // i=0 was pulled out to remove the count > 1 check in the loop
418       //
419       for (int i=1; i< batchSize; ++i) {
420         value = vector[selected[i]];
421         myagg.sum += value;
422         myagg.count += 1;
423         double t = myagg.count*value - myagg.sum;
424         myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
425       }
426     }
427
428     private void iterateNoSelectionHasNulls(
429         Aggregation myagg,
430         <ValueType>[] vector,
431         int batchSize,
432         boolean[] isNull) {
433
434       for(int i=0;i<batchSize;++i) {
435         if (!isNull[i]) {
436           double value = vector[i];
437           if (myagg.isNull) {
438             myagg.init ();
439           }
440           myagg.sum += value;
441           myagg.count += 1;
442         if(myagg.count > 1) {
443           double t = myagg.count*value - myagg.sum;
444           myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
445         }
446         }
447       }
448     }
449
450     private void iterateNoSelectionNoNulls(
451         Aggregation myagg,
452         <ValueType>[] vector,
453         int batchSize) {
454
455       if (myagg.isNull) {
456         myagg.init ();
457       }
458
459       double value = vector[0];
460       myagg.sum += value;
461       myagg.count += 1;
462
463       if(myagg.count > 1) {
464         double t = myagg.count*value - myagg.sum;
465         myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
466       }
467
468       // i=0 was pulled out to remove count > 1 check
469       for (int i=1; i<batchSize; ++i) {
470         value = vector[i];
471         myagg.sum += value;
472         myagg.count += 1;
473         double t = myagg.count*value - myagg.sum;
474         myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
475       }
476     }
477
478     @Override
479     public AggregationBuffer getNewAggregationBuffer() throws HiveException {
480       return new Aggregation();
481     }
482
483     @Override
484     public void reset(AggregationBuffer agg) throws HiveException {
485       Aggregation myAgg = (Aggregation) agg;
486       myAgg.reset();
487     }
488
489     @Override
490     public Object evaluateOutput(
491         AggregationBuffer agg) throws HiveException {
492       Aggregation myagg = (Aggregation) agg;
493       if (myagg.isNull) {
494         return null;
495       }
496       else {
497         assert(0 < myagg.count);
498         resultCount.set (myagg.count);
499         resultSum.set (myagg.sum);
500         resultVariance.set (myagg.variance);
501         return partialResult;
502       }
503     }
504   @Override
505     public ObjectInspector getOutputObjectInspector() {
506       return soi;
507     }
508
509   @Override
510   public int getAggregationBufferFixedSize() {
511       JavaDataModel model = JavaDataModel.get();
512       return JavaDataModel.alignUp(
513         model.object() +
514         model.primitive2()*3+
515         model.primitive1(),
516         model.memoryAlign());
517   }
518
519   @Override
520   public void init(AggregationDesc desc) throws HiveException {
521     // No-op
522   }
523
524   public VectorExpression getInputExpression() {
525     return inputExpression;
526   }
527
528   public void setInputExpression(VectorExpression inputExpression) {
529     this.inputExpression = inputExpression;
530   }
531 }
532