Skip to content

Commit

Permalink
Avoid NPE in ConnectPlan (#8514)
Browse files Browse the repository at this point in the history
* Avoid NPE in ConnectPlan

* Avoid NPE in ConnectPlan

* Avoid NPE in ConnectPlan

* cleanup
  • Loading branch information
yschimke authored Nov 19, 2024
1 parent 8da7440 commit aac6c70
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ class ConnectPlan(
internal var socket: Socket? = null
private var handshake: Handshake? = null
private var protocol: Protocol? = null
private var source: BufferedSource? = null
private var sink: BufferedSink? = null
private lateinit var source: BufferedSource
private lateinit var sink: BufferedSink
private var connection: RealConnection? = null

/** True if this connection is ready for use, including TCP, tunnels, and TLS. */
Expand Down Expand Up @@ -152,7 +152,7 @@ class ConnectPlan(
}

override fun connectTlsEtc(): ConnectResult {
check(rawSocket != null) { "TCP not connected" }
val rawSocket = requireNotNull(rawSocket) { "TCP not connected" }
check(!isReady) { "already connected" }

val connectionSpecs = route.address.connectionSpecs
Expand All @@ -176,7 +176,7 @@ class ConnectPlan(
// that happens, then we will have buffered bytes that are needed by the SSLSocket!
// This check is imperfect: it doesn't tell us whether a handshake will succeed, just
// that it will almost certainly fail because the proxy has sent unexpected data.
if (source?.buffer?.exhausted() == false || sink?.buffer?.exhausted() == false) {
if (!source.buffer.exhausted() || !sink.buffer.exhausted()) {
throw IOException("TLS tunnel buffered too many bytes!")
}

Expand Down Expand Up @@ -216,9 +216,9 @@ class ConnectPlan(
connectionPool = connectionPool,
route = route,
rawSocket = rawSocket,
socket = socket,
socket = socket!!,
handshake = handshake,
protocol = protocol,
protocol = protocol!!,
source = source,
sink = sink,
pingIntervalMillis = pingIntervalMillis,
Expand Down Expand Up @@ -247,7 +247,7 @@ class ConnectPlan(
user.removePlanToCancel(this)
if (!success) {
socket?.closeQuietly()
rawSocket?.closeQuietly()
rawSocket.closeQuietly()
}
}
}
Expand Down Expand Up @@ -420,8 +420,6 @@ class ConnectPlan(
val url = route.address.url
val requestLine = "CONNECT ${url.toHostHeader(includeDefaultPort = true)} HTTP/1.1"
while (true) {
val source = this.source!!
val sink = this.sink!!
val tunnelCodec =
Http1ExchangeCodec(
// No client for CONNECT tunnels:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ import okhttp3.internal.http2.StreamResetException
import okhttp3.internal.isHealthy
import okhttp3.internal.tls.OkHostnameVerifier
import okhttp3.internal.ws.RealWebSocket
import okio.Buffer
import okio.BufferedSink
import okio.BufferedSource
import okio.Sink
import okio.Source
import okio.Timeout
import okio.buffer

/**
* A connection to a remote web server capable of carrying 1 or more concurrent streams.
Expand All @@ -67,16 +72,16 @@ class RealConnection(
val connectionPool: RealConnectionPool,
override val route: Route,
/** The low-level TCP socket. */
private var rawSocket: Socket?,
private val rawSocket: Socket,
/**
* The application layer socket. Either an [SSLSocket] layered over [rawSocket], or [rawSocket]
* itself if this connection does not use SSL.
*/
private var socket: Socket?,
private var handshake: Handshake?,
private var protocol: Protocol?,
private var source: BufferedSource?,
private var sink: BufferedSink?,
private val socket: Socket,
private val handshake: Handshake?,
private val protocol: Protocol,
private val source: BufferedSource,
private val sink: BufferedSink,
private val pingIntervalMillis: Int,
internal val connectionListener: ConnectionListener,
) : Http2Connection.Listener(), Connection, ExchangeCodec.Carrier {
Expand Down Expand Up @@ -162,9 +167,6 @@ class RealConnection(

@Throws(IOException::class)
private fun startHttp2() {
val socket = this.socket!!
val source = this.source!!
val sink = this.sink!!
socket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream.
val flowControlListener = connectionListener as? FlowControlListener ?: FlowControlListener.None
val http2Connection =
Expand Down Expand Up @@ -253,7 +255,7 @@ class RealConnection(
}

// We have a host mismatch. But if the certificate matches, we're still good.
return !noCoalescedConnections && handshake != null && certificateSupportHost(url, handshake!!)
return !noCoalescedConnections && handshake != null && certificateSupportHost(url, handshake)
}

private fun certificateSupportHost(
Expand All @@ -271,9 +273,9 @@ class RealConnection(
client: OkHttpClient,
chain: RealInterceptorChain,
): ExchangeCodec {
val socket = this.socket!!
val source = this.source!!
val sink = this.sink!!
val socket = this.socket
val source = this.source
val sink = this.sink
val http2Connection = this.http2Connection

return if (http2Connection != null) {
Expand All @@ -288,10 +290,6 @@ class RealConnection(

@Throws(SocketException::class)
internal fun newWebSocketStreams(exchange: Exchange): RealWebSocket.Streams {
val socket = this.socket!!
val source = this.source!!
val sink = this.sink!!

socket.soTimeout = 0
noNewExchanges()
return object : RealWebSocket.Streams(true, source, sink) {
Expand All @@ -309,20 +307,17 @@ class RealConnection(

override fun cancel() {
// Close the raw socket so we don't end up doing synchronous I/O.
rawSocket?.closeQuietly()
rawSocket.closeQuietly()
}

override fun socket(): Socket = socket!!
override fun socket(): Socket = socket

/** Returns true if this connection is ready to host new streams. */
fun isHealthy(doExtensiveChecks: Boolean): Boolean {
lock.assertNotHeld()

val nowNs = System.nanoTime()

val rawSocket = this.rawSocket!!
val socket = this.socket!!
val source = this.source!!
if (rawSocket.isClosed || socket.isClosed || socket.isInputShutdown ||
socket.isOutputShutdown
) {
Expand Down Expand Up @@ -442,7 +437,7 @@ class RealConnection(
}
}

override fun protocol(): Protocol = protocol!!
override fun protocol(): Protocol = protocol

override fun toString(): String {
return "Connection{${route.address.url.host}:${route.address.url.port}," +
Expand All @@ -467,12 +462,38 @@ class RealConnection(
taskRunner = taskRunner,
connectionPool = connectionPool,
route = route,
rawSocket = null,
rawSocket = Socket(),
socket = socket,
handshake = null,
protocol = null,
source = null,
sink = null,
protocol = Protocol.HTTP_2,
source =
object : Source {
override fun close() = Unit

override fun read(
sink: Buffer,
byteCount: Long,
): Long {
throw UnsupportedOperationException()
}

override fun timeout(): Timeout = Timeout.NONE
}.buffer(),
sink =
object : Sink {
override fun close() = Unit

override fun flush() = Unit

override fun timeout(): Timeout = Timeout.NONE

override fun write(
source: Buffer,
byteCount: Long,
) {
throw UnsupportedOperationException()
}
}.buffer(),
pingIntervalMillis = 0,
ConnectionListener.NONE,
)
Expand Down

0 comments on commit aac6c70

Please sign in to comment.