9a48171a8e1c68d7b5289ed5f8cb87ed18de28b1
[hive.git] / ql / src / gen / vectorization / UDAFTemplates / VectorUDAFMinMaxDecimal.txt
1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18
19 package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen;
20
21 import org.apache.hadoop.hive.common.type.HiveDecimal;
22 import org.apache.hadoop.hive.ql.exec.Description;
23 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
24 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
25 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter;
26 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory;
27 import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
28 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
29 import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector;
30 import org.apache.hadoop.hive.ql.metadata.HiveException;
31 import org.apache.hadoop.hive.ql.plan.AggregationDesc;
32 import org.apache.hadoop.hive.ql.util.JavaDataModel;
33 import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
34 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
35
36 /**
37 * <ClassName>. Vectorized implementation for MIN/MAX aggregates.
38 */
39 @Description(name = "<DescriptionName>",
40     value = "<DescriptionValue>")
41 public class <ClassName> extends VectorAggregateExpression {
42
43     private static final long serialVersionUID = 1L;
44
45     /**
46      * class for storing the current aggregate value.
47      */
48     static private final class Aggregation implements AggregationBuffer {
49
50       private static final long serialVersionUID = 1L;
51
52       transient private final HiveDecimalWritable value;
53
54       /**
55       * Value is explicitly (re)initialized in reset()
56       */
57       transient private boolean isNull = true;
58
59       public Aggregation() {
60         value = new HiveDecimalWritable();
61       }
62
63       public void checkValue(HiveDecimalWritable writable, short scale) {
64         HiveDecimal value = writable.getHiveDecimal();
65         if (isNull) {
66           isNull = false;
67           this.value.set(value);
68         } else if (this.value.getHiveDecimal().compareTo(value) <OperatorSymbol> 0) {
69           this.value.set(value);
70         }
71       }
72
73       @Override
74       public int getVariableSize() {
75         throw new UnsupportedOperationException();
76       }
77
78       @Override
79       public void reset () {
80         isNull = true;
81         value.set(HiveDecimal.ZERO);
82       }
83     }
84
85     private VectorExpression inputExpression;
86     private transient VectorExpressionWriter resultWriter;
87
88     public <ClassName>(VectorExpression inputExpression) {
89       this();
90       this.inputExpression = inputExpression;
91     }
92
93     public <ClassName>() {
94       super();
95     }
96
97     @Override
98     public void init(AggregationDesc desc) throws HiveException {
99       resultWriter = VectorExpressionWriterFactory.genVectorExpressionWritable(
100           desc.getParameters().get(0));
101     }
102
103     private Aggregation getCurrentAggregationBuffer(
104         VectorAggregationBufferRow[] aggregationBufferSets,
105         int aggregrateIndex,
106         int row) {
107       VectorAggregationBufferRow mySet = aggregationBufferSets[row];
108       Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregrateIndex);
109       return myagg;
110     }
111
112     @Override
113     public void aggregateInputSelection(
114       VectorAggregationBufferRow[] aggregationBufferSets,
115       int aggregrateIndex,
116       VectorizedRowBatch batch) throws HiveException {
117
118       int batchSize = batch.size;
119
120       if (batchSize == 0) {
121         return;
122       }
123
124       inputExpression.evaluate(batch);
125
126       DecimalColumnVector inputVector = (DecimalColumnVector)batch.
127         cols[this.inputExpression.getOutputColumn()];
128       HiveDecimalWritable[] vector = inputVector.vector;
129
130       if (inputVector.noNulls) {
131         if (inputVector.isRepeating) {
132           iterateNoNullsRepeatingWithAggregationSelection(
133             aggregationBufferSets, aggregrateIndex,
134             vector[0], inputVector.scale, batchSize);
135         } else {
136           if (batch.selectedInUse) {
137             iterateNoNullsSelectionWithAggregationSelection(
138               aggregationBufferSets, aggregrateIndex,
139               vector, inputVector.scale, batch.selected, batchSize);
140           } else {
141             iterateNoNullsWithAggregationSelection(
142               aggregationBufferSets, aggregrateIndex,
143               vector, inputVector.scale, batchSize);
144           }
145         }
146       } else {
147         if (inputVector.isRepeating) {
148           if (batch.selectedInUse) {
149             iterateHasNullsRepeatingSelectionWithAggregationSelection(
150               aggregationBufferSets, aggregrateIndex,
151               vector[0], inputVector.scale, batchSize, batch.selected, inputVector.isNull);
152           } else {
153             iterateHasNullsRepeatingWithAggregationSelection(
154               aggregationBufferSets, aggregrateIndex,
155               vector[0], inputVector.scale, batchSize, inputVector.isNull);
156           }
157         } else {
158           if (batch.selectedInUse) {
159             iterateHasNullsSelectionWithAggregationSelection(
160               aggregationBufferSets, aggregrateIndex,
161               vector, inputVector.scale, batchSize, batch.selected, inputVector.isNull);
162           } else {
163             iterateHasNullsWithAggregationSelection(
164               aggregationBufferSets, aggregrateIndex,
165               vector, inputVector.scale, batchSize, inputVector.isNull);
166           }
167         }
168       }
169     }
170
171     private void iterateNoNullsRepeatingWithAggregationSelection(
172       VectorAggregationBufferRow[] aggregationBufferSets,
173       int aggregrateIndex,
174       HiveDecimalWritable value,
175       short scale,
176       int batchSize) {
177
178       for (int i=0; i < batchSize; ++i) {
179         Aggregation myagg = getCurrentAggregationBuffer(
180           aggregationBufferSets,
181           aggregrateIndex,
182           i);
183         myagg.checkValue(value, scale);
184       }
185     }
186
187     private void iterateNoNullsSelectionWithAggregationSelection(
188       VectorAggregationBufferRow[] aggregationBufferSets,
189       int aggregrateIndex,
190       HiveDecimalWritable[] values,
191       short scale,
192       int[] selection,
193       int batchSize) {
194
195       for (int i=0; i < batchSize; ++i) {
196         Aggregation myagg = getCurrentAggregationBuffer(
197           aggregationBufferSets,
198           aggregrateIndex,
199           i);
200         myagg.checkValue(values[selection[i]], scale);
201       }
202     }
203
204     private void iterateNoNullsWithAggregationSelection(
205       VectorAggregationBufferRow[] aggregationBufferSets,
206       int aggregrateIndex,
207       HiveDecimalWritable[] values,
208       short scale,
209       int batchSize) {
210       for (int i=0; i < batchSize; ++i) {
211         Aggregation myagg = getCurrentAggregationBuffer(
212           aggregationBufferSets,
213           aggregrateIndex,
214           i);
215         myagg.checkValue(values[i], scale);
216       }
217     }
218
219     private void iterateHasNullsRepeatingSelectionWithAggregationSelection(
220       VectorAggregationBufferRow[] aggregationBufferSets,
221       int aggregrateIndex,
222       HiveDecimalWritable value,
223       short scale,
224       int batchSize,
225       int[] selection,
226       boolean[] isNull) {
227
228       for (int i=0; i < batchSize; ++i) {
229         if (!isNull[selection[i]]) {
230           Aggregation myagg = getCurrentAggregationBuffer(
231             aggregationBufferSets,
232             aggregrateIndex,
233             i);
234           myagg.checkValue(value, scale);
235         }
236       }
237
238     }
239
240     private void iterateHasNullsRepeatingWithAggregationSelection(
241       VectorAggregationBufferRow[] aggregationBufferSets,
242       int aggregrateIndex,
243       HiveDecimalWritable value,
244       short scale,
245       int batchSize,
246       boolean[] isNull) {
247
248       if (isNull[0]) {
249         return;
250       }
251
252       for (int i=0; i < batchSize; ++i) {
253         Aggregation myagg = getCurrentAggregationBuffer(
254           aggregationBufferSets,
255           aggregrateIndex,
256           i);
257         myagg.checkValue(value, scale);
258       }
259     }
260
261     private void iterateHasNullsSelectionWithAggregationSelection(
262       VectorAggregationBufferRow[] aggregationBufferSets,
263       int aggregrateIndex,
264       HiveDecimalWritable[] values,
265       short scale,
266       int batchSize,
267       int[] selection,
268       boolean[] isNull) {
269
270       for (int j=0; j < batchSize; ++j) {
271         int i = selection[j];
272         if (!isNull[i]) {
273           Aggregation myagg = getCurrentAggregationBuffer(
274             aggregationBufferSets,
275             aggregrateIndex,
276             j);
277           myagg.checkValue(values[i], scale);
278         }
279       }
280    }
281
282     private void iterateHasNullsWithAggregationSelection(
283       VectorAggregationBufferRow[] aggregationBufferSets,
284       int aggregrateIndex,
285       HiveDecimalWritable[] values,
286       short scale,
287       int batchSize,
288       boolean[] isNull) {
289
290       for (int i=0; i < batchSize; ++i) {
291         if (!isNull[i]) {
292           Aggregation myagg = getCurrentAggregationBuffer(
293             aggregationBufferSets,
294             aggregrateIndex,
295             i);
296           myagg.checkValue(values[i], scale);
297         }
298       }
299    }
300
301     @Override
302     public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch)
303       throws HiveException {
304
305         inputExpression.evaluate(batch);
306
307         DecimalColumnVector inputVector = (DecimalColumnVector)batch.
308             cols[this.inputExpression.getOutputColumn()];
309
310         int batchSize = batch.size;
311
312         if (batchSize == 0) {
313           return;
314         }
315
316         Aggregation myagg = (Aggregation)agg;
317
318         HiveDecimalWritable[] vector = inputVector.vector;
319
320         if (inputVector.isRepeating) {
321           if (inputVector.noNulls &&
322             (myagg.isNull || (myagg.value.compareTo(vector[0]) <OperatorSymbol> 0))) {
323             myagg.isNull = false;
324             HiveDecimal value = vector[0].getHiveDecimal();
325             myagg.value.set(value);
326           }
327           return;
328         }
329
330         if (!batch.selectedInUse && inputVector.noNulls) {
331           iterateNoSelectionNoNulls(myagg, vector, inputVector.scale, batchSize);
332         }
333         else if (!batch.selectedInUse) {
334           iterateNoSelectionHasNulls(myagg, vector, inputVector.scale, 
335             batchSize, inputVector.isNull);
336         }
337         else if (inputVector.noNulls){
338           iterateSelectionNoNulls(myagg, vector, inputVector.scale, batchSize, batch.selected);
339         }
340         else {
341           iterateSelectionHasNulls(myagg, vector, inputVector.scale, 
342             batchSize, inputVector.isNull, batch.selected);
343         }
344     }
345
346     private void iterateSelectionHasNulls(
347         Aggregation myagg,
348         HiveDecimalWritable[] vector,
349         short scale,
350         int batchSize,
351         boolean[] isNull,
352         int[] selected) {
353
354       for (int j=0; j< batchSize; ++j) {
355         int i = selected[j];
356         if (!isNull[i]) {
357           HiveDecimal value = vector[i].getHiveDecimal();
358           if (myagg.isNull) {
359             myagg.isNull = false;
360             myagg.value.set(value);
361           }
362           else if (myagg.value.getHiveDecimal().compareTo(value) <OperatorSymbol> 0) {
363             myagg.value.set(value);
364           }
365         }
366       }
367     }
368
369     private void iterateSelectionNoNulls(
370         Aggregation myagg,
371         HiveDecimalWritable[] vector,
372         short scale,
373         int batchSize,
374         int[] selected) {
375
376       if (myagg.isNull) {
377         HiveDecimal value = vector[selected[0]].getHiveDecimal();
378         myagg.value.set(value);
379         myagg.isNull = false;
380       }
381
382       for (int i=0; i< batchSize; ++i) {
383         HiveDecimal value = vector[selected[i]].getHiveDecimal();
384         if (myagg.value.getHiveDecimal().compareTo(value) <OperatorSymbol> 0) {
385           myagg.value.set(value);
386         }
387       }
388     }
389
390     private void iterateNoSelectionHasNulls(
391         Aggregation myagg,
392         HiveDecimalWritable[] vector,
393         short scale,
394         int batchSize,
395         boolean[] isNull) {
396
397       for(int i=0;i<batchSize;++i) {
398         if (!isNull[i]) {
399           HiveDecimal value = vector[i].getHiveDecimal();
400           if (myagg.isNull) {
401             myagg.value.set(value);
402             myagg.isNull = false;
403           }
404           else if (myagg.value.getHiveDecimal().compareTo(value) <OperatorSymbol> 0) {
405             myagg.value.set(value);
406           }
407         }
408       }
409     }
410
411     private void iterateNoSelectionNoNulls(
412         Aggregation myagg,
413         HiveDecimalWritable[] vector,
414         short scale,
415         int batchSize) {
416       if (myagg.isNull) {
417         HiveDecimal value = vector[0].getHiveDecimal();
418         myagg.value.set(value);
419         myagg.isNull = false;
420       }
421
422       for (int i=0;i<batchSize;++i) {
423         HiveDecimal value = vector[i].getHiveDecimal();
424         if (myagg.value.getHiveDecimal().compareTo(value) <OperatorSymbol> 0) {
425           myagg.value.set(value);
426         }
427       }
428     }
429
430     @Override
431     public AggregationBuffer getNewAggregationBuffer() throws HiveException {
432       return new Aggregation();
433     }
434
435     @Override
436     public void reset(AggregationBuffer agg) throws HiveException {
437       Aggregation myAgg = (Aggregation) agg;
438       myAgg.reset();
439     }
440
441     @Override
442     public Object evaluateOutput(
443         AggregationBuffer agg) throws HiveException {
444     Aggregation myagg = (Aggregation) agg;
445       if (myagg.isNull) {
446         return null;
447       }
448       else {
449         return resultWriter.writeValue(myagg.value);
450       }
451     }
452
453     @Override
454     public ObjectInspector getOutputObjectInspector() {
455       return resultWriter.getObjectInspector();
456     }
457
458     @Override
459     public int getAggregationBufferFixedSize() {
460     JavaDataModel model = JavaDataModel.get();
461     return JavaDataModel.alignUp(
462       model.object() +
463       model.primitive2(),
464       model.memoryAlign());
465   }
466
467   public VectorExpression getInputExpression() {
468     return inputExpression;
469   }
470
471   public void setInputExpression(VectorExpression inputExpression) {
472     this.inputExpression = inputExpression;
473   }
474 }
475