InboundUtils.java
/*
* SPDX-FileCopyrightText: 2025 Lucimber UG
* SPDX-License-Identifier: Apache-2.0
*/
package com.lucimber.dbus.netty;
import com.lucimber.dbus.codec.decoder.DecoderException;
import com.lucimber.dbus.message.HeaderField;
import com.lucimber.dbus.message.MessageFlag;
import com.lucimber.dbus.message.MessageType;
import com.lucimber.dbus.type.DBusByte;
import com.lucimber.dbus.type.DBusStruct;
import com.lucimber.dbus.type.DBusType;
import com.lucimber.dbus.type.DBusVariant;
import io.netty.buffer.ByteBuf;
import java.nio.ByteOrder;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/** Utility class for common methods used for decoding and encoding messages. */
final class InboundUtils {
private static final int MAX_MSG_LENGTH = 0x08000000;
private static final int ZERO = 0x00000000;
private InboundUtils() {
// Utility class
}
static ByteOrder decodeByteOrder(final ByteBuf buffer) throws DecoderException {
Objects.requireNonNull(buffer, "buffer must not be null");
final byte B = 0x42;
final byte l = 0x6C;
final byte byteOrder = buffer.readByte();
if (byteOrder == B) {
return ByteOrder.BIG_ENDIAN;
} else if (byteOrder == l) {
return ByteOrder.LITTLE_ENDIAN;
} else {
throw new DecoderException("unknown byte order");
}
}
static MessageType decodeType(ByteBuf buffer) {
Objects.requireNonNull(buffer, "buffer must not be null");
byte ub = buffer.readByte();
return MessageType.fromCode(ub);
}
static Set<MessageFlag> decodeFlags(final ByteBuf buffer) {
Objects.requireNonNull(buffer, "buffer must not be null");
final byte flagsByte = buffer.readByte();
final Set<MessageFlag> flags = new HashSet<>();
final byte replyFlag = 0x01;
final byte startFlag = 0x02;
final byte authFlag = 0x04;
if ((flagsByte & replyFlag) == replyFlag) {
flags.add(MessageFlag.NO_REPLY_EXPECTED);
}
if ((flagsByte & startFlag) == startFlag) {
flags.add(MessageFlag.NO_AUTO_START);
}
if ((flagsByte & authFlag) == authFlag) {
flags.add(MessageFlag.ALLOW_INTERACTIVE_AUTHORIZATION);
}
return flags;
}
static Map<HeaderField, DBusVariant> mapHeaderFields(List<DBusStruct> headerFields)
throws DecoderException {
Objects.requireNonNull(headerFields, "headerFields must not be null");
Map<HeaderField, DBusVariant> map = new HashMap<>();
for (DBusStruct struct : headerFields) {
if (struct == null) {
throw new DecoderException("Header field struct cannot be null");
}
List<DBusType> structList = struct.getDelegate();
if (structList == null || structList.size() < 2) {
throw new DecoderException("Header field struct must contain at least 2 elements");
}
// Validate and extract header field code
DBusType firstElement = structList.get(0);
if (!(firstElement instanceof DBusByte)) {
throw new DecoderException(
"Header field code must be a DBusByte, got: "
+ (firstElement != null
? firstElement.getClass().getSimpleName()
: "null"));
}
DBusByte dbusByte = (DBusByte) firstElement;
HeaderField headerField = HeaderField.fromCode(dbusByte.getDelegate());
if (headerField == null) {
throw new DecoderException("Unknown header field code: " + dbusByte.getDelegate());
}
// Validate and extract header field variant
DBusType secondElement = structList.get(1);
if (!(secondElement instanceof DBusVariant)) {
throw new DecoderException(
"Header field value must be a DBusVariant, got: "
+ (secondElement != null
? secondElement.getClass().getSimpleName()
: "null"));
}
DBusVariant variant = (DBusVariant) secondElement;
map.put(headerField, variant);
}
return map;
}
static boolean isMessageTooLong(final int headerLength, final int bodyLength) {
// Validate input parameters to prevent integer overflow attacks
if (headerLength < 0) {
throw new IllegalArgumentException("Header length cannot be negative: " + headerLength);
}
if (bodyLength < 0) {
throw new IllegalArgumentException("Body length cannot be negative: " + bodyLength);
}
final int signature = 0x0C;
final int headerSignature = 0x08;
final int headerAlignment = 0x08;
final int headerRemainder = Integer.remainderUnsigned(headerLength, headerAlignment);
final int headerPadding = headerAlignment - headerRemainder;
// Check for potential integer overflow before performing arithmetic
int messageLength = signature + headerSignature;
if (Integer.MAX_VALUE - messageLength < headerLength) {
return true; // Would overflow, so message is too long
}
messageLength += headerLength;
if (Integer.compareUnsigned(headerRemainder, ZERO) > 0) {
if (Integer.MAX_VALUE - messageLength < headerPadding) {
return true; // Would overflow, so message is too long
}
messageLength += headerPadding;
}
if (Integer.MAX_VALUE - messageLength < bodyLength) {
return true; // Would overflow, so message is too long
}
messageLength += bodyLength;
return Integer.compareUnsigned(messageLength, MAX_MSG_LENGTH) > 0;
}
}