[SPARK-26456][SQL] Cast date/timestamp to string by Date/TimestampFormatter
authorMaxim Gekk <maxim.gekk@databricks.com>
Mon, 14 Jan 2019 13:59:25 +0000 (21:59 +0800)
committerWenchen Fan <wenchen@databricks.com>
Mon, 14 Jan 2019 13:59:25 +0000 (21:59 +0800)
## What changes were proposed in this pull request?

In the PR, I propose to switch on `TimestampFormatter`/`DateFormatter` in casting dates/timestamps to strings. The changes should make the date/timestamp casting consistent to JSON/CSV datasources and time-related functions like `to_date`, `to_unix_timestamp`/`from_unixtime`.

Local formatters are moved out from `DateTimeUtils` to where they are actually used. It allows to avoid re-creation of new formatter instance per-each call. Another reason is to have separate parser for `PartitioningUtils` because default parsing pattern cannot be used (expected optional section `[.S]`).

## How was this patch tested?

It was tested by `DateTimeUtilsSuite`, `CastSuite` and `JDBC*Suite`.

Closes #23391 from MaxGekk/thread-local-date-format.

Lead-authored-by: Maxim Gekk <maxim.gekk@databricks.com>
Co-authored-by: Maxim Gekk <max.gekk@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
13 files changed:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala

index ee463bf..ff6a68b 100644 (file)
@@ -230,12 +230,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   // [[func]] assumes the input is no longer null because eval already does the null check.
   @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
 
+  private lazy val dateFormatter = DateFormatter()
+  private lazy val timestampFormatter = TimestampFormatter(timeZone)
+
   // UDFToString
   private[this] def castToString(from: DataType): Any => Any = from match {
     case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes)
