[SSHD-877] Release failed sockets during acceptor bind attempt
authorLyor Goldstein <lgoldstein@apache.org>
Fri, 21 Dec 2018 10:34:07 +0000 (12:34 +0200)
committerLyor Goldstein <lgoldstein@apache.org>
Sun, 23 Dec 2018 06:00:41 +0000 (08:00 +0200)
sshd-common/src/main/java/org/apache/sshd/common/util/io/IoUtils.java
sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Acceptor.java
sshd-core/src/main/java/org/apache/sshd/common/io/nio2/Nio2Connector.java
sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoAcceptor.java

index 8f583ca..40ef9b0 100644 (file)
@@ -139,12 +139,31 @@ public final class IoUtils {
      *
      * @param closeables The {@link Closeable}s to close
      * @return The <U>first</U> {@link IOException} that occurred during closing
-     * of a resource - if more than one exception occurred, they are added as
-     * suppressed exceptions to the first one
+     * of a resource - {@code null} if not exception. If more than one exception
+     * occurred, they are added as suppressed exceptions to the first one
      * @see Throwable#getSuppressed()
      */
     @SuppressWarnings("ThrowableResultOfMethodCallIgnored")
     public static IOException closeQuietly(Closeable... closeables) {
+        return closeQuietly(GenericUtils.isEmpty(closeables) ? Collections.emptyList() : Arrays.asList(closeables));
+    }
+
+    /**
+     * Closes a bunch of resources suppressing any {@link IOException}s their
+     * {@link Closeable#close()} method may have thrown
+     *
+     * @param closeables The {@link Closeable}s to close
+     * @return The <U>first</U> {@link IOException} that occurred during closing
+     * of a resource - {@code null} if not exception. If more than one exception
+     * occurred, they are added as suppressed exceptions to the first one
+     * @see Throwable#getSuppressed()
+     */
+    @SuppressWarnings("ThrowableResultOfMethodCallIgnored")
+    public static IOException closeQuietly(Collection<? extends Closeable> closeables) {
+        if (GenericUtils.isEmpty(closeables)) {
+            return null;
+        }
+
         IOException err = null;
         for (Closeable c : closeables) {
             try {
index 3f7e8a8..141c896 100644 (file)
@@ -23,7 +23,9 @@ import java.net.SocketAddress;
 import java.nio.channels.AsynchronousChannelGroup;
 import java.nio.channels.AsynchronousServerSocketChannel;
 import java.nio.channels.AsynchronousSocketChannel;
+import java.nio.channels.ClosedChannelException;
 import java.nio.channels.CompletionHandler;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
@@ -37,7 +39,9 @@ import org.apache.sshd.common.FactoryManager;
 import org.apache.sshd.common.io.IoAcceptor;
 import org.apache.sshd.common.io.IoHandler;
 import org.apache.sshd.common.io.IoServiceEventListener;
+import org.apache.sshd.common.util.GenericUtils;
 import org.apache.sshd.common.util.ValidateUtils;
+import org.apache.sshd.common.util.io.IoUtils;
 
 /**
  * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
@@ -53,26 +57,107 @@ public class Nio2Acceptor extends Nio2Service implements IoAcceptor {
 
     @Override
     public void bind(Collection<? extends SocketAddress> addresses) throws IOException {
+        if (GenericUtils.isEmpty(addresses)) {
+            return;
+        }
+
         AsynchronousChannelGroup group = getChannelGroup();
-        for (SocketAddress address : addresses) {
-            if (log.isDebugEnabled()) {
-                log.debug("Binding Nio2Acceptor to address {}", address);
+        Collection<java.io.Closeable> bound = new ArrayList<>(addresses.size());
+        try {
+            boolean debugEnabled = log.isDebugEnabled();
+            for (SocketAddress address : addresses) {
+                if (debugEnabled) {
+                    log.debug("bind({}) binding to address", address);
+                }
+
+                try {
+                    AsynchronousServerSocketChannel asyncChannel =
+                        openAsynchronousServerSocketChannel(address, group);
+                    // In case it or the other bindings fail
+                    java.io.Closeable protector =
+                        protectInProgressBinding(address, asyncChannel);
+                    bound.add(protector);
+
+                    AsynchronousServerSocketChannel socket = setSocketOptions(asyncChannel);
+                    socket.bind(address, backlog);
+
+                    SocketAddress local = socket.getLocalAddress();
+                    if (debugEnabled) {
+                        log.debug("bind({}) bound to {}", address, local);
+                    }
+
+                    AsynchronousServerSocketChannel prev = channels.put(local, socket);
+                    if (prev != null) {
+                        if (debugEnabled) {
+                            log.debug("bind({}) replaced previous channel ({}) for {}",
+                                address, prev.getLocalAddress(), local);
+                        }
+                    }
+
+                    CompletionHandler<AsynchronousSocketChannel, ? super SocketAddress> handler =
+                        ValidateUtils.checkNotNull(createSocketCompletionHandler(channels, socket),
+                            "No completion handler created for address=%s[%s]",
+                            address, local);
+                    socket.accept(local, handler);
+                } catch (IOException | RuntimeException e) {
+                    log.error("bind({}) - failed ({}) to bind: {}",
+                        address, e.getClass().getSimpleName(), e.getMessage());
+                    if (debugEnabled) {
+                        log.debug("bind(" + address + ") failure details", e);
+                    }
+                    throw e;
+                }
             }
 
-            AsynchronousServerSocketChannel asyncChannel = openAsynchronousServerSocketChannel(address, group);
-            AsynchronousServerSocketChannel socket = setSocketOptions(asyncChannel);
-            socket.bind(address, backlog);
-            SocketAddress local = socket.getLocalAddress();
-            channels.put(local, socket);
-
-            CompletionHandler<AsynchronousSocketChannel, ? super SocketAddress> handler =
-                ValidateUtils.checkNotNull(createSocketCompletionHandler(channels, socket),
-                    "No completion handler created for address=%s",
-                    address);
-            socket.accept(local, handler);
+            bound.clear();  // avoid auto-close at finally clause
+        } finally {
+            IOException err = IoUtils.closeQuietly(bound);
+            if (err != null) {
+                throw err;
+            }
         }
     }
 
+    protected java.io.Closeable protectInProgressBinding(
+            SocketAddress address, AsynchronousServerSocketChannel asyncChannel) {
+        boolean debugEnabled = log.isDebugEnabled();
+
+        return new java.io.Closeable() {
+            @Override
+            @SuppressWarnings("synthetic-access")
+            public void close() throws IOException {
+                try {
+                    try {
+                        SocketAddress local = asyncChannel.getLocalAddress();
+                        // make sure bound channel
+                        if (local != null) {
+                            if (debugEnabled) {
+                                log.debug("protectInProgressBinding({}) remove {} binding", address, local);
+                            }
+                            channels.remove(local);
+                        }
+                    } finally {
+                        if (debugEnabled) {
+                            log.debug("protectInProgressBinding({}) auto-close", address);
+                        }
+
+                        asyncChannel.close();
+                    }
+                } catch (ClosedChannelException e) {
+                    // ignore if already closed
+                    if (debugEnabled) {
+                        log.debug("protectInProgressBinding(" + address + ") ignore close channel exception", e);
+                    }
+                }
+            }
+
+            @Override
+            public String toString() {
+                return "protectInProgressBinding(" + address + ")";
+            }
+        };
+    }
+
     protected AsynchronousServerSocketChannel openAsynchronousServerSocketChannel(
             SocketAddress address, AsynchronousChannelGroup group)
                 throws IOException {
index 4cebf57..faf9f03 100644 (file)
@@ -45,7 +45,8 @@ public class Nio2Connector extends Nio2Service implements IoConnector {
     }
 
     @Override
-    public IoConnectFuture connect(SocketAddress address, AttributeRepository context, SocketAddress localAddress) {
+    public IoConnectFuture connect(
+            SocketAddress address, AttributeRepository context, SocketAddress localAddress) {
         boolean debugEnabled = log.isDebugEnabled();
         if (debugEnabled) {
             log.debug("Connecting to {}", address);
index 1d4e7f6..96351f6 100644 (file)
@@ -23,6 +23,7 @@ import java.io.IOException;
 import java.io.InterruptedIOException;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.Map;
@@ -110,26 +111,79 @@ public class NettyIoAcceptor extends NettyIoService implements IoAcceptor {
 
     @Override
     public void bind(Collection<? extends SocketAddress> addresses) throws IOException {
-        for (SocketAddress address : addresses) {
-            bind(address);
+        if (GenericUtils.isEmpty(addresses)) {
+            return;
+        }
+
+        Collection<Channel> bound = new ArrayList<>(addresses.size());
+        try {
+            for (SocketAddress address : addresses) {
+                Channel channel = bindInternal(address);
+                bound.add(channel);
+            }
+
+            bound.clear();  // disable auto close at finally clause
+        } finally {
+            for (Channel channel : bound) {
+                closeChannel(channel);
+            }
         }
     }
 
     @Override
     public void bind(SocketAddress address) throws IOException {
+        bindInternal(address);
+    }
+
+    protected Channel bindInternal(SocketAddress address) throws IOException {
         InetSocketAddress inetAddress = (InetSocketAddress) address;
+        boolean debugEnabled = log.isDebugEnabled();
+        if (debugEnabled) {
+            log.debug("bindInternal({}) binding", address);
+        }
+
         ChannelFuture f = bootstrap.bind(inetAddress);
         Channel channel = f.channel();
         channelGroup.add(channel);
         try {
             f.sync();
+
             SocketAddress bound = channel.localAddress();
-            boundAddresses.put(bound, channel);
+            if (debugEnabled) {
+                log.debug("bindInternal({}) bound to {}", address, bound);
+            }
+
+            Channel prev = boundAddresses.put(bound, channel);
+            if (prev != null) {
+                if (debugEnabled) {
+                    log.debug("bindInternal({}) replaced entry of {} - previous={}",
+                        address, bound, prev.localAddress());
+                }
+            }
+
             channel.closeFuture().addListener(fut -> boundAddresses.remove(bound));
+
+            // disable auto close at finally clause
+            Channel returnValue = channel;
+            channel = null;
+            return returnValue;
         } catch (InterruptedException e) {
-            throw (InterruptedIOException) new InterruptedIOException().initCause(e);
-        } catch (Exception e) {
-            throw new IOException(e);
+            log.error("bindInternal({}) interrupted ({}): {}",
+                address, e.getClass().getSimpleName(), e.getMessage());
+            if (debugEnabled) {
+                log.debug("bindInternal(" + address + ") failure details", e);
+            }
+
+            throw (InterruptedIOException) new InterruptedIOException(e.getMessage()).initCause(e);
+        } finally {
+            closeChannel(channel);
+        }
+    }
+
+    protected void closeChannel(Channel channel) {
+        if (channel != null) {
+            channelGroup.remove(channel);
+            channel.close();
         }
     }