How to compute IP and TCP checksum
Whenever I need to code something, the very first thing I do is google for a ready to use solution. Some time ago I had to calculate IP and TCP checksums. Every peace of code I googled was unreadable junk. I wonder is it really hard to write good looking readable code, so you can just copy and paste it?!
#pragma pack(push, 1) struct IPV4_HDR { uint8_t ihl : 4; uint8_t version : 4; uint8_t tos; uint16_t total_length; uint16_t identification; uint16_t fragment_offset : 13; uint16_t more_fragment : 1; uint16_t dont_fragment : 1; uint16_t reserved_zero : 1; uint8_t ttl; uint8_t proto; uint16_t header_checksum; uint32_t src_ip; uint32_t dst_ip; }; struct TCP_HDR { uint16_t src_port; uint16_t dst_port; uint32_t seqn; uint32_t ackn; uint8_t ns : 1; uint8_t reserved : 3; uint8_t data_offset : 4; uint8_t fin : 1; uint8_t syn : 1; uint8_t rst : 1; uint8_t psh : 1; uint8_t ack : 1; uint8_t urg : 1; uint8_t ecn : 1; uint8_t cwr : 1; uint16_t window; uint16_t checksum; uint16_t urgent_pointer; }; #pragma pack(pop) uint16_t compute_checksum(uint16_t const *data, size_t count) { uint32_t sum = 0; while (count > 1) { sum += *data++; count -= sizeof(uint16_t); } if (count > 0) sum += ((*data) & htons(0xFF00)); while (sum >> 16) sum = (sum & 0xFFFF) + (sum >> 16); sum = ~sum; return static_cast<uint16_t>(sum); } uint16_t compute_ip_checksum(IPV4_HDR *ip) { ip->header_checksum = 0U; return (ip->header_checksum = compute_checksum(reinterpret_cast<uint16_t const *>(ip), ip->ihl << 2)); } uint16_t compute_tcp_checksum(IPV4_HDR const *ip, uint16_t *payload) { uint32_t sum = 0; uint16_t tcp_len = ntohs(ip->total_length) - (ip->ihl << 2); auto const tcp = reinterpret_cast<TCP_HDR *>(payload); sum += (ip->src_ip >> 16) & 0xFFFF; sum += (ip->src_ip) & 0xFFFF; sum += (ip->dst_ip >> 16) & 0xFFFF; sum += (ip->dst_ip) & 0xFFFF; sum += htons(IPPROTO_TCP); sum += htons(tcp_len); tcp->checksum = 0; while (tcp_len > 1) { sum += *payload++; tcp_len -= sizeof(uint16_t); } if (tcp_len > 0) sum += ((*payload) & htons(0xFF00)); while (sum >> 16) sum = (sum & 0xffff) + (sum >> 16); sum = ~sum; return (tcp->checksum = static_cast<uint16_t>(sum)); }