1+ /*
2+ * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+ *
5+ * This code is free software; you can redistribute it and/or modify it
6+ * under the terms of the GNU General Public License version 2 only, as
7+ * published by the Free Software Foundation.
8+ *
9+ * This code is distributed in the hope that it will be useful, but WITHOUT
10+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+ * version 2 for more details (a copy is included in the LICENSE file that
13+ * accompanied this code).
14+ *
15+ * You should have received a copy of the GNU General Public License version
16+ * 2 along with this work; if not, write to the Free Software Foundation,
17+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+ *
19+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+ * or visit www.oracle.com if you need additional information or have any
21+ * questions.
22+ */
23+
24+ /*
25+ * @test
26+ * @bug 8358764
27+ * @summary Test closing a socket while a thread is blocked in read. The connection
28+ * should be closed gracefuly so that the peer reads EOF.
29+ * @run junit PeerReadsAfterAsyncClose
30+ */
31+
32+ import java .io .IOException ;
33+ import java .net .InetAddress ;
34+ import java .net .InetSocketAddress ;
35+ import java .net .ServerSocket ;
36+ import java .net .Socket ;
37+ import java .net .SocketException ;
38+ import java .nio .ByteBuffer ;
39+ import java .nio .channels .ClosedChannelException ;
40+ import java .nio .channels .SocketChannel ;
41+ import java .util .Arrays ;
42+ import java .util .Objects ;
43+ import java .util .concurrent .ThreadFactory ;
44+ import java .util .concurrent .atomic .AtomicBoolean ;
45+ import java .util .stream .Stream ;
46+
47+ import org .junit .jupiter .params .ParameterizedTest ;
48+ import org .junit .jupiter .params .provider .MethodSource ;
49+ import static org .junit .jupiter .api .Assertions .*;
50+
51+ class PeerReadsAfterAsyncClose {
52+
53+ static Stream <ThreadFactory > factories () {
54+ return Stream .of (Thread .ofPlatform ().factory (), Thread .ofVirtual ().factory ());
55+ }
56+
57+ /**
58+ * Close SocketChannel while a thread is blocked reading from the channel's socket.
59+ */
60+ @ ParameterizedTest
61+ @ MethodSource ("factories" )
62+ void testCloseDuringSocketChannelRead (ThreadFactory factory ) throws Exception {
63+ var loopback = InetAddress .getLoopbackAddress ();
64+ try (var listener = new ServerSocket ()) {
65+ listener .bind (new InetSocketAddress (loopback , 0 ));
66+
67+ try (SocketChannel sc = SocketChannel .open (listener .getLocalSocketAddress ());
68+ Socket peer = listener .accept ()) {
69+
70+ // start thread to read from channel
71+ var cceThrown = new AtomicBoolean ();
72+ Thread thread = factory .newThread (() -> {
73+ try {
74+ sc .read (ByteBuffer .allocate (1 ));
75+ fail ();
76+ } catch (ClosedChannelException e ) {
77+ cceThrown .set (true );
78+ } catch (Throwable e ) {
79+ e .printStackTrace ();
80+ }
81+ });
82+ thread .start ();
83+ try {
84+ // close SocketChannel when thread sampled in implRead
85+ onReach (thread , "sun.nio.ch.SocketChannelImpl.implRead" , () -> {
86+ try {
87+ sc .close ();
88+ } catch (IOException ignore ) { }
89+ });
90+
91+ // peer should read EOF
92+ int n = peer .getInputStream ().read ();
93+ assertEquals (-1 , n );
94+ } finally {
95+ thread .join ();
96+ }
97+ assertEquals (true , cceThrown .get (), "ClosedChannelException not thrown" );
98+ }
99+ }
100+ }
101+
102+ /**
103+ * Close Socket while a thread is blocked reading from the socket.
104+ */
105+ @ ParameterizedTest
106+ @ MethodSource ("factories" )
107+ void testCloseDuringSocketUntimedRead (ThreadFactory factory ) throws Exception {
108+ testCloseDuringSocketRead (factory , 0 );
109+ }
110+
111+ /**
112+ * Close Socket while a thread is blocked reading from the socket with a timeout.
113+ */
114+ @ ParameterizedTest
115+ @ MethodSource ("factories" )
116+ void testCloseDuringSockeTimedRead (ThreadFactory factory ) throws Exception {
117+ testCloseDuringSocketRead (factory , 60_000 );
118+ }
119+
120+ private void testCloseDuringSocketRead (ThreadFactory factory , int timeout ) throws Exception {
121+ var loopback = InetAddress .getLoopbackAddress ();
122+ try (var listener = new ServerSocket ()) {
123+ listener .bind (new InetSocketAddress (loopback , 0 ));
124+
125+ try (Socket s = new Socket (loopback , listener .getLocalPort ());
126+ Socket peer = listener .accept ()) {
127+
128+ // start thread to read from socket
129+ var seThrown = new AtomicBoolean ();
130+ Thread thread = factory .newThread (() -> {
131+ try {
132+ s .setSoTimeout (timeout );
133+ s .getInputStream ().read ();
134+ fail ();
135+ } catch (SocketException e ) {
136+ seThrown .set (true );
137+ } catch (Throwable e ) {
138+ e .printStackTrace ();
139+ }
140+ });
141+ thread .start ();
142+ try {
143+ // close Socket when thread sampled in implRead
144+ onReach (thread , "sun.nio.ch.NioSocketImpl.implRead" , () -> {
145+ try {
146+ s .close ();
147+ } catch (IOException ignore ) { }
148+ });
149+
150+ // peer should read EOF
151+ int n = peer .getInputStream ().read ();
152+ assertEquals (-1 , n );
153+ } finally {
154+ thread .join ();
155+ }
156+ assertEquals (true , seThrown .get (), "SocketException not thrown" );
157+ }
158+ }
159+ }
160+
161+ /**
162+ * Runs the given action when the given target thread is sampled at the given
163+ * location. The location takes the form "{@code c.m}" where
164+ * {@code c} is the fully qualified class name and {@code m} is the method name.
165+ */
166+ private void onReach (Thread target , String location , Runnable action ) {
167+ int index = location .lastIndexOf ('.' );
168+ String className = location .substring (0 , index );
169+ String methodName = location .substring (index + 1 );
170+ Thread .ofPlatform ().daemon (true ).start (() -> {
171+ try {
172+ boolean found = false ;
173+ while (!found ) {
174+ found = contains (target .getStackTrace (), className , methodName );
175+ if (!found ) {
176+ Thread .sleep (20 );
177+ }
178+ }
179+ action .run ();
180+ } catch (Exception e ) {
181+ e .printStackTrace ();
182+ }
183+ });
184+ }
185+
186+ /**
187+ * Returns true if the given stack trace contains an element for the given class
188+ * and method name.
189+ */
190+ private boolean contains (StackTraceElement [] stack , String className , String methodName ) {
191+ return Arrays .stream (stack )
192+ .anyMatch (e -> className .equals (e .getClassName ())
193+ && methodName .equals (e .getMethodName ()));
194+ }
195+ }
0 commit comments