diff --git a/TerrariaServerAPI/TerrariaApi.Server/HandlerCollection.cs b/TerrariaServerAPI/TerrariaApi.Server/HandlerCollection.cs index 35e4b3a5..e9a5fd1a 100644 --- a/TerrariaServerAPI/TerrariaApi.Server/HandlerCollection.cs +++ b/TerrariaServerAPI/TerrariaApi.Server/HandlerCollection.cs @@ -22,6 +22,8 @@ internal HandlerCollection(string hookName) this.hookName = hookName; } + public int Count => this.registrations.Count; + public void Register(TerrariaPlugin registrator, HookHandler handler, int priority) { if (registrator == null) diff --git a/TerrariaServerAPI/TerrariaApi.Server/HookManager.cs b/TerrariaServerAPI/TerrariaApi.Server/HookManager.cs index ba14ebae..2bc1d544 100644 --- a/TerrariaServerAPI/TerrariaApi.Server/HookManager.cs +++ b/TerrariaServerAPI/TerrariaApi.Server/HookManager.cs @@ -1,5 +1,6 @@ -using Microsoft.Xna.Framework; +using Microsoft.Xna.Framework; using System; +using System.Buffers.Binary; using System.ComponentModel; using System.Diagnostics; using System.IO; @@ -16,6 +17,8 @@ namespace TerrariaApi.Server public class HookManager { + private static int netTextModuleId = -1; + public static void InitialiseAPI() { try @@ -408,7 +411,6 @@ internal bool InvokeNetGetData(ref byte msgId, MessageBuffer buffer, ref int ind // Ideally this check should occur in an OTAPI modification. if (length < 1) { - RemoteClient currentClient = Netplay.Clients[buffer.whoAmI]; Netplay.Clients[buffer.whoAmI].PendingTermination = true; return true; } @@ -419,7 +421,6 @@ internal bool InvokeNetGetData(ref byte msgId, MessageBuffer buffer, ref int ind // The length 1000 was chosen as an arbitrarily large number for all packets. It may need to be tuned later. if (length > 1000) { - RemoteClient currentClient = Netplay.Clients[buffer.whoAmI]; Netplay.Clients[buffer.whoAmI].PendingTermination = true; return true; } @@ -443,23 +444,26 @@ internal bool InvokeNetGetData(ref byte msgId, MessageBuffer buffer, ref int ind break; case PacketTypes.LoadNetModule: - using (var stream = new MemoryStream(buffer.readBuffer)) + if (!TryGetPacketSpan(buffer.readBuffer, index, length, out ReadOnlySpan modulePacket) || + modulePacket.Length < sizeof(ushort)) + { + return true; + } + + ushort moduleId = BinaryPrimitives.ReadUInt16LittleEndian(modulePacket); + // LoadNetModule is now used for sending chat text. + // Read the module ID to determine if this is the text module. + if (moduleId == GetNetTextModuleId()) { - stream.Position = index; + using (var stream = new MemoryStream(buffer.readBuffer, index, length, writable: false)) using (var reader = new BinaryReader(stream)) { - ushort moduleId = reader.ReadUInt16(); - //LoadNetModule is now used for sending chat text. - //Read the module ID to determine if this is in fact the text module - if (moduleId == Terraria.Net.NetManager.Instance.GetId()) - { - //Then deserialize the message from the reader - Terraria.Chat.ChatMessage msg = Terraria.Chat.ChatMessage.Deserialize(reader); + reader.ReadUInt16(); + Terraria.Chat.ChatMessage msg = Terraria.Chat.ChatMessage.Deserialize(reader); - if (InvokeServerChat(buffer, buffer.whoAmI, @msg.Text, msg.CommandId)) - { - return true; - } + if (InvokeServerChat(buffer, buffer.whoAmI, msg.Text, msg.CommandId)) + { + return true; } } } @@ -471,24 +475,25 @@ internal bool InvokeNetGetData(ref byte msgId, MessageBuffer buffer, ref int ind //Then the bytes get hashed, and set as ClientUUID (and gets written in DB for auto-login) //length minus 2 = 36, the length of a UUID. case PacketTypes.ClientUUID: - if (length == 38) + if (length == 38 && TryGetPacketSpan(buffer.readBuffer, index + 1, length - 2, out ReadOnlySpan uuidBytes)) { - byte[] uuid = new byte[length - 2]; - Buffer.BlockCopy(buffer.readBuffer, index + 1, uuid, 0, length - 2); - Guid guid = new Guid(); - if (Guid.TryParse(Encoding.Default.GetString(uuid, 0, uuid.Length), out guid)) + Span uuidChars = stackalloc char[36]; + int charsWritten = Encoding.ASCII.GetChars(uuidBytes, uuidChars); + if (charsWritten == 36 && Guid.TryParse(uuidChars, out _)) { - SHA512 shaM = new SHA512Managed(); - var result = shaM.ComputeHash(uuid); - Netplay.Clients[buffer.whoAmI].ClientUUID = result.Aggregate("", (s, b) => s + b.ToString("X2")); + Netplay.Clients[buffer.whoAmI].ClientUUID = Convert.ToHexString(SHA512.HashData(uuidBytes)); return true; } } - Netplay.Clients[buffer.whoAmI].ClientUUID = ""; + + Netplay.Clients[buffer.whoAmI].ClientUUID = string.Empty; return true; } } + if (this.netGetData.Count == 0) + return false; + GetDataEventArgs args = new GetDataEventArgs { MsgID = (PacketTypes)msgId, @@ -526,6 +531,27 @@ internal bool InvokeNetGreetPlayer(int who) return args.Handled; } + + private static bool TryGetPacketSpan(byte[] buffer, int offset, int length, out ReadOnlySpan span) + { + if (buffer == null || offset < 0 || length < 0 || offset > buffer.Length - length) + { + span = ReadOnlySpan.Empty; + return false; + } + + span = new ReadOnlySpan(buffer, offset, length); + return true; + } + + private static int GetNetTextModuleId() + { + if (netTextModuleId >= 0) + return netTextModuleId; + + netTextModuleId = Terraria.Net.NetManager.Instance.GetId(); + return netTextModuleId; + } #endregion #region NetSendBytes diff --git a/TerrariaServerAPI/TerrariaApi.Server/Hooking/NetHooks.cs b/TerrariaServerAPI/TerrariaApi.Server/Hooking/NetHooks.cs index a638047e..ee598591 100644 --- a/TerrariaServerAPI/TerrariaApi.Server/Hooking/NetHooks.cs +++ b/TerrariaServerAPI/TerrariaApi.Server/Hooking/NetHooks.cs @@ -1,5 +1,6 @@ -using OTAPI; +using OTAPI; using System; +using System.Collections; using Terraria; using Terraria.Net; @@ -8,6 +9,7 @@ namespace TerrariaApi.Server.Hooking; internal class NetHooks { private static HookManager _hookManager; + private static readonly BitArray knownPacketIds = BuildKnownPacketIds(); public static readonly object syncRoot = new(); @@ -128,7 +130,7 @@ static void OnReceiveData(object sender, Hooks.MessageBuffer.GetDataEventArgs e) { return; } - if (!Enum.IsDefined(typeof(PacketTypes), (int)e.PacketId)) + if ((uint)e.PacketId >= knownPacketIds.Length || !knownPacketIds.Get(e.PacketId)) { e.Result = HookResult.Cancel; } @@ -210,4 +212,25 @@ static int FindNextOpenClientSlot() } return -1; } + + private static BitArray BuildKnownPacketIds() + { + int maxPacketId = 0; + Array values = Enum.GetValues(typeof(PacketTypes)); + for (int i = 0; i < values.Length; i++) + { + int packetId = (int)values.GetValue(i); + if (packetId > maxPacketId) + maxPacketId = packetId; + } + + var map = new BitArray(maxPacketId + 1); + for (int i = 0; i < values.Length; i++) + { + int packetId = (int)values.GetValue(i); + map.Set(packetId, true); + } + + return map; + } }