[CALCITE-563] In JDBC adapter, push bindable parameters down to the underlying JDBC...
[calcite.git] / core / src / main / java / org / apache / calcite / runtime / ResultSetEnumerable.java
index 771772f..52c11f9 100644 (file)
@@ -16,6 +16,8 @@
  */
 package org.apache.calcite.runtime;
 
+import org.apache.calcite.DataContext;
+import org.apache.calcite.avatica.SqlType;
 import org.apache.calcite.linq4j.AbstractEnumerable;
 import org.apache.calcite.linq4j.Enumerable;
 import org.apache.calcite.linq4j.Enumerator;
@@ -23,16 +25,29 @@ import org.apache.calcite.linq4j.Linq4j;
 import org.apache.calcite.linq4j.function.Function0;
 import org.apache.calcite.linq4j.function.Function1;
 import org.apache.calcite.linq4j.tree.Primitive;
+import org.apache.calcite.util.Static;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.math.BigDecimal;
+import java.net.URL;
+import java.sql.Blob;
+import java.sql.Clob;
 import java.sql.Connection;
+import java.sql.Date;
+import java.sql.NClob;
+import java.sql.PreparedStatement;
+import java.sql.Ref;
 import java.sql.ResultSet;
 import java.sql.ResultSetMetaData;
+import java.sql.RowId;
 import java.sql.SQLException;
 import java.sql.SQLFeatureNotSupportedException;
+import java.sql.SQLXML;
 import java.sql.Statement;
+import java.sql.Time;
+import java.sql.Timestamp;
 import java.sql.Types;
 import java.util.ArrayList;
 import java.util.List;
