allow ntasks to be 0 and determine whetheher to use threads based on nthreads rather... main
authorBen Baumgold <4933671+baumgold@users.noreply.github.com>
Fri, 10 Jun 2022 07:11:07 +0000 (03:11 -0400)
committerGitHub <noreply@github.com>
Fri, 10 Jun 2022 07:11:07 +0000 (01:11 -0600)
Co-authored-by: Ben Baumgold <ben.baumgold@mavensecurities.com>
src/write.jl

index 5439b95e80a66054af88120a33b5b5207785e291..ae2da6dd0685518cd89d289808f2b37e3a2d9f14 100644 (file)
@@ -46,7 +46,7 @@ Supported keyword arguments to `Arrow.write` include:
   * `largelists::Bool=false`: causes list column types to be written with Int64 offset arrays; mainly for testing purposes; by default, Int64 offsets will be used only if needed
   * `maxdepth::Int=$DEFAULT_MAX_DEPTH`: deepest allowed nested serialization level; this is provided by default to prevent accidental infinite recursion with mutually recursive data structures
   * `metadata=Arrow.getmetadata(tbl)`: the metadata that should be written as the table's schema's `custom_metadata` field; must either be `nothing` or an iterable of `<:AbstractString` pairs.
-  * `ntasks::Int`: number of concurrent threaded tasks to allow while writing input partitions out as arrow record batches; default is no limit; to disable multithreaded writing, pass `ntasks=1`
+  * `ntasks::Int`: number of buffered threaded tasks to allow while writing input partitions out as arrow record batches; default is no limit; for unbuffered writing, pass `ntasks=0`
   * `file::Bool=false`: if a an `io` argument is being written to, passing `file=true` will cause the arrow file format to be written instead of just IPC streaming
 """
 function write end
@@ -135,16 +135,13 @@ mutable struct Writer{T<:IO}
 end
 
 function Base.open(::Type{Writer}, io::T, compress::Union{Nothing,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}}, writetofile::Bool, largelists::Bool, denseunions::Bool, dictencode::Bool, dictencodenested::Bool, alignment::Integer, maxdepth::Integer, ntasks::Integer, meta::Union{Nothing,Any}, colmeta::Union{Nothing,Any}, closeio::Bool) where {T<:IO}
-    if ntasks < 1
-        throw(ArgumentError("ntasks keyword argument must be > 0; pass `ntasks=1` to disable multithreaded writing"))
-    end
     msgs = OrderedChannel{Message}(ntasks)
     schema = Ref{Tables.Schema}()
     firstcols = Ref{Any}()
     dictencodings = Dict{Int64,Any}() # Lockable{DictEncoding}
     blocks = (Block[], Block[])
     # start message writing from channel
-    threaded = ntasks > 1
+    threaded = Threads.nthreads() > 1
     task = threaded ? (Threads.@spawn for msg in msgs
         Base.write(io, msg, blocks, schema, alignment)
     end) : (@async for msg in msgs