Skip to content

Commit

Permalink
1) moving read of byte buff and release to helper method in TLSPSKHan…
Browse files Browse the repository at this point in the history
…dler 2) adding comment in Http2OrHttpHandler and use readBytes instead of readSlice 3) adding SslCloseCompletionEvent on close_notify alert 4) handling null value TLS_HANDSHAKE_USING_EXTERNAL_PSK
  • Loading branch information
deeptiv1991 committed Oct 16, 2024
1 parent 3cc8293 commit a3c5d39
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@
import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.AttributeKey;

import java.util.function.Consumer;

/**
* Http2 Or Http Handler
*
* <p>
* Author: Arthur Gonigberg
* Date: December 15, 2017
*/
Expand Down Expand Up @@ -75,6 +74,9 @@ public Http2OrHttpHandler(
this.addHttpHandlerFn = addHttpHandlerFn;
}

/**
* this method is inspired by ApplicationProtocolNegotiationHandler.userEventTriggered
*/
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent handshakeEvent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
import io.netty.channel.ChannelPromise;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import org.bouncycastle.tls.CipherSuite;
import org.bouncycastle.tls.ProtocolName;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCryptoProvider;

import javax.net.ssl.SSLSession;
import java.security.SecureRandom;
import java.util.Map;
import java.util.Set;
import javax.net.ssl.SSLSession;
import org.bouncycastle.tls.CipherSuite;
import org.bouncycastle.tls.ProtocolName;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCryptoProvider;

