[BAHIR-233] Add SNS message support for SQS streaming source (#97) master
authorDmitryGrb <gorbatcevich.d@gmail.com>
Sat, 12 Dec 2020 19:16:46 +0000 (22:16 +0300)
committerGitHub <noreply@github.com>
Sat, 12 Dec 2020 19:16:46 +0000 (11:16 -0800)
Added messageWrapper option for SQS streaming connector
which says if this is pure s3 notification event or it is coming
from SNS topic

sql-streaming-sqs/README.md
sql-streaming-sqs/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsClient.scala
sql-streaming-sqs/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSourceOptions.scala
sql-streaming-sqs/src/test/scala/org/apache/spark/sql/streaming/sqs/SqsSourceOptionsSuite.scala

index b0555d6c283c0a85fdf84e1ce8f9108b09a4f124..8dd8a5409a722e0773b554182b4b5795e5906c38 100644 (file)
@@ -63,6 +63,18 @@ shouldSortFiles|true|whether to sort files based on timestamp while listing them
 useInstanceProfileCredentials|false|Whether to use EC2 instance profile credentials for connecting to Amazon SQS
 maxFilesPerTrigger|no default value|maximum number of files to process in a microbatch
 maxFileAge|7d|Maximum age of a file that can be found in this directory
 useInstanceProfileCredentials|false|Whether to use EC2 instance profile credentials for connecting to Amazon SQS
 maxFilesPerTrigger|no default value|maximum number of files to process in a microbatch
 maxFileAge|7d|Maximum age of a file that can be found in this directory
+messageWrapper|None| - 'None' if SQS contains plain S3 message. <br/> - 'SNS' if SQS contains S3 notification message which came from SNS. <br/> Please see 'Use multiple consumers' section for more details 
+
+## Use multiple consumers
+
+SQS cannot be read by multiple consumers. <br/> 
+If S3 path should be listen by multiple applications the following approach is recommended: S3 -> SNS -> SQS:
+1. Create multiple SQS queues. Each application listen for one SQS queue.
+2. Create 1 SNS topic
+3. Once S3 notification event is pushed to SNS topic it will be delivered to each SQS queue 
+
+Thus, one S3 path can be processed by multiple applications. <br/>
 
 ## Example
 
 
 ## Example
 
index 1d9bb97b3513867223ec7761226d68af69f90a8b..72b4568a5e762f7a25b8417d2de46cdc3fc25eb6 100644 (file)
@@ -28,7 +28,7 @@ import com.amazonaws.services.sqs.{AmazonSQS, AmazonSQSClientBuilder}
 import com.amazonaws.services.sqs.model.{DeleteMessageBatchRequestEntry, Message, ReceiveMessageRequest}
 import org.apache.hadoop.conf.Configuration
 import org.json4s.{DefaultFormats, MappingException}
 import com.amazonaws.services.sqs.model.{DeleteMessageBatchRequestEntry, Message, ReceiveMessageRequest}
 import org.apache.hadoop.conf.Configuration
 import org.json4s.{DefaultFormats, MappingException}
-import org.json4s.JsonAST.JValue
+import org.json4s.JsonAST.{JNothing, JValue}
 import org.json4s.jackson.JsonMethods.parse
 
 import org.apache.spark.SparkException
 import org.json4s.jackson.JsonMethods.parse
 
 import org.apache.spark.SparkException
@@ -131,13 +131,32 @@ class SqsClient(sourceOptions: SqsSourceOptions,
     }
   }
 
     }
   }
 