-    case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d)))
+    case DateType => buildCast[Int](_, d => UTF8String.fromString(dateFormatter.format(d)))
     case TimestampType => buildCast[Long](_,
-      t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
+      t => UTF8String.fromString(DateTimeUtils.timestampToString(timestampFormatter, t)))
     case ArrayType(et, _) =>
       buildCast[ArrayData](_, array => {
         val builder = new UTF8StringBuilder
@@ -843,12 +846,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
       case BinaryType =>
         (c, evPrim, evNull) => code"$evPrim = UTF8String.fromBytes($c);"
       case DateType =>
-        (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString(
-          org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));"""
+        val df = JavaCode.global(
+          ctx.addReferenceObj("dateFormatter", dateFormatter),
+          dateFormatter.getClass)
+        (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString(${df}.format($c));"""
       case TimestampType =>
-        val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass)
+        val tf = JavaCode.global(
+          ctx.addReferenceObj("timestampFormatter", timestampFormatter),
+          timestampFormatter.getClass)
         (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString(
-          org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
+          org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($tf, $c));"""
       case ArrayType(et, _) =>
         (c, evPrim, evNull) => {
           val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder])
index b1832da..e758362 100644 (file)
@@ -562,7 +562,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti
     copy(timeZoneId = Option(timeZoneId))
 
   override protected def nullSafeEval(timestamp: Any, format: Any): Any = {
-    val df = TimestampFormatter(format.toString, timeZone, Locale.US)
+    val df = TimestampFormatter(format.toString, timeZone)
     UTF8String.fromString(df.format(timestamp.asInstanceOf[Long]))
   }
 
@@ -665,7 +665,7 @@ abstract class UnixTime
   private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String]
   private lazy val formatter: TimestampFormatter =
     try {
-      TimestampFormatter(constFormat.toString, timeZone, Locale.US)
+      TimestampFormatter(constFormat.toString, timeZone)
     } catch {
       case NonFatal(_) => null
     }
@@ -698,7 +698,7 @@ abstract class UnixTime
           } else {
             val formatString = f.asInstanceOf[UTF8String].toString
             try {
-              TimestampFormatter(formatString, timeZone, Locale.US).parse(
+              TimestampFormatter(formatString, timeZone).parse(
                 t.asInstanceOf[UTF8String].toString) / MICROS_PER_SECOND
             } catch {
               case NonFatal(_) => null
@@ -819,7 +819,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
   private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String]
   private lazy val formatter: TimestampFormatter =
     try {
-      TimestampFormatter(constFormat.toString, timeZone, Locale.US)
+      TimestampFormatter(constFormat.toString, timeZone)
     } catch {
       case NonFatal(_) => null
     }
@@ -845,7 +845,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
           null
         } else {
           try {
-            UTF8String.fromString(TimestampFormatter(f.toString, timeZone, Locale.US)
+            UTF8String.fromString(TimestampFormatter(f.toString, timeZone)
               .format(time.asInstanceOf[Long] * MICROS_PER_SECOND))
           } catch {
             case NonFatal(_) => null
index c47b087..adc69ab 100644 (file)
@@ -51,7 +51,14 @@ class Iso8601DateFormatter(
 }
 
 object DateFormatter {
+  val defaultPattern: String = "yyyy-MM-dd"
+  val defaultLocale: Locale = Locale.US
+
   def apply(format: String, locale: Locale): DateFormatter = {
     new Iso8601DateFormatter(format, locale)
   }
+
+  def apply(format: String): DateFormatter = apply(format, defaultLocale)
+
+  def apply(): DateFormatter = apply(defaultPattern)
 }
index e95117f..da8899a 100644 (file)
@@ -92,32 +92,6 @@ object DateTimeUtils {
     }
   }
 
-  // `SimpleDateFormat` is not thread-safe.
-  private val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
-    override def initialValue(): SimpleDateFormat = {
-      new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
-    }
-  }
-
-  def getThreadLocalTimestampFormat(timeZone: TimeZone): DateFormat = {
-    val sdf = threadLocalTimestampFormat.get()
-    sdf.setTimeZone(timeZone)
-    sdf
-  }
-
-  // `SimpleDateFormat` is not thread-safe.
-  private val threadLocalDateFormat = new ThreadLocal[DateFormat] {
-    override def initialValue(): SimpleDateFormat = {
-      new SimpleDateFormat("yyyy-MM-dd", Locale.US)
-    }
-  }
-
-  def getThreadLocalDateFormat(timeZone: TimeZone): DateFormat = {
-    val sdf = threadLocalDateFormat.get()
-    sdf.setTimeZone(timeZone)
-    sdf
-  }
-
   private val computedTimeZones = new ConcurrentHashMap[String, TimeZone]
   private val computeTimeZone = new JFunction[String, TimeZone] {
     override def apply(timeZoneId: String): TimeZone = TimeZone.getTimeZone(timeZoneId)
@@ -149,24 +123,11 @@ object DateTimeUtils {
     millisLocal - getOffsetFromLocalMillis(millisLocal, timeZone)
   }
 
-  def dateToString(days: SQLDate): String =
-    getThreadLocalDateFormat(defaultTimeZone()).format(toJavaDate(days))
-
-  def dateToString(days: SQLDate, timeZone: TimeZone): String = {
-    getThreadLocalDateFormat(timeZone).format(toJavaDate(days))
-  }
-
-  // Converts Timestamp to string according to Hive TimestampWritable convention.
-  def timestampToString(us: SQLTimestamp): String = {
-    timestampToString(us, defaultTimeZone())
-  }
-
   // Converts Timestamp to string according to Hive TimestampWritable convention.
-  def timestampToString(us: SQLTimestamp, timeZone: TimeZone): String = {
+  def timestampToString(tf: TimestampFormatter, us: SQLTimestamp): String = {
     val ts = toJavaTimestamp(us)
     val timestampString = ts.toString
-    val timestampFormat = getThreadLocalTimestampFormat(timeZone)
-    val formatted = timestampFormat.format(ts)
+    val formatted = tf.format(us)
 
     if (timestampString.length > 19 && timestampString.substring(19) != ".0") {
       formatted + timestampString.substring(19)
@@ -1168,7 +1129,5 @@ object DateTimeUtils {
    */
   private[util] def resetThreadLocals(): Unit = {
     threadLocalGmtCalendar.remove()
-    threadLocalTimestampFormat.remove()
-    threadLocalDateFormat.remove()
   }
 }
index 10c73b2..1374a82 100644 (file)
@@ -36,7 +36,7 @@ sealed trait TimestampFormatter extends Serializable {
   @throws(classOf[ParseException])
   @throws(classOf[DateTimeParseException])
   @throws(classOf[DateTimeException])
-  def parse(s: String): Long // returns microseconds since epoch
+  def parse(s: String): Long
   def format(us: Long): String
 }
 
@@ -74,7 +74,18 @@ class Iso8601TimestampFormatter(
 }
 
 object TimestampFormatter {
+  val defaultPattern: String = "yyyy-MM-dd HH:mm:ss"
+  val defaultLocale: Locale = Locale.US
+
   def apply(format: String, timeZone: TimeZone, locale: Locale): TimestampFormatter = {
     new Iso8601TimestampFormatter(format, timeZone, locale)
   }
+
+  def apply(format: String, timeZone: TimeZone): TimestampFormatter = {
+    apply(format, timeZone, defaultLocale)
+  }
+
+  def apply(timeZone: TimeZone): TimestampFormatter = {
+    apply(defaultPattern, timeZone, defaultLocale)
+  }
 }
index 2cb6110..e732eb0 100644 (file)
@@ -35,10 +35,11 @@ class DateTimeUtilsSuite extends SparkFunSuite {
   }
 
   test("nanoseconds truncation") {
+    val tf = TimestampFormatter(DateTimeUtils.defaultTimeZone())
     def checkStringToTimestamp(originalTime: String, expectedParsedTime: String) {
       val parsedTimestampOp = DateTimeUtils.stringToTimestamp(UTF8String.fromString(originalTime))
       assert(parsedTimestampOp.isDefined, "timestamp with nanoseconds was not parsed correctly")
-      assert(DateTimeUtils.timestampToString(parsedTimestampOp.get) === expectedParsedTime)
+      assert(DateTimeUtils.timestampToString(tf, parsedTimestampOp.get) === expectedParsedTime)
     }
 
     checkStringToTimestamp("2015-01-02 00:00:00.123456789", "2015-01-02 00:00:00.123456")
index 2dc55e0..4d5872c 100644 (file)
@@ -29,7 +29,7 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper {
   test("parsing dates") {
     DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
       withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
-        val formatter = DateFormatter("yyyy-MM-dd", Locale.US)
+        val formatter = DateFormatter()
         val daysSinceEpoch = formatter.parse("2018-12-02")
         assert(daysSinceEpoch === 17867)
       }
@@ -39,7 +39,7 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper {
   test("format dates") {
     DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
       withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
-        val formatter = DateFormatter("yyyy-MM-dd", Locale.US)
+        val formatter = DateFormatter()
         val date = formatter.format(17867)
         assert(date === "2018-12-02")
       }
@@ -59,7 +59,7 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper {
       "5010-11-17").foreach { date =>
       DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
         withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
-          val formatter = DateFormatter("yyyy-MM-dd", Locale.US)
+          val formatter = DateFormatter()
           val days = formatter.parse(date)
           val formatted = formatter.format(days)
           assert(date === formatted)
@@ -82,7 +82,7 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper {
       1110657).foreach { days =>
       DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
         withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
-          val formatter = DateFormatter("yyyy-MM-dd", Locale.US)
+          val formatter = DateFormatter()
           val date = formatter.format(days)
           val parsed = formatter.parse(date)
           assert(days === parsed)
@@ -92,7 +92,7 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper {
   }
 
   test("parsing date without explicit day") {
-    val formatter = DateFormatter("yyyy MMM", Locale.US)
+    val formatter = DateFormatter("yyyy MMM")
     val daysSinceEpoch = formatter.parse("2018 Dec")
     assert(daysSinceEpoch === LocalDate.of(2018, 12, 1).toEpochDay)
   }
index 2ce3eac..d007adf 100644 (file)
@@ -41,8 +41,7 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper {
     DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
       val formatter = TimestampFormatter(
         "yyyy-MM-dd'T'HH:mm:ss.SSSSSS",
-        TimeZone.getTimeZone(timeZone),
-        Locale.US)
+        TimeZone.getTimeZone(timeZone))
       val microsSinceEpoch = formatter.parse(localDate)
       assert(microsSinceEpoch === expectedMicros(timeZone))
     }
@@ -62,8 +61,7 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper {
     DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
       val formatter = TimestampFormatter(
         "yyyy-MM-dd'T'HH:mm:ss.SSSSSS",
-        TimeZone.getTimeZone(timeZone),
-        Locale.US)
+        TimeZone.getTimeZone(timeZone))
       val timestamp = formatter.format(microsSinceEpoch)
       assert(timestamp === expectedTimestamp(timeZone))
     }
@@ -82,7 +80,7 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper {
         2177456523456789L,
         11858049903010203L).foreach { micros =>
         DateTimeTestUtils.outstandingTimezones.foreach { timeZone =>
-          val formatter = TimestampFormatter(pattern, timeZone, Locale.US)
+          val formatter = TimestampFormatter(pattern, timeZone)
           val timestamp = formatter.format(micros)
           val parsed = formatter.parse(timestamp)
           assert(micros === parsed)
@@ -103,7 +101,7 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper {
       "2039-01-01T01:02:03.456789",
       "2345-10-07T22:45:03.010203").foreach { timestamp =>
       DateTimeTestUtils.outstandingTimezones.foreach { timeZone =>
-        val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone, Locale.US)
+        val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone)
         val micros = formatter.parse(timestamp)
         val formatted = formatter.format(micros)
         assert(timestamp === formatted)
@@ -114,8 +112,7 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper {
   test(" case insensitive parsing of am and pm") {
     val formatter = TimestampFormatter(
       "yyyy MMM dd hh:mm:ss a",
-      TimeZone.getTimeZone("UTC"),
-      Locale.US)
+      TimeZone.getTimeZone("UTC"))
     val micros = formatter.parse("2009 Mar 20 11:30:01 am")
     assert(micros === TimeUnit.SECONDS.toMicros(
       LocalDateTime.of(2009, 3, 20, 11, 30, 1).toEpochSecond(ZoneOffset.UTC)))
index c90b254..d3934a0 100644 (file)
@@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets
 import java.sql.{Date, Timestamp}
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
 import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -77,6 +77,10 @@ object HiveResult {
     TimestampType,
     BinaryType)
 
+  private lazy val dateFormatter = DateFormatter()
+  private lazy val timestampFormatter = TimestampFormatter(
+    DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone))
+
   /** Hive outputs fields of structs slightly differently than top level attributes. */
   private def toHiveStructString(a: (Any, DataType)): String = a match {
     case (struct: Row, StructType(fields)) =>
@@ -111,11 +115,9 @@ object HiveResult {
           toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
       }.toSeq.sorted.mkString("{", ",", "}")
     case (null, _) => "NULL"
-    case (d: Date, DateType) =>
-      DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d))
+    case (d: Date, DateType) => dateFormatter.format(DateTimeUtils.fromJavaDate(d))
     case (t: Timestamp, TimestampType) =>
-      val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone)
-      DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), timeZone)
+      DateTimeUtils.timestampToString(timestampFormatter, DateTimeUtils.fromJavaTimestamp(t))
     case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
     case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal)
     case (interval, CalendarIntervalType) => interval.toString
index 6458b65..ee77042 100644 (file)
@@ -26,13 +26,12 @@ import scala.util.Try
 
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
 
@@ -59,6 +58,8 @@ object PartitionSpec {
 
 object PartitioningUtils {
 
+  val timestampPartitionPattern = "yyyy-MM-dd HH:mm:ss[.S]"
+
   private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal])
   {
     require(columnNames.size == literals.size)
@@ -122,10 +123,12 @@ object PartitioningUtils {
       Map.empty[String, DataType]
     }
 
+    val dateFormatter = DateFormatter()
+    val timestampFormatter = TimestampFormatter(timestampPartitionPattern, timeZone)
     // First, we need to parse every partition's path and see if we can find partition values.
     val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
       parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes,
-        validatePartitionColumns, timeZone)
+        validatePartitionColumns, timeZone, dateFormatter, timestampFormatter)
     }.unzip
 
     // We create pairs of (path -> path's partition value) here
@@ -208,7 +211,9 @@ object PartitioningUtils {
       basePaths: Set[Path],
       userSpecifiedDataTypes: Map[String, DataType],
       validatePartitionColumns: Boolean,
-      timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = {
+      timeZone: TimeZone,
+      dateFormatter: DateFormatter,
+      timestampFormatter: TimestampFormatter): (Option[PartitionValues], Option[Path]) = {
     val columns = ArrayBuffer.empty[(String, Literal)]
     // Old Hadoop versions don't have `Path.isRoot`
     var finished = path.getParent == null
@@ -230,7 +235,7 @@ object PartitioningUtils {
         // Once we get the string, we try to parse it and find the partition column and value.
         val maybeColumn =
           parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes,
-            validatePartitionColumns, timeZone)
+            validatePartitionColumns, timeZone, dateFormatter, timestampFormatter)
         maybeColumn.foreach(columns += _)
 
         // Now, we determine if we should stop.
@@ -265,7 +270,9 @@ object PartitioningUtils {
       typeInference: Boolean,
       userSpecifiedDataTypes: Map[String, DataType],
       validatePartitionColumns: Boolean,
-      timeZone: TimeZone): Option[(String, Literal)] = {
+      timeZone: TimeZone,
+      dateFormatter: DateFormatter,
+      timestampFormatter: TimestampFormatter): Option[(String, Literal)] = {
     val equalSignIndex = columnSpec.indexOf('=')
     if (equalSignIndex == -1) {
       None
@@ -280,7 +287,12 @@ object PartitioningUtils {
         // SPARK-26188: if user provides corresponding column schema, get the column value without
         //              inference, and then cast it as user specified data type.
         val dataType = userSpecifiedDataTypes(columnName)
-        val columnValueLiteral = inferPartitionColumnValue(rawColumnValue, false, timeZone)
+        val columnValueLiteral = inferPartitionColumnValue(
+          rawColumnValue,
+          false,
+          timeZone,
+          dateFormatter,
+          timestampFormatter)
         val columnValue = columnValueLiteral.eval()
         val castedValue = Cast(columnValueLiteral, dataType, Option(timeZone.getID)).eval()
         if (validatePartitionColumns && columnValue != null && castedValue == null) {
@@ -289,7 +301,12 @@ object PartitioningUtils {
         }
         Literal.create(castedValue, dataType)
       } else {
-        inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
+        inferPartitionColumnValue(
+          rawColumnValue,
+          typeInference,
+          timeZone,
+          dateFormatter,
+          timestampFormatter)
       }
       Some(columnName -> literal)
     }
@@ -442,7 +459,9 @@ object PartitioningUtils {
   private[datasources] def inferPartitionColumnValue(
       raw: String,
       typeInference: Boolean,
-      timeZone: TimeZone): Literal = {
+      timeZone: TimeZone,
+      dateFormatter: DateFormatter,
+      timestampFormatter: TimestampFormatter): Literal = {
     val decimalTry = Try {
       // `BigDecimal` conversion can fail when the `field` is not a form of number.
       val bigDecimal = new JBigDecimal(raw)
@@ -457,7 +476,7 @@ object PartitioningUtils {
     val dateTry = Try {
       // try and parse the date, if no exception occurs this is a candidate to be resolved as
       // DateType
-      DateTimeUtils.getThreadLocalDateFormat(DateTimeUtils.defaultTimeZone()).parse(raw)
+      dateFormatter.parse(raw)
       // SPARK-23436: Casting the string to date may still return null if a bad Date is provided.
       // This can happen since DateFormat.parse  may not use the entire text of the given string:
       // so if there are extra-characters after the date, it returns correctly.
@@ -474,7 +493,7 @@ object PartitioningUtils {
       val unescapedRaw = unescapePathName(raw)
       // try and parse the date, if no exception occurs this is a candidate to be resolved as
       // TimestampType
-      DateTimeUtils.getThreadLocalTimestampFormat(timeZone).parse(unescapedRaw)
+      timestampFormatter.parse(unescapedRaw)
       // SPARK-23436: see comment for date
       val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval()
       // Disallow TimestampType if the cast returned null
index 51c385e..13ed105 100644 (file)
@@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext}
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.jdbc.JdbcDialects
 import org.apache.spark.sql.sources._
@@ -185,10 +185,11 @@ private[sql] object JDBCRelation extends Logging {
       columnType: DataType,
       timeZoneId: String): String = {
     def dateTimeToString(): String = {
-      val timeZone = DateTimeUtils.getTimeZone(timeZoneId)
       val dateTimeStr = columnType match {
-        case DateType => DateTimeUtils.dateToString(value.toInt, timeZone)
-        case TimestampType => DateTimeUtils.timestampToString(value, timeZone)
+        case DateType => DateFormatter().format(value.toInt)
+        case TimestampType =>
+          val timestampFormatter = TimestampFormatter(DateTimeUtils.getTimeZone(timeZoneId))
+          DateTimeUtils.timestampToString(timestampFormatter, value)
       }
       s"'$dateTimeStr'"
     }
index 4205b3f..bbce470 100644 (file)
 
 package org.apache.spark.sql.execution
 
+import java.sql.{Date, Timestamp}
+
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT}
+import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
+
+class HiveResultSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
-class HiveResultSuite extends SparkFunSuite {
+  test("date formatting in hive result") {
+    val date = "2018-12-28"
+    val executedPlan = Seq(Date.valueOf(date)).toDS().queryExecution.executedPlan
+    val result = HiveResult.hiveResultString(executedPlan)
+    assert(result.head == date)
+  }
+
+  test("timestamp formatting in hive result") {
+    val timestamp = "2018-12-28 01:02:03"
+    val executedPlan = Seq(Timestamp.valueOf(timestamp)).toDS().queryExecution.executedPlan
+    val result = HiveResult.hiveResultString(executedPlan)
+    assert(result.head == timestamp)
+  }
 
   test("toHiveString correctly handles UDTs") {
     val point = new ExamplePoint(50.0, 50.0)
index 8806735..864c1e9 100644 (file)
@@ -32,7 +32,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
 import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition}
 import org.apache.spark.sql.execution.streaming.MemoryStream
@@ -56,6 +56,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
 
   val timeZone = TimeZone.getDefault()
   val timeZoneId = timeZone.getID
+  val df = DateFormatter()
+  val tf = TimestampFormatter(timestampPartitionPattern, timeZone)
 
   protected override def beforeAll(): Unit = {
     super.beforeAll()
@@ -69,7 +71,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
 
   test("column type inference") {
     def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = {
-      assert(inferPartitionColumnValue(raw, true, timeZone) === literal)
+      assert(inferPartitionColumnValue(raw, true, timeZone, df, tf) === literal)
     }
 
     check("10", Literal.create(10, IntegerType))
@@ -197,13 +199,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
   test("parse partition") {
     def check(path: String, expected: Option[PartitionValues]): Unit = {
       val actual = parsePartition(new Path(path), true, Set.empty[Path],
-        Map.empty, true, timeZone)._1
+        Map.empty, true, timeZone, df, tf)._1
       assert(expected === actual)
     }
 
     def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = {
       val message = intercept[T] {
-        parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone)
+        parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone, df, tf)
       }.getMessage
 
       assert(message.contains(expected))
@@ -249,7 +251,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
       basePaths = Set(new Path("file://path/a=10")),
       Map.empty,
       true,
-      timeZone = timeZone)._1
+      timeZone = timeZone,
+      df,
+      tf)._1
 
     assert(partitionSpec1.isEmpty)
 
@@ -260,7 +264,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
       basePaths = Set(new Path("file://path")),
       Map.empty,
       true,
-      timeZone = timeZone)._1
+      timeZone = timeZone,
+      df,
+      tf)._1
 
     assert(partitionSpec2 ==
       Option(PartitionValues(