cc7e54dfd754cf611bace0c654855071edb84d2a
[hive.git] / ql / src / gen / vectorization / UDAFTemplates / VectorUDAFSum.txt
1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18
19 package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen;
20
21 import org.apache.hadoop.hive.ql.exec.Description;
22 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
23 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
24 import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
25 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
26 import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
27 import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector;
28 import org.apache.hadoop.hive.ql.metadata.HiveException;
29 import org.apache.hadoop.hive.ql.plan.AggregationDesc;
30 import org.apache.hadoop.hive.ql.util.JavaDataModel;
31 import org.apache.hadoop.io.LongWritable;
32 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
33 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
34 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
35
36 /**
37 * <ClassName>. Vectorized implementation for SUM aggregates. 
38 */
39 @Description(name = "sum", 
40     value = "_FUNC_(expr) - Returns the sum value of expr (vectorized, type: <ValueType>)")
41 public class <ClassName> extends VectorAggregateExpression {
42    
43     private static final long serialVersionUID = 1L;
44     
45     /** 
46      * class for storing the current aggregate value.
47      */
48     private static final class Aggregation implements AggregationBuffer {
49
50       private static final long serialVersionUID = 1L;
51
52       transient private <ValueType> sum;
53
54       /**
55       * Value is explicitly (re)initialized in reset()
56       */
57       transient private boolean isNull = true;
58       
59       public void sumValue(<ValueType> value) {
60         if (isNull) {
61           sum = value;
62           isNull = false;
63         } else {
64           sum += value;
65         }
66       }
67
68       @Override
69       public int getVariableSize() {
70         throw new UnsupportedOperationException();
71       }
72
73       @Override
74       public void reset () {
75         isNull = true;
76         sum = 0;;
77       }
78     }
79     
80     private VectorExpression inputExpression;
81     transient private final <OutputType> result;
82     
83     public <ClassName>(VectorExpression inputExpression) {
84       this();
85       this.inputExpression = inputExpression;
86     }
87
88     public <ClassName>() {
89       super();
90       result = new <OutputType>();
91     }
92
93     private Aggregation getCurrentAggregationBuffer(
94         VectorAggregationBufferRow[] aggregationBufferSets,
95         int aggregateIndex,
96         int row) {
97       VectorAggregationBufferRow mySet = aggregationBufferSets[row];
98       Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex);
99       return myagg;
100     }
101     
102     @Override
103     public void aggregateInputSelection(
104       VectorAggregationBufferRow[] aggregationBufferSets,
105       int aggregateIndex, 
106       VectorizedRowBatch batch) throws HiveException {
107       
108       int batchSize = batch.size;
109       
110       if (batchSize == 0) {
111         return;
112       }
113       
114       inputExpression.evaluate(batch);
115       
116       <InputColumnVectorType> inputVector = (<InputColumnVectorType>)batch.
117         cols[this.inputExpression.getOutputColumn()];
118       <ValueType>[] vector = inputVector.vector;
119
120       if (inputVector.noNulls) {
121         if (inputVector.isRepeating) {
122           iterateNoNullsRepeatingWithAggregationSelection(
123             aggregationBufferSets, aggregateIndex,
124             vector[0], batchSize);
125         } else {
126           if (batch.selectedInUse) {
127             iterateNoNullsSelectionWithAggregationSelection(
128               aggregationBufferSets, aggregateIndex,
129               vector, batch.selected, batchSize);
130           } else {
131             iterateNoNullsWithAggregationSelection(
132               aggregationBufferSets, aggregateIndex,
133               vector, batchSize);
134           }
135         }
136       } else {
137         if (inputVector.isRepeating) {
138           if (batch.selectedInUse) {
139             iterateHasNullsRepeatingSelectionWithAggregationSelection(
140               aggregationBufferSets, aggregateIndex,
141               vector[0], batchSize, batch.selected, inputVector.isNull);
142           } else {
143             iterateHasNullsRepeatingWithAggregationSelection(
144               aggregationBufferSets, aggregateIndex,
145               vector[0], batchSize, inputVector.isNull);
146           }
147         } else {
148           if (batch.selectedInUse) {
149             iterateHasNullsSelectionWithAggregationSelection(
150               aggregationBufferSets, aggregateIndex,
151               vector, batchSize, batch.selected, inputVector.isNull);
152           } else {
153             iterateHasNullsWithAggregationSelection(
154               aggregationBufferSets, aggregateIndex,
155               vector, batchSize, inputVector.isNull);
156           }
157         }
158       }
159     }
160
161     private void iterateNoNullsRepeatingWithAggregationSelection(
162       VectorAggregationBufferRow[] aggregationBufferSets,
163       int aggregateIndex,
164       <ValueType> value,
165       int batchSize) {
166
167       for (int i=0; i < batchSize; ++i) {
168         Aggregation myagg = getCurrentAggregationBuffer(
169           aggregationBufferSets, 
170           aggregateIndex,
171           i);
172         myagg.sumValue(value);
173       }
174     } 
175
176     private void iterateNoNullsSelectionWithAggregationSelection(
177       VectorAggregationBufferRow[] aggregationBufferSets,
178       int aggregateIndex,
179       <ValueType>[] values,
180       int[] selection,
181       int batchSize) {
182       
183       for (int i=0; i < batchSize; ++i) {
184         Aggregation myagg = getCurrentAggregationBuffer(
185           aggregationBufferSets, 
186           aggregateIndex,
187           i);
188         myagg.sumValue(values[selection[i]]);
189       }
190     }
191
192     private void iterateNoNullsWithAggregationSelection(
193       VectorAggregationBufferRow[] aggregationBufferSets,
194       int aggregateIndex,
195       <ValueType>[] values,
196       int batchSize) {
197       for (int i=0; i < batchSize; ++i) {
198         Aggregation myagg = getCurrentAggregationBuffer(
199           aggregationBufferSets, 
200           aggregateIndex,
201           i);
202         myagg.sumValue(values[i]);
203       }
204     }
205
206     private void iterateHasNullsRepeatingSelectionWithAggregationSelection(
207       VectorAggregationBufferRow[] aggregationBufferSets,
208       int aggregateIndex,
209       <ValueType> value,
210       int batchSize,
211       int[] selection,
212       boolean[] isNull) {
213
214       if (isNull[0]) {
215         return;
216       }
217
218       for (int i=0; i < batchSize; ++i) {
219         Aggregation myagg = getCurrentAggregationBuffer(
220           aggregationBufferSets,
221           aggregateIndex,
222           i);
223         myagg.sumValue(value);
224       }
225       
226     }
227
228     private void iterateHasNullsRepeatingWithAggregationSelection(
229       VectorAggregationBufferRow[] aggregationBufferSets,
230       int aggregateIndex,
231       <ValueType> value,
232       int batchSize,
233       boolean[] isNull) {
234
235       if (isNull[0]) {
236         return;
237       }
238
239       for (int i=0; i < batchSize; ++i) {
240         Aggregation myagg = getCurrentAggregationBuffer(
241           aggregationBufferSets,
242           aggregateIndex,
243           i);
244         myagg.sumValue(value);
245       }
246     }
247
248     private void iterateHasNullsSelectionWithAggregationSelection(
249       VectorAggregationBufferRow[] aggregationBufferSets,
250       int aggregateIndex,
251       <ValueType>[] values,
252       int batchSize,
253       int[] selection,
254       boolean[] isNull) {
255
256       for (int j=0; j < batchSize; ++j) {
257         int i = selection[j];
258         if (!isNull[i]) {
259           Aggregation myagg = getCurrentAggregationBuffer(
260             aggregationBufferSets, 
261             aggregateIndex,
262             j);
263           myagg.sumValue(values[i]);
264         }
265       }
266    }
267
268     private void iterateHasNullsWithAggregationSelection(
269       VectorAggregationBufferRow[] aggregationBufferSets,
270       int aggregateIndex,
271       <ValueType>[] values,
272       int batchSize,
273       boolean[] isNull) {
274
275       for (int i=0; i < batchSize; ++i) {
276         if (!isNull[i]) {
277           Aggregation myagg = getCurrentAggregationBuffer(
278             aggregationBufferSets, 
279             aggregateIndex,
280             i);
281           myagg.sumValue(values[i]);
282         }
283       }
284    }
285     
286     
287     @Override
288     public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) 
289     throws HiveException {
290       
291       inputExpression.evaluate(batch);
292       
293       <InputColumnVectorType> inputVector = (<InputColumnVectorType>)batch.
294           cols[this.inputExpression.getOutputColumn()];
295       
296       int batchSize = batch.size;
297       
298       if (batchSize == 0) {
299         return;
300       }
301       
302       Aggregation myagg = (Aggregation)agg;
303
304       <ValueType>[] vector = inputVector.vector;
305       
306       if (inputVector.isRepeating) {
307         if (inputVector.noNulls) {
308         if (myagg.isNull) {
309           myagg.isNull = false;
310           myagg.sum = 0;
311         }
312         myagg.sum += vector[0]*batchSize;
313       }
314         return;
315       }
316       
317       if (!batch.selectedInUse && inputVector.noNulls) {
318         iterateNoSelectionNoNulls(myagg, vector, batchSize);
319       }
320       else if (!batch.selectedInUse) {
321         iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull);
322       }
323       else if (inputVector.noNulls){
324         iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected);
325       }
326       else {
327         iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected);
328       }
329     }
330   
331     private void iterateSelectionHasNulls(
332         Aggregation myagg, 
333         <ValueType>[] vector, 
334         int batchSize,
335         boolean[] isNull, 
336         int[] selected) {
337       
338       for (int j=0; j< batchSize; ++j) {
339         int i = selected[j];
340         if (!isNull[i]) {
341           <ValueType> value = vector[i];
342           if (myagg.isNull) {
343             myagg.isNull = false;
344             myagg.sum = 0;
345           }
346           myagg.sum += value;
347         }
348       }
349     }
350
351     private void iterateSelectionNoNulls(
352         Aggregation myagg, 
353         <ValueType>[] vector, 
354         int batchSize, 
355         int[] selected) {
356       
357       if (myagg.isNull) {
358         myagg.sum = 0;
359         myagg.isNull = false;
360       }
361       
362       for (int i=0; i< batchSize; ++i) {
363         <ValueType> value = vector[selected[i]];
364         myagg.sum += value;
365       }
366     }
367
368     private void iterateNoSelectionHasNulls(
369         Aggregation myagg, 
370         <ValueType>[] vector, 
371         int batchSize,
372         boolean[] isNull) {
373       
374       for(int i=0;i<batchSize;++i) {
375         if (!isNull[i]) {
376           <ValueType> value = vector[i];
377           if (myagg.isNull) { 
378             myagg.sum = 0;
379             myagg.isNull = false;
380           }
381           myagg.sum += value;
382         }
383       }
384     }
385
386     private void iterateNoSelectionNoNulls(
387         Aggregation myagg, 
388         <ValueType>[] vector, 
389         int batchSize) {
390       if (myagg.isNull) {
391         myagg.sum = 0;
392         myagg.isNull = false;
393       }
394       
395       for (int i=0;i<batchSize;++i) {
396         <ValueType> value = vector[i];
397         myagg.sum += value;
398       }
399     }
400
401     @Override
402     public AggregationBuffer getNewAggregationBuffer() throws HiveException {
403       return new Aggregation();
404     }
405
406     @Override
407     public void reset(AggregationBuffer agg) throws HiveException {
408       Aggregation myAgg = (Aggregation) agg;
409       myAgg.reset();
410     }
411
412     @Override
413     public Object evaluateOutput(AggregationBuffer agg) throws HiveException {
414       Aggregation myagg = (Aggregation) agg;
415       if (myagg.isNull) {
416         return null;
417       }
418       else {
419         result.set(myagg.sum);
420         return result;
421       }
422     }
423     
424     @Override
425     public ObjectInspector getOutputObjectInspector() {
426       return <OutputTypeInspector>;
427     }
428
429   @Override
430   public int getAggregationBufferFixedSize() {
431       JavaDataModel model = JavaDataModel.get();
432       return JavaDataModel.alignUp(
433         model.object(),
434         model.memoryAlign());
435   }
436
437   @Override
438   public void init(AggregationDesc desc) throws HiveException {
439     // No-op
440   }
441
442   public VectorExpression getInputExpression() {
443     return inputExpression;
444   }
445
446   public void setInputExpression(VectorExpression inputExpression) {
447     this.inputExpression = inputExpression;
448   }
449 }
450