Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

CFLAGS := -Wall -Wextra -Wshadow -Wimplicit-fallthrough=0 -ansi -fshort-enums -fpic
CFLAGS := -Wall -Wextra -Wshadow -ansi -fshort-enums -fpic

all: build tests example

Expand Down
106 changes: 76 additions & 30 deletions src/proxy_protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ static uint8_t parse_port(const char *value, uint16_t *usport)
{
char *endptr = NULL;
uint64_t port = strtoul(value, &endptr, 10);
if (endptr == value || port > UINT16_MAX)
if (endptr == value || *endptr != '\0' || *value == '+' || port > UINT16_MAX)
{
return 0;
}
Expand Down Expand Up @@ -286,7 +286,12 @@ static uint8_t tlv_array_append_tlv_new(tlv_array_t *tlv_array, uint8_t type, ui

static uint8_t tlv_array_append_tlv_new_usascii(tlv_array_t *tlv_array, uint8_t type, uint16_t length, const void *value)
{
pp2_tlv_t *tlv = tlv_new(type, length + 1, value);
pp2_tlv_t *tlv;
if (length == UINT16_MAX)
{
return 0;
}
tlv = tlv_new(type, length + 1, value);
if (!tlv)
{
return 0;
Expand Down Expand Up @@ -409,20 +414,42 @@ uint8_t pp_info_add_ssl(pp_info_t *pp_info, const char *version, const char *cip

uint8_t pp_info_add_netns(pp_info_t *pp_info, const char *netns)
{
return tlv_array_append_tlv_new(&pp_info->pp2_info.tlv_array, PP2_TYPE_NETNS, (uint16_t) strlen(netns), netns);
size_t netns_len;
if (!netns)
{
return 0;
}
netns_len = strlen(netns);
if (netns_len > UINT16_MAX)
{
return 0;
}
return tlv_array_append_tlv_new(&pp_info->pp2_info.tlv_array, PP2_TYPE_NETNS, (uint16_t) netns_len, netns);
}

uint8_t pp_info_add_aws_vpce_id(pp_info_t *pp_info, const char *vpce_id)
{
uint16_t length = sizeof_pp2_tlv_aws_t + (uint16_t) strlen(vpce_id);
pp2_tlv_aws_t *pp2_tlv_aws = malloc(length);
size_t vpce_id_len;
uint16_t length;
pp2_tlv_aws_t *pp2_tlv_aws;
uint8_t rc;
if (!vpce_id)
{
return 0;
}
vpce_id_len = strlen(vpce_id);
if (vpce_id_len > UINT16_MAX - sizeof_pp2_tlv_aws_t)
{
return 0;
}
length = (uint16_t)(sizeof_pp2_tlv_aws_t + vpce_id_len);
pp2_tlv_aws = malloc(length);
if (!pp2_tlv_aws)
{
return 0;
}
pp2_tlv_aws->type = PP2_SUBTYPE_AWS_VPCE_ID;
memcpy(pp2_tlv_aws->value, vpce_id, strlen(vpce_id));
memcpy(pp2_tlv_aws->value, vpce_id, vpce_id_len);
rc = tlv_array_append_tlv_new(&pp_info->pp2_info.tlv_array, PP2_TYPE_AWS, length, pp2_tlv_aws);
free(pp2_tlv_aws);
return rc;
Expand Down Expand Up @@ -674,7 +701,8 @@ static uint32_t crc32c(const uint8_t* buf, uint32_t len)
static uint8_t *pp2_create_hdr(const pp_info_t *pp_info, uint16_t *pp2_hdr_len, int32_t *error)
{
proxy_hdr_v2_t proxy_hdr_v2 = { PP2_SIG, '\x21', 0, 0 };
uint16_t proxy_addr_len, len, padding_bytes, index;
uint16_t proxy_addr_len, padding_bytes, index;
uint32_t len;
proxy_addr_t proxy_addr;
const tlv_array_t *tlv_array;
uint32_t i;
Expand Down Expand Up @@ -755,7 +783,12 @@ static uint8_t *pp2_create_hdr(const pp_info_t *pp_info, uint16_t *pp2_hdr_len,
{
len += sizeof_pp2_tlv_t + sizeof(uint32_t);
}
*pp2_hdr_len = sizeof(proxy_hdr_v2_t) + len;
if (sizeof(proxy_hdr_v2_t) + len > UINT16_MAX)
{
*error = -ERR_PP2_LENGTH;
return NULL;
}
*pp2_hdr_len = (uint16_t)(sizeof(proxy_hdr_v2_t) + len);
if (pp_info->pp2_info.alignment_power > 1)
{
uint16_t alignment = 1 << pp_info->pp2_info.alignment_power;
Expand All @@ -767,13 +800,13 @@ static uint8_t *pp2_create_hdr(const pp_info_t *pp_info, uint16_t *pp2_hdr_len,
{
pp2_hdr_len_padded += alignment;
}
padding_bytes = pp2_hdr_len_padded - sizeof(proxy_hdr_v2_t) - len - sizeof_pp2_tlv_t;
padding_bytes = pp2_hdr_len_padded - (uint16_t)sizeof(proxy_hdr_v2_t) - (uint16_t)len - sizeof_pp2_tlv_t;

*pp2_hdr_len = pp2_hdr_len_padded;
len = pp2_hdr_len_padded - sizeof(proxy_hdr_v2_t);
len = pp2_hdr_len_padded - (uint32_t)sizeof(proxy_hdr_v2_t);
}
}
proxy_hdr_v2.len = htons(len);
proxy_hdr_v2.len = htons((uint16_t)len);

/* Create the PROXY protocol header */
pp2_hdr = malloc(*pp2_hdr_len);
Expand All @@ -795,7 +828,7 @@ static uint8_t *pp2_create_hdr(const pp_info_t *pp_info, uint16_t *pp2_hdr_len,
memcpy(pp2_hdr + index, tlv_array->tlvs[i], tlv_len);
index += tlv_len;
}
if (pp_info->pp2_info.alignment_power > 1)
if (pp_info->pp2_info.alignment_power > 1 && padding_bytes > 0)
{
pp2_tlv_t tlv = { 0 };
tlv.type = PP2_TYPE_NOOP;
Expand Down Expand Up @@ -884,7 +917,7 @@ static uint8_t *pp1_create_hdr(const pp_info_t *pp_info, uint16_t *pp1_hdr_len,
}
memcpy(src_addr, pp_info->src_addr, sizeof(src_addr));
memcpy(dst_addr, pp_info->dst_addr, sizeof(dst_addr));
*pp1_hdr_len = _sprintf(block, "PROXY %s %s %s %hu %hu"CRLF, fam, src_addr, dst_addr, pp_info->src_port, pp_info->dst_port);
*pp1_hdr_len = (uint16_t)_sprintf(block, "PROXY %s %s %s %hu %hu"CRLF, fam, src_addr, dst_addr, pp_info->src_port, pp_info->dst_port);
}
else
{
Expand Down Expand Up @@ -1055,12 +1088,12 @@ static int32_t pp2_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
}

/* TLVs */
/* Any TLV vector must be at least 3 bytes */
while (tlv_vectors_len > sizeof_pp2_tlv_t)
/* Any TLV vector must be at least 3 bytes (sizeof_pp2_tlv_t) */
while (tlv_vectors_len >= sizeof_pp2_tlv_t)
{
const pp2_tlv_t *pp2_tlv = (const pp2_tlv_t*) buffer;
uint16_t pp2_tlv_len = pp2_tlv->length_hi << 8 | pp2_tlv->length_lo;
uint16_t pp2_tlv_offset = sizeof_pp2_tlv_t + pp2_tlv_len;
uint32_t pp2_tlv_offset = (uint32_t) sizeof_pp2_tlv_t + pp2_tlv_len;
if (pp2_tlv_offset > tlv_vectors_len)
{
return -ERR_PP2_TLV_LENGTH;
Expand Down Expand Up @@ -1094,7 +1127,7 @@ static int32_t pp2_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
* Instead of zeroing the field in the buffer, compute CRC in 3 segments:
* before the checksum value, 4 zero bytes, after the checksum value. */
total_hdr_len = sizeof(proxy_hdr_v2_t) + len;
offset_to_chksum_value = (const uint8_t*)pp2_tlv->value - pp2_hdr;
offset_to_chksum_value = (uint32_t)((const uint8_t*)pp2_tlv->value - pp2_hdr);
after_chksum_offset = offset_to_chksum_value + sizeof(uint32_t);

crc = crc32c_continue(0xffffffff, pp2_hdr, offset_to_chksum_value);
Expand Down Expand Up @@ -1129,25 +1162,39 @@ static int32_t pp2_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
break;
case PP2_TYPE_SSL:
{
const pp2_tlv_ssl_t *pp2_tlv_ssl = (const pp2_tlv_ssl_t*) pp2_tlv->value;
const pp2_tlv_ssl_t *pp2_tlv_ssl;
uint16_t pp2_tlvs_ssl_len = 0, pp2_sub_tlv_offset = 0;
uint8_t tlv_ssl_version_found = 0;

if (pp2_tlv_len < sizeof(((pp2_tlv_ssl_t*)0)->client) + sizeof(((pp2_tlv_ssl_t*)0)->verify))
{
return -ERR_PP2_TYPE_SSL;
}

pp2_tlv_ssl = (const pp2_tlv_ssl_t*) pp2_tlv->value;

/* Set the pp2_ssl_info */
pp_info->pp2_info.pp2_ssl_info.ssl = !!(pp2_tlv_ssl->client & PP2_CLIENT_SSL);
pp_info->pp2_info.pp2_ssl_info.cert_in_connection = !!(pp2_tlv_ssl->client & PP2_CLIENT_CERT_CONN);
pp_info->pp2_info.pp2_ssl_info.cert_in_session = !!(pp2_tlv_ssl->client & PP2_CLIENT_CERT_SESS);
pp_info->pp2_info.pp2_ssl_info.cert_verified = !pp2_tlv_ssl->verify;

pp2_tlvs_ssl_len = pp2_tlv_len - sizeof(pp2_tlv_ssl->client) - sizeof(pp2_tlv_ssl->verify);
while (pp2_sub_tlv_offset < pp2_tlvs_ssl_len)
while (pp2_sub_tlv_offset + sizeof_pp2_tlv_t <= pp2_tlvs_ssl_len)
{
const pp2_tlv_t *pp2_sub_tlv_ssl = (const pp2_tlv_t*) ((const uint8_t*) pp2_tlv_ssl->sub_tlv + pp2_sub_tlv_offset);
uint16_t pp2_sub_tlv_ssl_len = pp2_sub_tlv_ssl->length_hi << 8 | pp2_sub_tlv_ssl->length_lo;
if ((uint32_t) sizeof_pp2_tlv_t + pp2_sub_tlv_ssl_len > (uint32_t)(pp2_tlvs_ssl_len - pp2_sub_tlv_offset))
{
return -ERR_PP2_TYPE_SSL;
}
if (pp2_sub_tlv_ssl->type == PP2_SUBTYPE_SSL_VERSION)
{
tlv_ssl_version_found = 1;
}
switch (pp2_sub_tlv_ssl->type)
{
case PP2_SUBTYPE_SSL_VERSION: /* US-ASCII */
tlv_ssl_version_found = 1;
case PP2_SUBTYPE_SSL_CIPHER: /* US-ASCII */
case PP2_SUBTYPE_SSL_SIG_ALG: /* US-ASCII */
case PP2_SUBTYPE_SSL_KEY_ALG: /* US-ASCII */
Expand All @@ -1171,7 +1218,7 @@ static int32_t pp2_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i

pp2_sub_tlv_offset += sizeof_pp2_tlv_t + pp2_sub_tlv_ssl_len;
}
if (pp2_sub_tlv_offset > pp2_tlvs_ssl_len || (pp_info->pp2_info.pp2_ssl_info.ssl && !tlv_ssl_version_found))
if (pp2_sub_tlv_offset != pp2_tlvs_ssl_len || (pp_info->pp2_info.pp2_ssl_info.ssl && !tlv_ssl_version_found))
{
return -ERR_PP2_TYPE_SSL;
}
Expand Down Expand Up @@ -1213,8 +1260,7 @@ static int32_t pp2_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
/* Connection is done through Private Link service */
if (pp2_tlv_azure->type == PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID) /* 32-bit number */
{
pp2_tlv_t *tlv = tlv_new(pp2_tlv->type, pp2_tlv_len, pp2_tlv->value);
if (!tlv || !tlv_array_append_tlv(&pp_info->pp2_info.tlv_array, tlv))
if (!tlv_array_append_tlv_new(&pp_info->pp2_info.tlv_array, pp2_tlv->type, pp2_tlv_len, pp2_tlv->value))
{
return -ERR_HEAP_ALLOC;
}
Expand Down Expand Up @@ -1248,15 +1294,15 @@ static int32_t pp1_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
char src_port_str[6] = { 0 };
char dst_port_str[6] = { 0 };

memcpy(block, buffer, buffer_length < PP1_MAX_LENGTH ? buffer_length : PP1_MAX_LENGTH);
memcpy(block, buffer, buffer_length < PP1_MAX_LENGTH - 1 ? buffer_length : PP1_MAX_LENGTH - 1);

block_end = strstr(block, CRLF);
if (!block_end)
{
return -ERR_PP1_CRLF;
}
block_end += strlen(CRLF);
pp1_hdr_len = block_end - block;
pp1_hdr_len = (int32_t)(block_end - block);

/* PROXY */
if (memcmp(block, "PROXY", 5))
Expand All @@ -1277,7 +1323,7 @@ static int32_t pp1_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
if (!inet_family)
{
/* Unknown connection (short form) */
if (pp1_hdr_len == 15 || !memcmp(ptr, "UNKNOWN", 7))
if (pp1_hdr_len == 15 && !memcmp(ptr, "UNKNOWN", 7))
{
pp_info->address_family = ADDR_FAMILY_UNSPEC;
pp_info->transport_protocol = TRANSPORT_PROTOCOL_UNSPEC;
Expand Down Expand Up @@ -1324,7 +1370,7 @@ static int32_t pp1_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
{
return sa_family == AF_INET ? -ERR_PP1_IPV4_SRC_IP : -ERR_PP1_IPV6_SRC_IP;
}
src_address_length = src_address_end - ptr;
src_address_length = (uint16_t)(src_address_end - ptr);
memcpy(pp_info->src_addr, ptr, src_address_length);
if (inet_pton(sa_family, pp_info->src_addr, &src_sin_addr) != 1)
{
Expand All @@ -1345,7 +1391,7 @@ static int32_t pp1_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
{
return sa_family == AF_INET ? -ERR_PP1_IPV4_DST_IP : -ERR_PP1_IPV6_DST_IP;
}
dst_address_length = dst_address_end - ptr;
dst_address_length = (uint16_t)(dst_address_end - ptr);
memcpy(pp_info->dst_addr, ptr, dst_address_length);
if (inet_pton(sa_family, pp_info->dst_addr, &dst_sin_addr) != 1)
{
Expand All @@ -1366,7 +1412,7 @@ static int32_t pp1_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
{
return -ERR_PP1_SRC_PORT;
}
src_port_length = src_port_end - ptr;
src_port_length = (uint16_t)(src_port_end - ptr);
if (src_port_length == 0 || src_port_length > 5)
{
return -ERR_PP1_SRC_PORT;
Expand All @@ -1391,7 +1437,7 @@ static int32_t pp1_parse_hdr(const uint8_t *buffer, uint32_t buffer_length, pp_i
{
return -ERR_PP1_DST_PORT;
}
dst_port_length = dst_port_end - ptr;
dst_port_length = (uint16_t)(dst_port_end - ptr);
if (dst_port_length == 0 || dst_port_length > 5)
{
return -ERR_PP1_DST_PORT;
Expand Down
Loading
Loading