SaslAuthenticationHandler.java
/*
* SPDX-FileCopyrightText: 2023-2025 Lucimber UG
* SPDX-License-Identifier: Apache-2.0
*/
package com.lucimber.dbus.netty.sasl;
import com.lucimber.dbus.connection.sasl.SaslCommandName;
import com.lucimber.dbus.connection.sasl.SaslMessage;
import com.lucimber.dbus.netty.DBusChannelEvent;
import com.lucimber.dbus.netty.WriteOperationListener;
import com.lucimber.dbus.util.LoggerUtils;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public final class SaslAuthenticationHandler extends ChannelDuplexHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(SaslAuthenticationHandler.class);
private final List<SaslMechanism> clientMechanismsPreference;
private SaslState currentState = SaslState.IDLE;
private SaslMechanism currentMechanism;
private List<String> serverSupportedMechanisms;
private int currentMechanismAttemptIndex = 0;
public SaslAuthenticationHandler(List<SaslMechanism> preferredClientMechanisms) {
Objects.requireNonNull(preferredClientMechanisms, "Client mechanisms list cannot be null.");
this.clientMechanismsPreference =
preferredClientMechanisms.isEmpty()
? List.of(
new ExternalSaslMechanism(),
new CookieSaslMechanism(),
new AnonymousSaslMechanism())
: new ArrayList<>(preferredClientMechanisms);
}
public SaslAuthenticationHandler() {
this(Collections.emptyList());
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
// Handle reconnection events
if (evt == DBusChannelEvent.RECONNECTION_STARTING) {
reset();
ctx.fireUserEventTriggered(evt);
return;
}
if (evt == DBusChannelEvent.SASL_NUL_BYTE_SENT && currentState == SaslState.IDLE) {
LOGGER.debug(LoggerUtils.HANDLER_LIFECYCLE, "SASL_NUL_BYTE_SENT event received.");
SaslMessage authMsg = new SaslMessage(SaslCommandName.AUTH, null);
ctx.writeAndFlush(authMsg).addListener(new WriteOperationListener<>(LOGGER));
currentState = SaslState.AWAITING_SERVER_MECHS;
} else {
ctx.fireUserEventTriggered(evt);
}
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof SaslMessage saslMessage) {
switch (currentState) {
case AWAITING_SERVER_MECHS, NEGOTIATING ->
handleSaslServerResponse(ctx, saslMessage);
default -> {
LOGGER.warn(
LoggerUtils.SASL,
"Received command '{}' in unexpected state: {}.",
saslMessage.getCommandName(),
currentState);
if (EnumSet.of(SaslState.AUTHENTICATED, SaslState.FAILED)
.contains(currentState)) {
LOGGER.warn(
LoggerUtils.SASL,
"Ignoring command '{}' as state is already {}.",
saslMessage.getCommandName(),
currentState);
} else {
failAuthentication(
ctx,
"Unexpected command '"
+ saslMessage.getCommandName()
+ "' in state "
+ currentState);
}
}
}
} else {
LOGGER.warn(
LoggerUtils.SASL,
"Received unexpected non-SASL message type during SASL: {}",
msg.getClass().getName());
if (currentState == SaslState.AUTHENTICATED) {
LOGGER.debug(
LoggerUtils.SASL,
"Passing non-string message up, assuming post-SASL and pre-DBus-pipeline data.");
ctx.fireChannelRead(msg);
} else {
failAuthentication(ctx, "Received non-string data during active SASL exchange.");
}
}
}
private void handleSaslServerResponse(ChannelHandlerContext ctx, SaslMessage msg) {
SaslCommandName command = msg.getCommandName();
String args = msg.getCommandArgs().orElse("");
switch (command) {
case OK -> {
LOGGER.info(LoggerUtils.SASL, "Server send OK. Sending BEGIN.");
LOGGER.debug(LoggerUtils.SASL, "Server GUID: {}", args);
SaslMessage beginMsg = new SaslMessage(SaslCommandName.BEGIN, null);
ctx.writeAndFlush(beginMsg)
.addListener(
new WriteOperationListener<>(
LOGGER,
future -> {
if (future.isSuccess()) {
LOGGER.debug(
LoggerUtils.SASL,
"BEGIN command sent successfully.");
currentState = SaslState.AUTHENTICATED;
cleanupAndSignalCompletion(ctx, true);
} else {
LOGGER.error(
LoggerUtils.SASL,
"Failed to send BEGIN: {}",
future.cause().getMessage());
failAuthentication(
ctx,
"Failed to send BEGIN: "
+ future.cause().getMessage());
}
}));
}
case REJECTED -> {
serverSupportedMechanisms = Arrays.asList(args.split(" "));
LOGGER.warn(
LoggerUtils.SASL,
"Mechanism rejected. Server supports: {}",
serverSupportedMechanisms);
disposeCurrentMechanism();
tryNextMechanism(ctx);
}
case DATA -> {
if (currentMechanism == null || currentState != SaslState.NEGOTIATING) {
failAuthentication(
ctx, "Unexpected DATA command without mechanism or wrong state.");
return;
}
currentMechanism
.processChallengeAsync(ctx, args)
.addListener(
future -> {
if (!ctx.channel().isActive()) {
return;
}
if (future.isSuccess()) {
var responseHex = (String) future.getNow();
if (responseHex != null) {
SaslMessage dataMsg =
new SaslMessage(
SaslCommandName.DATA, responseHex);
ctx.writeAndFlush(dataMsg)
.addListener(
new WriteOperationListener<>(LOGGER));
} else {
LOGGER.debug(
LoggerUtils.SASL,
"Mechanism {} complete, awaiting server response.",
currentMechanism.getName());
}
} else {
LOGGER.error(
LoggerUtils.SASL,
"Failed to process challenge with mechanism {}",
currentMechanism.getName(),
future.cause());
SaslMessage cancelMsg =
new SaslMessage(SaslCommandName.CANCEL, null);
ctx.writeAndFlush(cancelMsg)
.addListener(new WriteOperationListener<>(LOGGER));
}
});
}
case ERROR -> {
LOGGER.error(LoggerUtils.SASL, "Server send ERROR: {}", args);
SaslMessage cancelMsg = new SaslMessage(SaslCommandName.CANCEL, null);
ctx.writeAndFlush(cancelMsg).addListener(new WriteOperationListener<>(LOGGER));
}
case AGREE_UNIX_FD ->
LOGGER.info(LoggerUtils.SASL, "Server agreed to UNIX FD passing.");
default -> {
if (currentState == SaslState.AWAITING_SERVER_MECHS
&& command.name().matches("[A-Z0-9_]+([-A-Z0-9_]*[A-Z0-9_]+)?")) {
serverSupportedMechanisms = Arrays.asList(msg.toString().split(" "));
LOGGER.debug(
LoggerUtils.SASL, "Server mechanisms: {}", serverSupportedMechanisms);
currentMechanismAttemptIndex = 0;
tryNextMechanism(ctx);
} else {
failAuthentication(
ctx, "Unexpected command: " + command + " with args: " + args);
}
}
}
}
private void tryNextMechanism(ChannelHandlerContext ctx) {
if (serverSupportedMechanisms == null) {
failAuthentication(ctx, "No server mechanisms provided.");
return;
}
disposeCurrentMechanism();
if (currentMechanismAttemptIndex < clientMechanismsPreference.size()) {
var candidate = clientMechanismsPreference.get(currentMechanismAttemptIndex++);
if (serverSupportedMechanisms.contains(candidate.getName())) {
currentMechanism = candidate;
LOGGER.info(LoggerUtils.SASL, "Trying mechanism: {}", candidate.getName());
try {
currentMechanism.init(ctx);
currentMechanism
.getInitialResponseAsync(ctx)
.addListener(
future -> {
if (!ctx.channel().isActive()) {
return;
}
if (future.isSuccess()) {
String initialResponse = (String) future.getNow();
String value =
initialResponse != null
&& !initialResponse.isEmpty()
? initialResponse
: null;
String commandArgs = candidate.getName();
if (value != null) {
commandArgs += " " + value;
}
SaslMessage authMsg =
new SaslMessage(
SaslCommandName.AUTH, commandArgs);
ctx.writeAndFlush(authMsg)
.addListener(
new WriteOperationListener<>(LOGGER));
currentState = SaslState.NEGOTIATING;
} else {
LOGGER.error(
LoggerUtils.SASL,
"Failed to get initial response for {}: {}",
candidate.getName(),
future.cause().getMessage());
tryNextMechanism(ctx);
}
});
} catch (SaslMechanismException e) {
LOGGER.warn(
LoggerUtils.SASL,
"Initialization failed for {}: {}",
candidate.getName(),
e.getMessage());
tryNextMechanism(ctx);
}
}
} else {
failAuthentication(ctx, "No compatible SASL mechanism found.");
}
}
private void failAuthentication(ChannelHandlerContext ctx, String reason) {
if (currentState == SaslState.FAILED) {
return;
}
LOGGER.error(LoggerUtils.SASL, "Authentication failed: {}", reason);
currentState = SaslState.FAILED;
if (currentMechanism != null) {
SaslMessage responseMsg = new SaslMessage(SaslCommandName.CANCEL, null);
ctx.writeAndFlush(responseMsg)
.addListener(
new WriteOperationListener<>(
LOGGER, f -> cleanupAndSignalCompletion(ctx, false)));
} else {
cleanupAndSignalCompletion(ctx, false);
}
}
private void cleanupAndSignalCompletion(ChannelHandlerContext ctx, boolean success) {
disposeCurrentMechanism();
// Fire the event to notify other handlers before removing self
if (success) {
LOGGER.info(LoggerUtils.SASL, "Authentication completed successfully.");
ctx.pipeline().fireUserEventTriggered(DBusChannelEvent.SASL_AUTH_COMPLETE);
} else {
LOGGER.error(LoggerUtils.SASL, "Authentication failed.");
ctx.pipeline().fireUserEventTriggered(DBusChannelEvent.SASL_AUTH_FAILED);
}
// Remove this handler from the pipeline as SASL phase is complete
ctx.pipeline().remove(this);
LOGGER.debug(
LoggerUtils.HANDLER_LIFECYCLE,
"Removed SASL authentication handler from pipeline as SASL phase is complete.");
}
private void disposeCurrentMechanism() {
if (currentMechanism != null) {
try {
currentMechanism.dispose();
} catch (Exception e) {
LOGGER.warn(
LoggerUtils.SASL,
"Error disposing mechanism {}: {}",
currentMechanism.getName(),
e.getMessage());
}
currentMechanism = null;
}
}
/**
* Resets the SASL handler to its initial state for reconnection. This method is called when the
* connection needs to be re-established.
*/
public void reset() {
LOGGER.debug(LoggerUtils.SASL, "Resetting SASL handler for reconnection");
// Reset state
currentState = SaslState.IDLE;
currentMechanismAttemptIndex = 0;
serverSupportedMechanisms = null;
// Dispose current mechanism
disposeCurrentMechanism();
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
LOGGER.warn(
LoggerUtils.SASL,
"Channel became inactive during authentication. Current state: {}",
currentState);
if (!EnumSet.of(SaslState.AUTHENTICATED, SaslState.FAILED).contains(currentState)) {
failAuthentication(ctx, "Channel became inactive.");
}
disposeCurrentMechanism();
ctx.fireChannelInactive();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (currentState == SaslState.FAILED) {
LOGGER.debug(
LoggerUtils.SASL,
"Ignoring exception as SASL already in FAILED state: ",
cause);
return;
}
LOGGER.error(
LoggerUtils.SASL,
"Exception in SaslAuthenticationHandler. State: {}",
currentState,
cause);
failAuthentication(ctx, "Exception caught: " + cause.getMessage());
}
private enum SaslState {
IDLE,
AWAITING_SERVER_MECHS,
NEGOTIATING,
AWAITING_BEGIN_CONFIRMATION,
AUTHENTICATED,
FAILED
}
}