4393c3b063fe68892570fe22194622e736a9f082
[hive.git] / ql / src / gen / vectorization / UDAFTemplates / VectorUDAFAvg.txt
1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18
19 package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen;
20
21 import java.util.ArrayList;
22 import java.util.List;
23
24 import org.apache.hadoop.hive.ql.exec.Description;
25 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
26 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
27 import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
28 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
29 import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
30 import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector;
31 import org.apache.hadoop.hive.ql.metadata.HiveException;
32 import org.apache.hadoop.hive.ql.plan.AggregationDesc;
33 import org.apache.hadoop.hive.ql.util.JavaDataModel;
34 import org.apache.hadoop.io.LongWritable;
35 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
36 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
37 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
38 import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
39 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
40
41 /**
42  * Generated from template VectorUDAFAvg.txt.
43  */
44 @Description(name = "avg",
45     value = "_FUNC_(expr) - Returns the average value of expr (vectorized, type: <ValueType>)")
46 public class <ClassName> extends VectorAggregateExpression {
47
48     private static final long serialVersionUID = 1L;
49     
50     /** class for storing the current aggregate value. */
51     static class Aggregation implements AggregationBuffer {
52
53       private static final long serialVersionUID = 1L;
54
55       transient private double sum;
56       transient private long count;
57
58       /**
59       * Value is explicitly (re)initialized in reset()
60       */
61       transient private boolean isNull = true;
62       
63       public void sumValue(<ValueType> value) {
64         if (isNull) {
65           sum = value; 
66           count = 1;
67           isNull = false;
68         } else {
69           sum += value;
70           count++;
71         }
72       }
73
74       @Override
75       public int getVariableSize() {
76         throw new UnsupportedOperationException();
77       }
78       
79       @Override
80       public void reset () {
81         isNull = true;
82         sum = 0;
83         count = 0L;
84       }
85     }
86     
87     private VectorExpression inputExpression;
88
89     @Override
90     public VectorExpression inputExpression() {
91       return inputExpression;
92     }
93
94     transient private Object[] partialResult;
95     transient private LongWritable resultCount;
96     transient private DoubleWritable resultSum;
97     transient private StructObjectInspector soi;
98         
99     public <ClassName>(VectorExpression inputExpression) {
100       this();
101       this.inputExpression = inputExpression;
102     }
103
104     public <ClassName>() {
105       super();
106       partialResult = new Object[2];
107       resultCount = new LongWritable();
108       resultSum = new DoubleWritable();
109       partialResult[0] = resultCount;
110       partialResult[1] = resultSum;
111       initPartialResultInspector();
112     }
113
114     private void initPartialResultInspector() {
115         List<ObjectInspector> foi = new ArrayList<ObjectInspector>();
116         foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
117         foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
118         List<String> fname = new ArrayList<String>();
119         fname.add("count");
120         fname.add("sum");
121         soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
122     }
123     
124     private Aggregation getCurrentAggregationBuffer(
125         VectorAggregationBufferRow[] aggregationBufferSets,
126         int bufferIndex,
127         int row) {
128       VectorAggregationBufferRow mySet = aggregationBufferSets[row];
129       Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(bufferIndex);
130       return myagg;
131     }
132     
133     @Override
134     public void aggregateInputSelection(
135       VectorAggregationBufferRow[] aggregationBufferSets,
136       int bufferIndex, 
137       VectorizedRowBatch batch) throws HiveException {
138       
139       int batchSize = batch.size;
140       
141       if (batchSize == 0) {
142         return;
143       }
144       
145       inputExpression.evaluate(batch);
146       
147        <InputColumnVectorType> inputVector = ( <InputColumnVectorType>)batch.
148         cols[this.inputExpression.getOutputColumn()];
149       <ValueType>[] vector = inputVector.vector;
150
151       if (inputVector.noNulls) {
152         if (inputVector.isRepeating) {
153           iterateNoNullsRepeatingWithAggregationSelection(
154             aggregationBufferSets, bufferIndex,
155             vector[0], batchSize);
156         } else {
157           if (batch.selectedInUse) {
158             iterateNoNullsSelectionWithAggregationSelection(
159               aggregationBufferSets, bufferIndex,
160               vector, batch.selected, batchSize);
161           } else {
162             iterateNoNullsWithAggregationSelection(
163               aggregationBufferSets, bufferIndex,
164               vector, batchSize);
165           }
166         }
167       } else {
168         if (inputVector.isRepeating) {
169           if (batch.selectedInUse) {
170             iterateHasNullsRepeatingSelectionWithAggregationSelection(
171               aggregationBufferSets, bufferIndex,
172               vector[0], batchSize, batch.selected, inputVector.isNull);
173           } else {
174             iterateHasNullsRepeatingWithAggregationSelection(
175               aggregationBufferSets, bufferIndex,
176               vector[0], batchSize, inputVector.isNull);
177           }
178         } else {
179           if (batch.selectedInUse) {
180             iterateHasNullsSelectionWithAggregationSelection(
181               aggregationBufferSets, bufferIndex,
182               vector, batchSize, batch.selected, inputVector.isNull);
183           } else {
184             iterateHasNullsWithAggregationSelection(
185               aggregationBufferSets, bufferIndex,
186               vector, batchSize, inputVector.isNull);
187           }
188         }
189       }
190     }
191
192     private void iterateNoNullsRepeatingWithAggregationSelection(
193       VectorAggregationBufferRow[] aggregationBufferSets,
194       int bufferIndex,
195       <ValueType> value,
196       int batchSize) {
197
198       for (int i=0; i < batchSize; ++i) {
199         Aggregation myagg = getCurrentAggregationBuffer(
200           aggregationBufferSets, 
201           bufferIndex,
202           i);
203         myagg.sumValue(value);
204       }
205     } 
206
207     private void iterateNoNullsSelectionWithAggregationSelection(
208       VectorAggregationBufferRow[] aggregationBufferSets,
209       int bufferIndex,
210       <ValueType>[] values,
211       int[] selection,
212       int batchSize) {
213       
214       for (int i=0; i < batchSize; ++i) {
215         Aggregation myagg = getCurrentAggregationBuffer(
216           aggregationBufferSets, 
217           bufferIndex,
218           i);
219         myagg.sumValue(values[selection[i]]);
220       }
221     }
222
223     private void iterateNoNullsWithAggregationSelection(
224       VectorAggregationBufferRow[] aggregationBufferSets,
225       int bufferIndex,
226       <ValueType>[] values,
227       int batchSize) {
228       for (int i=0; i < batchSize; ++i) {
229         Aggregation myagg = getCurrentAggregationBuffer(
230           aggregationBufferSets, 
231           bufferIndex,
232           i);
233         myagg.sumValue(values[i]);
234       }
235     }
236
237     private void iterateHasNullsRepeatingSelectionWithAggregationSelection(
238       VectorAggregationBufferRow[] aggregationBufferSets,
239       int bufferIndex,
240       <ValueType> value,
241       int batchSize,
242       int[] selection,
243       boolean[] isNull) {
244
245       if (isNull[0]) {
246         return;
247       }
248       
249       for (int i=0; i < batchSize; ++i) {
250         Aggregation myagg = getCurrentAggregationBuffer(
251           aggregationBufferSets,
252           bufferIndex,
253           i);
254         myagg.sumValue(value);
255       }
256       
257     }
258
259     private void iterateHasNullsRepeatingWithAggregationSelection(
260       VectorAggregationBufferRow[] aggregationBufferSets,
261       int bufferIndex,
262       <ValueType> value,
263       int batchSize,
264       boolean[] isNull) {
265
266       if (isNull[0]) {
267         return;
268       }
269
270       for (int i=0; i < batchSize; ++i) {
271         Aggregation myagg = getCurrentAggregationBuffer(
272           aggregationBufferSets,
273           bufferIndex,
274           i);
275         myagg.sumValue(value);
276       }
277     }
278
279     private void iterateHasNullsSelectionWithAggregationSelection(
280       VectorAggregationBufferRow[] aggregationBufferSets,
281       int bufferIndex,
282       <ValueType>[] values,
283       int batchSize,
284       int[] selection,
285       boolean[] isNull) {
286
287       for (int j=0; j < batchSize; ++j) {
288         int i = selection[j];
289         if (!isNull[i]) {
290           Aggregation myagg = getCurrentAggregationBuffer(
291             aggregationBufferSets, 
292             bufferIndex,
293             j);
294           myagg.sumValue(values[i]);
295         }
296       }
297    }
298
299     private void iterateHasNullsWithAggregationSelection(
300       VectorAggregationBufferRow[] aggregationBufferSets,
301       int bufferIndex,
302       <ValueType>[] values,
303       int batchSize,
304       boolean[] isNull) {
305
306       for (int i=0; i < batchSize; ++i) {
307         if (!isNull[i]) {
308           Aggregation myagg = getCurrentAggregationBuffer(
309             aggregationBufferSets, 
310             bufferIndex,
311             i);
312           myagg.sumValue(values[i]);
313         }
314       }
315    }
316
317     
318     @Override
319     public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) 
320         throws HiveException {
321         
322         inputExpression.evaluate(batch);
323         
324         <InputColumnVectorType> inputVector = 
325             (<InputColumnVectorType>)batch.cols[this.inputExpression.getOutputColumn()];
326         
327         int batchSize = batch.size;
328         
329         if (batchSize == 0) {
330           return;
331         }
332         
333         Aggregation myagg = (Aggregation)agg;
334   
335         <ValueType>[] vector = inputVector.vector;
336         
337         if (inputVector.isRepeating) {
338           if (inputVector.noNulls) {
339             if (myagg.isNull) {
340               myagg.isNull = false;
341               myagg.sum = 0;
342               myagg.count = 0;
343             }
344             myagg.sum += vector[0]*batchSize;
345             myagg.count += batchSize;
346           }
347           return;
348         }
349         
350         if (!batch.selectedInUse && inputVector.noNulls) {
351           iterateNoSelectionNoNulls(myagg, vector, batchSize);
352         }
353         else if (!batch.selectedInUse) {
354           iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull);
355         }
356         else if (inputVector.noNulls){
357           iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected);
358         }
359         else {
360           iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected);
361         }
362     }
363   
364     private void iterateSelectionHasNulls(
365         Aggregation myagg, 
366         <ValueType>[] vector, 
367         int batchSize,
368         boolean[] isNull, 
369         int[] selected) {
370       
371       for (int j=0; j< batchSize; ++j) {
372         int i = selected[j];
373         if (!isNull[i]) {
374           <ValueType> value = vector[i];
375           if (myagg.isNull) {
376             myagg.isNull = false;
377             myagg.sum = 0;
378             myagg.count = 0;
379           }
380           myagg.sum += value;
381           myagg.count += 1;
382         }
383       }
384     }
385
386     private void iterateSelectionNoNulls(
387         Aggregation myagg, 
388         <ValueType>[] vector, 
389         int batchSize, 
390         int[] selected) {
391       
392       if (myagg.isNull) {
393         myagg.isNull = false;
394         myagg.sum = 0;
395         myagg.count = 0;
396       }
397       
398       for (int i=0; i< batchSize; ++i) {
399         <ValueType> value = vector[selected[i]];
400         myagg.sum += value;
401         myagg.count += 1;
402       }
403     }
404
405     private void iterateNoSelectionHasNulls(
406         Aggregation myagg, 
407         <ValueType>[] vector, 
408         int batchSize,
409         boolean[] isNull) {
410       
411       for(int i=0;i<batchSize;++i) {
412         if (!isNull[i]) {
413           <ValueType> value = vector[i];
414           if (myagg.isNull) { 
415             myagg.isNull = false;
416             myagg.sum = 0;
417             myagg.count = 0;
418           }
419           myagg.sum += value;
420           myagg.count += 1;
421         }
422       }
423     }
424
425     private void iterateNoSelectionNoNulls(
426         Aggregation myagg, 
427         <ValueType>[] vector, 
428         int batchSize) {
429       if (myagg.isNull) {
430         myagg.isNull = false;
431         myagg.sum = 0;
432         myagg.count = 0;
433       }
434       
435       for (int i=0;i<batchSize;++i) {
436         <ValueType> value = vector[i];
437         myagg.sum += value;
438         myagg.count += 1;
439       }
440     }
441
442     @Override
443     public AggregationBuffer getNewAggregationBuffer() throws HiveException {
444       return new Aggregation();
445     }
446
447     @Override
448     public void reset(AggregationBuffer agg) throws HiveException {
449       Aggregation myAgg = (Aggregation) agg;
450       myAgg.reset();
451     }
452
453     @Override
454     public Object evaluateOutput(
455         AggregationBuffer agg) throws HiveException {
456       Aggregation myagg = (Aggregation) agg;
457       if (myagg.isNull) {
458         return null;
459       }
460       else {
461         assert(0 < myagg.count);
462         resultCount.set (myagg.count);
463         resultSum.set (myagg.sum);
464         return partialResult;
465       }
466     }
467     
468   @Override
469     public ObjectInspector getOutputObjectInspector() {
470     return soi;
471   }     
472
473   @Override
474   public int getAggregationBufferFixedSize() {
475     JavaDataModel model = JavaDataModel.get();
476     return JavaDataModel.alignUp(
477       model.object() +
478       model.primitive2() * 2,
479       model.memoryAlign());
480   }
481
482   @Override
483   public void init(AggregationDesc desc) throws HiveException {
484     // No-op
485   }
486   
487   public VectorExpression getInputExpression() {
488     return inputExpression;
489   }
490
491   public void setInputExpression(VectorExpression inputExpression) {
492     this.inputExpression = inputExpression;
493   }
494 }
495