@@ -47,6 +62,8 @@ public class ResultSetEnumerable<T> extends AbstractEnumerable<T> {
   private final DataSource dataSource;
   private final String sql;
   private final Function1<ResultSet, Function0<T>> rowBuilderFactory;
+  private final PreparedStatementEnricher preparedStatementEnricher;
+
   private static final Logger LOGGER = LoggerFactory.getLogger(
       ResultSetEnumerable.class);
 
@@ -96,10 +113,19 @@ public class ResultSetEnumerable<T> extends AbstractEnumerable<T> {
   private ResultSetEnumerable(
       DataSource dataSource,
       String sql,
-      Function1<ResultSet, Function0<T>> rowBuilderFactory) {
+      Function1<ResultSet, Function0<T>> rowBuilderFactory,
+      PreparedStatementEnricher preparedStatementEnricher) {
     this.dataSource = dataSource;
     this.sql = sql;
     this.rowBuilderFactory = rowBuilderFactory;
+    this.preparedStatementEnricher = preparedStatementEnricher;
+  }
+
+  private ResultSetEnumerable(
+      DataSource dataSource,
+      String sql,
+      Function1<ResultSet, Function0<T>> rowBuilderFactory) {
+    this(dataSource, sql, rowBuilderFactory, null);
   }
 
   /** Creates an ResultSetEnumerable. */
@@ -123,17 +149,100 @@ public class ResultSetEnumerable<T> extends AbstractEnumerable<T> {
     return new ResultSetEnumerable<>(dataSource, sql, rowBuilderFactory);
   }
 
+  /** Executes a SQL query and returns the results as an enumerator, using a
+   * row builder to convert JDBC column values into rows.
+   *
+   * <p>It uses a {@link PreparedStatement} for computing the query result,
+   * and that means that it can bind parameters. */
+  public static <T> Enumerable<T> of(
+      DataSource dataSource,
+      String sql,
+      Function1<ResultSet, Function0<T>> rowBuilderFactory,
+      PreparedStatementEnricher consumer) {
+    return new ResultSetEnumerable<>(dataSource, sql, rowBuilderFactory, consumer);
+  }
+
+  /** Called from generated code that proposes to create a
+   * {@code ResultSetEnumerable} over a prepared statement. */
+  public static PreparedStatementEnricher createEnricher(Integer[] indexes,
+      DataContext context) {
+    return preparedStatement -> {
+      for (int i = 0; i < indexes.length; i++) {
+        final int index = indexes[i];
+        setDynamicParam(preparedStatement, i + 1,
+            context.get("?" + index));
+      }
+    };
+  }
+
+  /** Assigns a value to a dynamic parameter in a prepared statement, calling
+   * the appropriate {@code setXxx} method based on the type of the value. */
+  private static void setDynamicParam(PreparedStatement preparedStatement,
+      int i, Object value) throws SQLException {
+    if (value == null) {
+      preparedStatement.setObject(i, null, SqlType.ANY.id);
+    } else if (value instanceof Timestamp) {
+      preparedStatement.setTimestamp(i, (Timestamp) value);
+    } else if (value instanceof Time) {
+      preparedStatement.setTime(i, (Time) value);
+    } else if (value instanceof String) {
+      preparedStatement.setString(i, (String) value);
+    } else if (value instanceof Integer) {
+      preparedStatement.setInt(i, (Integer) value);
+    } else if (value instanceof Double) {
+      preparedStatement.setDouble(i, (Double) value);
+    } else if (value instanceof java.sql.Array) {
+      preparedStatement.setArray(i, (java.sql.Array) value);
+    } else if (value instanceof BigDecimal) {
+      preparedStatement.setBigDecimal(i, (BigDecimal) value);
+    } else if (value instanceof Boolean) {
+      preparedStatement.setBoolean(i, (Boolean) value);
+    } else if (value instanceof Blob) {
+      preparedStatement.setBlob(i, (Blob) value);
+    } else if (value instanceof Byte) {
+      preparedStatement.setByte(i, (Byte) value);
+    } else if (value instanceof NClob) {
+      preparedStatement.setNClob(i, (NClob) value);
+    } else if (value instanceof Clob) {
+      preparedStatement.setClob(i, (Clob) value);
+    } else if (value instanceof byte[]) {
+      preparedStatement.setBytes(i, (byte[]) value);
+    } else if (value instanceof Date) {
+      preparedStatement.setDate(i, (Date) value);
+    } else if (value instanceof Float) {
+      preparedStatement.setFloat(i, (Float) value);
+    } else if (value instanceof Long) {
+      preparedStatement.setLong(i, (Long) value);
+    } else if (value instanceof Ref) {
+      preparedStatement.setRef(i, (Ref) value);
+    } else if (value instanceof RowId) {
+      preparedStatement.setRowId(i, (RowId) value);
+    } else if (value instanceof Short) {
+      preparedStatement.setShort(i, (Short) value);
+    } else if (value instanceof URL) {
+      preparedStatement.setURL(i, (URL) value);
+    } else if (value instanceof SQLXML) {
+      preparedStatement.setSQLXML(i, (SQLXML) value);
+    } else {
+      preparedStatement.setObject(i, value);
+    }
+  }
+
   public Enumerator<T> enumerator() {
+    if (preparedStatementEnricher == null) {
+      return enumeratorBasedOnStatement();
+    } else {
+      return enumeratorBasedOnPreparedStatement();
+    }
+  }
+
+  private Enumerator<T> enumeratorBasedOnStatement() {
     Connection connection = null;
     Statement statement = null;
     try {
       connection = dataSource.getConnection();
       statement = connection.createStatement();
-      try {
-        statement.setQueryTimeout(10);
-      } catch (SQLFeatureNotSupportedException e) {
-        LOGGER.debug("Failed to set query timeout.");
-      }
+      setTimeoutIfPossible(statement);
       if (statement.execute(sql)) {
         final ResultSet resultSet = statement.getResultSet();
         statement = null;
@@ -144,21 +253,59 @@ public class ResultSetEnumerable<T> extends AbstractEnumerable<T> {
         return Linq4j.singletonEnumerator((T) updateCount);
       }
     } catch (SQLException e) {
-      throw new RuntimeException("while executing SQL [" + sql + "]", e);
+      throw Static.RESOURCE.exceptionWhilePerformingQueryOnJdbcSubSchema(sql)
+          .ex(e);
     } finally {
-      if (statement != null) {
-        try {
-          statement.close();
-        } catch (SQLException e) {
-          // ignore
-        }
+      closeIfPossible(connection, statement);
+    }
+  }
+
+  private Enumerator<T> enumeratorBasedOnPreparedStatement() {
+    Connection connection = null;
+    PreparedStatement preparedStatement = null;
+    try {
+      connection = dataSource.getConnection();
+      preparedStatement = connection.prepareStatement(sql);
+      setTimeoutIfPossible(preparedStatement);
+      preparedStatementEnricher.enrich(preparedStatement);
+      if (preparedStatement.execute()) {
+        final ResultSet resultSet = preparedStatement.getResultSet();
+        preparedStatement = null;
+        connection = null;
+        return new ResultSetEnumerator<>(resultSet, rowBuilderFactory);
+      } else {
+        Integer updateCount = preparedStatement.getUpdateCount();
+        return Linq4j.singletonEnumerator((T) updateCount);
       }
-      if (connection != null) {
-        try {
-          connection.close();
-        } catch (SQLException e) {
-          // ignore
-        }
+    } catch (SQLException e) {
+      throw Static.RESOURCE.exceptionWhilePerformingQueryOnJdbcSubSchema(sql)
+          .ex(e);
+    } finally {
+      closeIfPossible(connection, preparedStatement);
+    }
+  }
+
+  private void setTimeoutIfPossible(Statement statement) throws SQLException {
+    try {
+      statement.setQueryTimeout(10);
+    } catch (SQLFeatureNotSupportedException e) {
+      LOGGER.debug("Failed to set query timeout.");
+    }
+  }
+
+  private void closeIfPossible(Connection connection, Statement statement) {
+    if (statement != null) {
+      try {
+        statement.close();
+      } catch (SQLException e) {
+        // ignore
+      }
+    }
+    if (connection != null) {
+      try {
+        connection.close();
+      } catch (SQLException e) {
+        // ignore
       }
     }
   }
@@ -254,6 +401,14 @@ public class ResultSetEnumerable<T> extends AbstractEnumerable<T> {
       };
     };
   }
+
+  /**
+   * Consumer for decorating a {@link PreparedStatement}, that is, setting
+   * its parameters.
+   */
+  public interface PreparedStatementEnricher {
+    void enrich(PreparedStatement statement) throws SQLException;
+  }
 }
 
 // End ResultSetEnumerable.java