+  private def tryToParseSNS(parsedBody: JValue): JValue = {
+    implicit val formats = DefaultFormats
+    parsedBody \ "Message" match {
+      case JNothing => throw new MappingException("Original message does not look like SNS one. " +
+        "Please check your setup and make sure it is S3 notification event coming from SNS")
+      case value => parse(value.extract[String])
+    }
+  }
+
+  private def extractS3Message(parsedBody: JValue): JValue = {
+    sourceOptions.messageWrapper match {
+      case sourceOptions.S3MessageWrapper.None => parsedBody
+      case sourceOptions.S3MessageWrapper.SNS => tryToParseSNS(parsedBody)
+    }
+  }
+
   private def parseSqsMessages(messageList: Seq[Message]): Seq[(String, Long, String)] = {
     val errorMessages = scala.collection.mutable.ListBuffer[String]()
     val parsedMessages = messageList.foldLeft(Seq[(String, Long, String)]()) { (list, message) =>
       implicit val formats = DefaultFormats
       try {
         val messageReceiptHandle = message.getReceiptHandle
   private def parseSqsMessages(messageList: Seq[Message]): Seq[(String, Long, String)] = {
     val errorMessages = scala.collection.mutable.ListBuffer[String]()
     val parsedMessages = messageList.foldLeft(Seq[(String, Long, String)]()) { (list, message) =>
       implicit val formats = DefaultFormats
       try {
         val messageReceiptHandle = message.getReceiptHandle
-        val messageJson = parse(message.getBody).extract[JValue]
+
+        val parsedBody: JValue = parse(message.getBody).extract[JValue]
+        val messageJson = extractS3Message(parsedBody)
+
         val bucketName = (
           messageJson \ "Records" \ "s3" \ "bucket" \ "name").extract[Array[String]].head
         val eventName = (messageJson \ "Records" \ "eventName").extract[Array[String]].head
         val bucketName = (
           messageJson \ "Records" \ "s3" \ "bucket" \ "name").extract[Array[String]].head
         val eventName = (messageJson \ "Records" \ "eventName").extract[Array[String]].head
index a4c0cc1ae67afdec7af8de7fa69e92de5125137b..0c2eda0513fbfd15c53c874dcca248dd7ca3431b 100644 (file)
@@ -28,6 +28,15 @@ import org.apache.spark.util.Utils
  */
 class SqsSourceOptions(parameters: CaseInsensitiveMap[String]) extends Logging {
 
  */
 class SqsSourceOptions(parameters: CaseInsensitiveMap[String]) extends Logging {
 
+  object S3MessageWrapper extends Enumeration {
+    type MessageFormat = Value
+    val None, SNS = Value
+
+    def withNameOpt(opt: String): Option[Value] = {
+      values.find(_.toString.toLowerCase == opt.toLowerCase)
+    }
+  }
+
   def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
 
   val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str =>
   def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
 
   val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str =>
@@ -92,6 +101,13 @@ class SqsSourceOptions(parameters: CaseInsensitiveMap[String]) extends Logging {
     throw new IllegalArgumentException("Specifying file format is mandatory with sqs source")
   }
 
     throw new IllegalArgumentException("Specifying file format is mandatory with sqs source")
   }
 
+  val messageWrapper: S3MessageWrapper.Value = parameters.get("messageWrapper").map( str =>
+    S3MessageWrapper.withNameOpt(str).getOrElse({
+      throw new IllegalArgumentException(s"Invalid value '$str' for option 'messageWrapper', " +
+        s"must be one of [${S3MessageWrapper.values.mkString(", ")}]")
+    })
+  ).getOrElse(S3MessageWrapper.None)
+
   val ignoreFileDeletion: Boolean = withBooleanParameter("ignoreFileDeletion", false)
 
    /**
   val ignoreFileDeletion: Boolean = withBooleanParameter("ignoreFileDeletion", false)
 
    /**
index 6382fb1f87e96875f3845dd7661e8e1e81db94c4..e43bae32aea9ce539bd2cc275690d1b2f498d102 100644 (file)
@@ -74,6 +74,8 @@ class SqsSourceOptionsSuite extends StreamTest {
       "for option 'shouldSortFiles', must be true or false")
     testBadOptions("useInstanceProfileCredentials" -> "x")("Invalid value 'x' " +
       "for option 'useInstanceProfileCredentials', must be true or false")
       "for option 'shouldSortFiles', must be true or false")
     testBadOptions("useInstanceProfileCredentials" -> "x")("Invalid value 'x' " +
       "for option 'useInstanceProfileCredentials', must be true or false")
+    testBadOptions("messageWrapper" -> "x")("Invalid value 'x' " +
+      "for option 'messageWrapper', must be one of [none, sns]")
 
   }
 
 
   }