""" Module for helper classes to different data structures. """ __all__ = [ 'BytesReader', ] import io import struct from typing import Any, Optional, Tuple, Type, TypeVar, Union from .. import constants _T = TypeVar('_T') class BytesReader(io.BytesIO): """ Extension of io.BytesIO that allows you to read specific data types from the stream. """ def __init__(self, *args, littleEndian: bool = True, **kwargs): super().__init__(*args, **kwargs) self.__le = bool(littleEndian) if self.__le: self.__int8_t = constants.st.ST_LE_I8 self.__int16_t = constants.st.ST_LE_I16 self.__int32_t = constants.st.ST_LE_I32 self.__int64_t = constants.st.ST_LE_I64 self.__uint8_t = constants.st.ST_LE_UI8 self.__uint16_t = constants.st.ST_LE_UI16 self.__uint32_t = constants.st.ST_LE_UI32 self.__uint64_t = constants.st.ST_LE_UI64 self.__float_t = constants.st.ST_LE_F32 self.__double_t = constants.st.ST_LE_F64 else: self.__int8_t = constants.st.ST_BE_I8 self.__int16_t = constants.st.ST_BE_I16 self.__int32_t = constants.st.ST_BE_I32 self.__int64_t = constants.st.ST_BE_I64 self.__uint8_t = constants.st.ST_BE_UI8 self.__uint16_t = constants.st.ST_BE_UI16 self.__uint32_t = constants.st.ST_BE_UI32 self.__uint64_t = constants.st.ST_BE_UI64 self.__float_t = constants.st.ST_BE_F32 self.__double_t = constants.st.ST_BE_F64 def _readDecodedString(self, encoding: str, width: int = 1) -> str: """ Reads a null terminated string with the specified character width decoded using the specified encoding. If it cannot be read or cannot be decoded then the position of the read pointer will not be changed. """ position = self.tell() try: return self.readByteString(width).decode(encoding) except Exception: while self.tell() != position: self.seek(position) raise def assertNull(self, length: int, errorMsg: Optional[str] = None) -> bytes: """ Reads the number of bytes specified and ensures they are all null. Ensures the reader returns back to the spot before attempting to read if there are not enough bytes to read. :param length: The amount of bytes to read. :param errorMsg: Optional, the error message to use if the bytes are not all null. :returns: The bytes read, if you need them. :raise IOError: Not enough bytes left to read. :raises ValueError: Assertion failed. """ # Quick return for reading 0 bytes. if length == 0: return b'' valueRead = self.tryReadBytes(length) if valueRead: if sum(valueRead) != 0: errorMsg = errorMsg or 'Bytes read were not all null.' raise ValueError(errorMsg) else: raise IOError('Not enough bytes left in buffer.') return valueRead def assertRead(self, value: bytes, errorMsg: Optional[str] = None) -> bytes: """ Reads the number of bytes and compares them to the value provided. If it does not match, throws a value error. Ensures the reader returns back to the spot before attempting to read if there are not enough bytes to read. :param value: Value to compare read bytes to. :param errorMsg: Optional, an error message to emit on mismatch. Does not apply to the buffer being too small. Allows for a format string with the keyword values "expected" and "actual", representing the value given to the function and the actual value read, respectively. :returns: The bytes read, if you need them. :raises TypeError: The value given was not bytes. :raises ValueError: Assertion failed. """ # Quick return for a value being empty. if len(value) == 0: return b'' if not isinstance(value, bytes): # pyright: ignore raise TypeError(':param value: was not bytes.') valueRead = self.tryReadBytes(len(value)) if valueRead: if valueRead != value: errorMsg = errorMsg or 'Value did not match (expected {expected}, got {actual}).' raise ValueError(errorMsg.format(expected = value, actual = valueRead)) else: raise IOError('Not enough bytes left in buffer.') return valueRead def readAnsiString(self) -> str: """ Reads a null-terminated string in ANSI format. """ return self._readDecodedString('ansi') def readAsciiString(self) -> str: """ Reads a null-terminated string in ASCII format. """ return self._readDecodedString('ascii') def readByte(self) -> int: """ Reads a signed byte from the stream. """ value = self.tryReadBytes(1) if value: return self.__int8_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readByteString(self, width: int = 1) -> bytes: """ Reads a string of bytes until it finds the null character, returning everything before that and consuming the null. Unlike other string functions, this will not decode the data into a string. :param width: tells how big a character is (in bytes), so this function can be used for strings whose characters are multiple bytes. """ if width < 1: raise ValueError('Character width must be at least 1.') position = self.tell() string = b'' null = b'\x00' * width while True: nextChar = self.read(width) if nextChar == b'': # We reached the end of the buffer without finding the null. We # need to seek back to where we started and raise an exception. while self.tell() != position: self.seek(position) raise IOError('Could not find null character.') elif nextChar == null: # If we find the null character, return what we have read. return string else: # Otherwise add the character to our string. string += nextChar def readClass(self, _class: Type[_T]) -> _T: """ Takes anything with a __SIZE__ property and a call function that takes a single bytes argument and returns the result of that function. Generally, this is intended to take a fixed-size class and return an instance of the class created with that amount of bytes. However, there is little reason to truly limit it to only that. """ if not hasattr(_class, '__SIZE__'): raise TypeError('Argument to readClass MUST have a __SIZE__ attribute.') value = self.tryReadBytes(_class.__SIZE__) if value: return _class(value) else: raise IOError('Not enough bytes left in buffer.') def readDouble(self) -> float: """ Reads a double from the stream. """ value = self.tryReadBytes(8) if value: return self.__double_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readFloat(self) -> float: """ Reads a float from the stream. """ value = self.tryReadBytes(4) if value: return self.__float_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readInt(self) -> int: """ Reads a signed int from the stream. """ value = self.tryReadBytes(4) if value: return self.__int32_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readLong(self) -> int: """ Reads a signed byte from the stream. """ value = self.tryReadBytes(8) if value: return self.__int64_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readShort(self) -> int: """ Reads a signed short from the stream. """ value = self.tryReadBytes(2) if value: return self.__int16_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readStruct(self, _struct: Union[struct.Struct, Any]) -> Tuple[Any, ...]: """ Read enough bytes for a struct and unpack it, returning the tuple of values. :param _struct: A struct or struct-like object using duck-typing. Only requires the object have an unpack method that takes a single argument and a size property to tell how many bytes to read. :raises IOError: If there are not enough bytes left to read. """ value = self.tryReadBytes(_struct.size) if value: return _struct.unpack(value) else: raise IOError('Not enough bytes left in buffer.') def readUnsignedByte(self) -> int: """ Reads an unsigned byte from the stream. """ value = self.tryReadBytes(1) if value: return self.__uint8_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readUnsignedInt(self) -> int: """ Reads an unsigned int from the stream. """ value = self.tryReadBytes(4) if value: return self.__uint32_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readUnsignedLong(self) -> int: """ Reads an unsigned long from the stream. """ value = self.tryReadBytes(8) if value: return self.__uint64_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readUnsignedShort(self) -> int: """ Reads an unsigned short from the stream. """ value = self.tryReadBytes(2) if value: return self.__uint16_t.unpack(value)[0] else: raise IOError('Not enough bytes left in buffer.') def readUtf8String(self) -> str: """ Reads a null-terminated string in UTF-8 format. """ return self._readDecodedString('utf-8') def readUtf16String(self) -> str: """ Reads a null terminated string in UTF-16 format using the endienness of the reader to determine which one to use. """ return self._readDecodedString('utf-16-le' if self.__le else 'utf-16-be', 2) def readUtf32String(self) -> str: """ Reads a null terminated string in UTF-32 format using the endienness of the reader to determine which one to use. """ return self._readDecodedString('utf-32-le' if self.__le else 'utf-32-be', 4) def tryReadBytes(self, size: int) -> bytes: """ Tries to read the specified number of bytes, returning b'' if not possible. Will only change the position of the read pointer if reading was possible. """ if size < 1: raise ValueError(':param size: must be at least 1.') position = self.tell() value = self.read(size) if len(value) == size: return value # Ensure that we seek back to where we started if we could not read. while self.tell() != position: self.seek(position) return b''
Memory