Hungry Mind , Blog about everything in IT - C#, Java, C++, .NET, Windows, WinAPI, ...

How to compute IP and TCP checksum

Whenever I need to code something, the very first thing I do is google for a ready to use solution. Some time ago I had to calculate IP and TCP checksums. Every peace of code I googled was unreadable junk. I wonder is it really hard to write good looking readable code, so you can just copy and paste it?!

#pragma pack(push, 1)
 
struct IPV4_HDR
{
    uint8_t ihl : 4;
    uint8_t version : 4;
    uint8_t tos;
    uint16_t total_length;
    uint16_t identification;
 
    uint16_t fragment_offset : 13;
 
    uint16_t more_fragment : 1;
    uint16_t dont_fragment : 1;
    uint16_t reserved_zero : 1;
 
    uint8_t ttl;
    uint8_t proto;
    uint16_t header_checksum;
    uint32_t src_ip;
    uint32_t dst_ip;
};
 
struct TCP_HDR
{
    uint16_t src_port;
    uint16_t dst_port;
    uint32_t seqn;
    uint32_t ackn;
 
    uint8_t ns : 1;
    uint8_t reserved : 3;
    uint8_t data_offset : 4;
 
    uint8_t fin : 1;
    uint8_t syn : 1;
    uint8_t rst : 1;
    uint8_t psh : 1;
    uint8_t ack : 1;
    uint8_t urg : 1;
 
    uint8_t ecn : 1;
    uint8_t cwr : 1;
 
    uint16_t window;
    uint16_t checksum;
    uint16_t urgent_pointer;
};
 
#pragma pack(pop)
 
uint16_t compute_checksum(uint16_t const *data, size_t count)
{
    uint32_t sum = 0;
    while (count > 1) {
        sum += *data++;
        count -= sizeof(uint16_t);
    }
    if (count > 0) sum += ((*data) & htons(0xFF00));
    while (sum >> 16) sum = (sum & 0xFFFF) + (sum >> 16);
    sum = ~sum;
    return static_cast<uint16_t>(sum);
}
 
uint16_t compute_ip_checksum(IPV4_HDR *ip)
{
    ip->header_checksum = 0U;
    return (ip->header_checksum = compute_checksum(reinterpret_cast<uint16_t const *>(ip), ip->ihl << 2));
}
 
uint16_t compute_tcp_checksum(IPV4_HDR const *ip, uint16_t *payload)
{
    uint32_t sum = 0;
    uint16_t tcp_len = ntohs(ip->total_length) - (ip->ihl << 2);
    auto const tcp = reinterpret_cast<TCP_HDR *>(payload);
    sum += (ip->src_ip >> 16) & 0xFFFF;
    sum += (ip->src_ip) & 0xFFFF;
    sum += (ip->dst_ip >> 16) & 0xFFFF;
    sum += (ip->dst_ip) & 0xFFFF;
    sum += htons(IPPROTO_TCP);
    sum += htons(tcp_len);
 
    tcp->checksum = 0;
    while (tcp_len > 1) {
        sum += *payload++;
        tcp_len -= sizeof(uint16_t);
    }
    if (tcp_len > 0) sum += ((*payload) & htons(0xFF00));
    while (sum >> 16) sum = (sum & 0xffff) + (sum >> 16);
    sum = ~sum;
    return (tcp->checksum = static_cast<uint16_t>(sum));
}
Copyright 2007-2011 Chabster