public class TlsPskHandler extends ChannelDuplexHandler {

Expand Down Expand Up @@ -67,8 +66,7 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
promise.setFailure(new IllegalStateException("Failed to write message on the channel. Message is not a ByteBuf"));
return;
}
byte[] appDataBytes = byteBufMsg.hasArray() ? byteBufMsg.array() : TlsPskUtils.readDirect(byteBufMsg);
ReferenceCountUtil.safeRelease(byteBufMsg);
byte[] appDataBytes = TlsPskUtils.getAppDataBytesAndRelease(byteBufMsg);
tlsPskServerProtocol.writeApplicationData(appDataBytes, 0, appDataBytes.length);
int availableOutputBytes = tlsPskServerProtocol.getAvailableOutputBytes();
if (availableOutputBytes != 0) {
Expand Down Expand Up @@ -98,11 +96,11 @@ public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
* the protocol name or null if application-level protocol has not been negotiated
*/
public String getApplicationProtocol() {
return tlsPskServer!=null ? tlsPskServer.getApplicationProtocol() : null;
return tlsPskServer != null ? tlsPskServer.getApplicationProtocol() : null;
}

public SSLSession getSession() {
return tlsPskServerProtocol!=null ? tlsPskServerProtocol.getSSLSession() : null;
return tlsPskServerProtocol != null ? tlsPskServerProtocol.getSSLSession() : null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
package com.netflix.zuul.netty.server.psk;

import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCountUtil;

class TlsPskUtils {
static byte[] readDirect(ByteBuf byteBufMsg) {
protected static byte[] readDirect(ByteBuf byteBufMsg) {
int length = byteBufMsg.readableBytes();
byte[] dest = new byte[length];
byteBufMsg.readSlice(length).getBytes(0, dest);
byteBufMsg.readBytes(dest);
return dest;
}

protected static byte[] getAppDataBytesAndRelease(ByteBuf byteBufMsg) {
byte[] appDataBytes = byteBufMsg.hasArray() ? byteBufMsg.array() : TlsPskUtils.readDirect(byteBufMsg);
ReferenceCountUtil.safeRelease(byteBufMsg);
return appDataBytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@
import com.netflix.spectator.api.Registry;
import com.netflix.spectator.api.Timer;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.ssl.SslCloseCompletionEvent;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.AttributeKey;
import java.io.IOException;
import java.util.Hashtable;
import java.util.Set;
import java.util.Vector;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.SneakyThrows;
import org.bouncycastle.tls.AbstractTlsServer;
import org.bouncycastle.tls.AlertDescription;
Expand All @@ -40,13 +47,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Hashtable;
import java.util.Set;
import java.util.Vector;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

public class ZuulPskServer extends AbstractTlsServer {

private static final Logger LOGGER = LoggerFactory.getLogger(ZuulPskServer.class);
Expand Down Expand Up @@ -101,7 +101,7 @@ public TlsCredentials getCredentials() {
@Override
protected Vector getProtocolNames() {
Vector protocolNames = new Vector();
if (supportedApplicationProtocols!=null) {
if (supportedApplicationProtocols != null) {
supportedApplicationProtocols.forEach(protocolNames::addElement);
}
return protocolNames;
Expand Down Expand Up @@ -145,16 +145,19 @@ public ProtocolVersion getServerVersion() throws IOException {
@Override
@SneakyThrows
// TODO: Ask BC folks to see if getExternalPSK can throw a checked exception
// https://github.com/bcgit/bc-java/issues/1673
public TlsPSKExternal getExternalPSK(Vector clientPskIdentities) {
byte[] clientPskIdentity = ((PskIdentity)clientPskIdentities.get(0)).getIdentity();
byte[] clientPskIdentity = ((PskIdentity) clientPskIdentities.get(0)).getIdentity();
byte[] psk;
try{
try {
this.ctx.channel().attr(TlsPskHandler.CLIENT_PSK_IDENTITY_ATTRIBUTE_KEY).set(new ClientPSKIdentityInfo(clientPskIdentity));
psk = externalTlsPskProvider.provide(clientPskIdentity, this.context.getSecurityParametersHandshake().getClientRandom());
}catch (PskCreationFailureException e) {
} catch (PskCreationFailureException e) {
throw switch (e.getTlsAlertMessage()) {
case unknown_psk_identity -> new TlsFatalAlert(AlertDescription.unknown_psk_identity, "Unknown or null client PSk identity");
case decrypt_error -> new TlsFatalAlert(AlertDescription.decrypt_error, "Invalid or expired client PSk identity");
case unknown_psk_identity ->
new TlsFatalAlert(AlertDescription.unknown_psk_identity, "Unknown or null client PSk identity");
case decrypt_error ->
new TlsFatalAlert(AlertDescription.decrypt_error, "Invalid or expired client PSk identity");
};
}
TlsSecret pskTlsSecret = getCrypto().createSecret(psk);
Expand All @@ -174,6 +177,10 @@ public void notifyAlertRaised(short alertLevel, short alertDescription, String m
if (cause != null) {
LOGGER.error("TLS/PSK alert stacktrace", cause);
}

if (alertDescription == AlertDescription.close_notify) {
ctx.fireUserEventTriggered(SslCloseCompletionEvent.SUCCESS);
}
}

@Override
Expand Down Expand Up @@ -210,7 +217,7 @@ public void getServerExtensionsForConnection(Hashtable serverExtensions) throws
public String getApplicationProtocol() {
ProtocolName protocolName =
context.getSecurityParametersConnection().getApplicationProtocol();
if (protocolName!=null) {
if (protocolName != null) {
return protocolName.getUtf8Decoding();
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.AttributeKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.nio.channels.ClosedChannelException;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Stores info about the client and server's SSL certificates in the context, after a successful handshake.
Expand Down Expand Up @@ -105,9 +104,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
serverCert = session.getLocalCertificates()[0];
}

Boolean tlsHandshakeUsingExternalPSK = ctx.channel()
//if attribute is true, then true. If null or false then false
boolean tlsHandshakeUsingExternalPSK = Boolean.TRUE.equals(ctx.channel()
.attr(ZuulPskServer.TLS_HANDSHAKE_USING_EXTERNAL_PSK)
.get();
.get());

ClientPSKIdentityInfo clientPSKIdentityInfo = ctx.channel()
.attr(TlsPskHandler.CLIENT_PSK_IDENTITY_ATTRIBUTE_KEY)
Expand Down Expand Up @@ -138,7 +138,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
CurrentPassport.fromChannel(ctx.channel()).getState();
if (cause instanceof ClosedChannelException
&& (PassportState.SERVER_CH_INACTIVE.equals(passportState)
|| PassportState.SERVER_CH_IDLE_TIMEOUT.equals(passportState))) {
|| PassportState.SERVER_CH_IDLE_TIMEOUT.equals(passportState))) {
// Either client closed the connection without/before having completed a handshake, or
// the connection idle timed-out before handshake.
// NOTE: we were seeing a lot of these in prod and can repro by just telnetting to port and then
Expand Down

0 comments on commit a3c5d39

Please sign in to comment.