// ----------------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
// ----------------------------------------------------------------------------
///
// Microsoft Research, Cambridge
//
using NetStack.Common;
using System;
using System.Diagnostics;
using System.Net.IP;
using Drivers.Net;
using Microsoft.Singularity.Channels;
using Microsoft.Singularity;
#if !SINGULARITY
using System.Net;
using System.Net.Sockets;
#endif
namespace NetStack.Protocols
{
///
// A UDP formatter class
//
public class UdpFormat
{
// size of UDP header
public const int Size = 8;
// a UDP header
// (all header fields are given in host endianness)
public struct UdpHeader
{
public ushort srcPort;
public ushort dstPort;
public ushort length; // data+header
public ushort checksum;
}
// construct a UDP header from a packet
// if something is wrong, return the same start value
public static int ReadUdpHeader(byte[]! packet,
int udpOffset,
out UdpHeader udpHeader)
{
// initialize an UdpHeader struct
udpHeader = new UdpHeader();
udpHeader.srcPort =
(ushort)((int)(packet[udpOffset + 0] << 8) |
(int)(packet[udpOffset + 1]));
udpHeader.dstPort =
(ushort)((int)(packet[udpOffset + 2] << 8) |
(int)(packet[udpOffset + 3]));
udpHeader.length =
(ushort)((int)(packet[udpOffset + 4] << 8) |
(int)(packet[udpOffset + 5]));
udpHeader.checksum =
(ushort)((int)(packet[udpOffset + 6] << 8) |
(int)(packet[udpOffset + 7]));;
return udpOffset + Size;
}
public static bool ReadUdpHeader(IBuffer! buf,
out UdpHeader udpHeader)
{
udpHeader = new UdpHeader();
if (buf.ReadNet16(out udpHeader.srcPort) &&
buf.ReadNet16(out udpHeader.dstPort) &&
buf.ReadNet16(out udpHeader.length) &&
buf.ReadNet16(out udpHeader.checksum))
{
return true;
}
return false;
}
// writes a UDP header to a packet
// return the next place to write to
// the checksum must be later calculated
// (over all the octets of the pseudo header, UDP header and data)
public static int WriteUdpHeader(byte[]! buffer,
int offset,
ref UdpHeader header)
{
// check we have enough packet space
if (buffer.Length - offset < Size)
return offset;
int o = offset;
buffer[o++] = (byte)(((ushort)header.srcPort) >> 8);
buffer[o++] = (byte)(((ushort)header.srcPort) & 0xff);
buffer[o++] = (byte)(((ushort)header.dstPort) >> 8);
buffer[o++] = (byte)(((ushort)header.dstPort) & 0xff);
buffer[o++] = (byte)(((ushort)header.length) >> 8);
buffer[o++] = (byte)(((ushort)header.length) & 0xff);
// checksum
buffer[o++] = 0;
buffer[o++] = 0;
return o;
}
// set the checksum field, totalSize covers all the fields for
// which the checksum is calculated
// offset is points to the beginning of the IP header (!!!)
// Should be called after the UDP header + data have been written
public static void SetUdpChecksum(byte[]! packet,
int ipOffset,
ref UdpHeader udpHeader)
{
// sum IP pseudo
ushort ipPayloadSize = 0;
ushort headerSum = IPFormat.SumPseudoHeader(packet, ipOffset,
ref ipPayloadSize);
Debug.Assert(((ushort)udpHeader.length) == ipPayloadSize);
// now add it to the udp header + data
int ipHeaderSize = (packet[ipOffset] & 0xf) * 4;
int udpOffset = ipOffset + ipHeaderSize;
Debug.Assert(packet[udpOffset + 6] == 0);
Debug.Assert(packet[udpOffset + 7] == 0);
ushort payloadSum = IPFormat.SumShortValues(packet, udpOffset,
ipPayloadSize);
udpHeader.checksum = IPFormat.ComplementAndFixZeroChecksum(
IPFormat.SumShortValues(headerSum, payloadSum));
packet[udpOffset + 6] = (byte) (udpHeader.checksum >> 8);
packet[udpOffset + 7] = (byte) (udpHeader.checksum & 0xff);
}
///
/// Compute checksum of UDP header.
///
public static ushort SumHeader(UdpHeader udpHeader)
{
// Do not include existing checksum.
int x = udpHeader.srcPort + udpHeader.dstPort + udpHeader.length;
return (ushort) ((x >> 16) + x);
}
public static bool IsChecksumValid(IPFormat.IPHeader! ipHeader,
UdpHeader udpHeader,
NetPacket! payload)
{
// Compute partial checksums of headers
ushort checksum = IPFormat.SumPseudoHeader(ipHeader);
checksum = IPFormat.SumShortValues(checksum,
UdpFormat.SumHeader(udpHeader));
// Checksum payload
int length = payload.Available;
int end = length & ~1;
int i = 0;
while (i != end) {
int x = ((((int) payload.PeekAvailable(i++)) << 8) +
(int) payload.PeekAvailable(i++));
checksum = IPFormat.SumShortValues(checksum, (ushort) x);
}
if (i != length) {
int x = (((int) payload.PeekAvailable(i++)) << 8);
checksum = IPFormat.SumShortValues(checksum, (ushort) x);
}
checksum = IPFormat.ComplementAndFixZeroChecksum(checksum);
if (udpHeader.checksum != checksum) {
DebugStub.WriteLine("Bad UDP checksum {0:x4} != {1:x4}",
__arglist(udpHeader.checksum, checksum));
}
return udpHeader.checksum == checksum;
}
///
/// Write IP and UDP headers and payload data into a
/// byte array.
///
/// Array of bytes representing
/// packet to be sent.
/// Offset of IP Header within
/// packet.
/// IP header to be written
/// to packet.
/// UDP header to be written
/// to packet.
/// Payload of UDP Packet.
/// The offset of start
/// of the payload data within the payload
/// array.
/// The size of the payload data.
public static void WriteUdpPacket(byte[]! pkt,
int offset,
IPFormat.IPHeader! ipHeader,
ref UdpHeader udpHeader,
byte[] payload,
int payloadOffset,
int payloadLength)
{
int udpStart = IPFormat.WriteIPHeader(pkt, offset, ipHeader);
int udpEnd = WriteUdpHeader(pkt, udpStart, ref udpHeader);
if (pkt != payload || udpEnd != payloadOffset) {
Array.Copy(payload, payloadOffset,
pkt, udpEnd, payloadLength);
}
SetUdpChecksum(pkt, offset, ref udpHeader);
}
public static void WriteUdpPacket(byte[]! pkt,
int offset,
IPFormat.IPHeader! ipHeader,
ref UdpHeader udpHeader,
byte[]! in ExHeap payload,
int payloadOffset,
int payloadLength)
{
int udpStart = IPFormat.WriteIPHeader(pkt, offset, ipHeader);
int udpEnd = WriteUdpHeader(pkt, udpStart, ref udpHeader);
Bitter.ToByteArray(payload, payloadOffset, payloadLength,
pkt, udpEnd);
SetUdpChecksum(pkt, offset, ref udpHeader);
}
public static void WriteUdpPacket(byte[]! packet,
int ipOffset,
IPv4 srcAddress,
ushort srcPort,
IPv4 dstAddress,
ushort dstPort,
byte[] payload,
int payloadOffset,
int payloadLength)
{
IPFormat.IPHeader ipHeader = new IPFormat.IPHeader();
ipHeader.SetDefaults(IPFormat.Protocol.UDP);
ipHeader.Source = srcAddress;
ipHeader.Destination = dstAddress;
ipHeader.totalLength = (ushort)(IPFormat.Size +
UdpFormat.Size +
payloadLength);
UdpFormat.UdpHeader udpHeader = new UdpFormat.UdpHeader();
udpHeader.srcPort = srcPort;
udpHeader.dstPort = dstPort;
udpHeader.length = (ushort)(UdpFormat.Size + payloadLength);
WriteUdpPacket(packet, ipOffset, ipHeader, ref udpHeader,
payload, payloadOffset, payloadLength);
}
public static void WriteUdpPacket(byte[]! packet,
int ipOffset,
IPv4 srcAddress,
ushort srcPort,
IPv4 dstAddress,
ushort dstPort,
byte[]! in ExHeap payload,
int payloadOffset,
int payloadLength)
{
IPFormat.IPHeader ipHeader = new IPFormat.IPHeader();
ipHeader.SetDefaults(IPFormat.Protocol.UDP);
ipHeader.Source = srcAddress;
ipHeader.Destination = dstAddress;
ipHeader.totalLength = (ushort)(IPFormat.Size +
UdpFormat.Size +
payloadLength);
UdpFormat.UdpHeader udpHeader = new UdpFormat.UdpHeader();
udpHeader.srcPort = srcPort;
udpHeader.dstPort = dstPort;
udpHeader.length = (ushort)(UdpFormat.Size + payloadLength);
WriteUdpPacket(packet, ipOffset, ipHeader, ref udpHeader,
payload, payloadOffset, payloadLength);
}
}
}