# Copyright 2018-2019 VMware, Inc.
# All rights reserved. -- VMware Confidential

"""GUID Partition Table (GPT) Disk Layout.

Set of classes and functions to setup a GPT partition table on a physical block
device or in a raw (DD-style) file.

TODO:
  - Support partition attributes (see UEFI spec, Table 20)
  - Add proper support for 4k-block devices
  - Check partition alignment
  - Support passing partition size instead of (start, end) LBA
"""
from binascii import crc32
from random import getrandbits
from struct import pack, unpack_from
from uuid import UUID, uuid4

from systemStorage import *

# Per UEFI spec, a minimum of 16,384 bytes of space must be reserved for the GPT
# Partition Entry Array. This is 128 entries of 128 bytes each.
SIZEOF_GPT_ENTRY = 128
NUM_GPT_ENTRIES = 128
PARTITION_NAME_MAXLEN = 72

SIZEOF_MBR_BOOT_CODE = 424

GPT_FS_TYPES = {"ebd0a0a2-b9e5-4433-87c0-68b6b72699c7": FS_TYPE_VFAT,
                "c12a7328-f81f-11d2-ba4b-00a0c93ec93b": FS_TYPE_UEFI_SYSTEM,
                "aa31e02a-400f-11db-9590-000c2911d1b8": FS_TYPE_VMFS,
                "4eb2ea39-7855-4790-a79e-fae495e21f8d": FS_TYPE_VMFS_L,
                "9d275380-40ad-11db-bf97-000c2911d1b8": FS_TYPE_VMKCORE}


class GptPartition(object):
   """Base class to represent a GPT partition.
   """
   _FMT = ('<'    # Little endian
           '16s'  # Partition type GUID
           '16s'  # Partition GUID
           'Q'    # Starting LBA
           'Q'    # Ending LBA
           '8x'   # Attributes, UEFI-reserved
           '72s') # Human-readable partition name

   def __init__(self, partNum, fsType, start, end, label, guid=None):
      super().__init__()

      if not (0 < partNum <= NUM_GPT_ENTRIES):
         raise ValueError("%u: invalid partition number (min: 1, max: %u)" %
                          (partNum, NUM_GPT_ENTRIES))

      if start > end:
         raise ValueError("invalid partition boundaries (start: %u > end: %u)" %
                          (start, end))

      if len(label) >= PARTITION_NAME_MAXLEN:
         raise ValueError("%s: partition name is too long (maxLen: %u)" %
                          (label, PARTITION_NAME_MAXLEN - 1))

      self.num = partNum
      self.fsType = fsType
      self.start = start
      self.end = end
      self.label = label
      self.guid = uuid4() if guid is None else guid

   @classmethod
   def unpack(cls, partNum, buf):
      """Unpack a GPT Partition Array Entry from memory.

      This method returns a GptPartition instance, initialized with the content
      of the partition entry.
      """
      fsGuid, guid, startLba, endLba, label = unpack_from(cls._FMT, buf)

      fsGuid = UUID(bytes_le=fsGuid)
      fsType = GPT_FS_TYPES.get(str(fsGuid).lower(), FS_TYPE_UNKNOWN)

      return cls(partNum, fsType, startLba, endLba,
                 label.decode("utf_16_le").strip('\t\n\r\0'),
                 guid=UUID(bytes_le=guid))

   def pack(self):
      """Return a GPT partition array entry for this partition.
      """
      for fsGuid, fsType in GPT_FS_TYPES.items():
         if self.fsType == fsType:
            entry = pack(self._FMT, UUID(fsGuid).bytes_le,
                         self.guid.bytes_le, self.start, self.end,
                         self.label.encode("utf_16_le"))
            assert len(entry) == SIZEOF_GPT_ENTRY, \
               ("bad GPT partition array entry (size=%u != %u)" %
                (len(entry), SIZEOF_GPT_ENTRY))
            return entry

      raise NotImplementedError("%s: filesystem is not supported" % self.fsType)


class GptHeader(object):
   """Class to load/store a GPT header.
   """
   GPT_HEADER_SIGNATURE = b'EFI PART' # Per UEFI spec
   GPT_HEADER_REVISION = 0x00010000   # Per UEFI spec
   GPT_HEADER_SIZE = 92               # Per UEFI spec

   _FMT = ('<'    # Little endian
           '8s'   # 8-byte ASCII signature
           'I'    # Header revision
           'I'    # Header size
           'I'    # Header CRC32
           '4x'   # UEFI reserved, must be 0
           'QQ'   # MyLBA, AlternateLBA
           'QQ'   # FirstUsableLBA, LastUsableLBA
           '16s'  # Disk GUID
           'Q'    # Partition Array LBA
           'I'    # Number of partition entries
           'I'    # Size of partition entry
           'I')   # Partition array CRC32

   def __init__(self, myLba, altLba, firstUsableLba, lastUsableLba, diskGuid,
                partitionArrayLba, numPartitions, sizeOfPartitionEntry,
                partitionArrayCrc):
      self.myLba = myLba
      self.altLba = altLba
      self.firstUsableLba = firstUsableLba
      self.lastUsableLba = lastUsableLba
      self.diskGuid = diskGuid
      self.partitionArrayLba = partitionArrayLba
      self.numPartitions = numPartitions
      self.sizeOfPartitionEntry = sizeOfPartitionEntry
      self.partitionArrayCrc = partitionArrayCrc

   @classmethod
   def unpack(cls, buf):
      """Unpack a GPT Header structure from memory.
      """
      header = unpack_from(cls._FMT, buf)

      if header[0] != GptHeader.GPT_HEADER_SIGNATURE:
         raise ValueError("invalid GPT header signature")

      if header[2] != GptHeader.GPT_HEADER_SIZE:
         raise RuntimeError("invalid GPT header size (%u)" % header[2])

      headerCrc = header[3]
      args = list(header)
      args[3] = 0
      if headerCrc != crc32(pack(cls._FMT, *args)):
         raise RuntimeError("GPT header is corrupted (CRC mismatch)")

      diskGuid = UUID(bytes_le=header[8])

      return cls(header[4], header[5], header[6], header[7], diskGuid,
                 header[9], header[10], header[11], header[12])

   def pack(self):
      """Construct a GPT header (primary or backup).
      """
      header = pack(self._FMT, GptHeader.GPT_HEADER_SIGNATURE,
                    GptHeader.GPT_HEADER_REVISION,
                    GptHeader.GPT_HEADER_SIZE, 0,
                    self.myLba, self.altLba, self.firstUsableLba,
                    self.lastUsableLba, self.diskGuid.bytes_le,
                    self.partitionArrayLba, self.numPartitions,
                    self.sizeOfPartitionEntry, self.partitionArrayCrc)

      crc = crc32(header)
      return header[:16] + pack("<I", crc) + header[20:]


class Gpt(object):
   """Object that represents the GPT metadata for an entire disk.
   """

   def __init__(self, diskSizeInLba, sizeOfLba=512):
      self.diskSizeInLba = diskSizeInLba
      self.sizeOfLba = sizeOfLba
      self._sizeOfPartitionEntryArray = self.bytesToLba(NUM_GPT_ENTRIES *
                                                        SIZEOF_GPT_ENTRY)
      self.guid = uuid4()
      self.partitions = {}
      self._bootPart = None
      self._mbrBootCode = None

   def bytesToLba(self, byteOffset):
      """Convert a size/offset from bytes to LBA.
      """
      assert byteOffset % self.sizeOfLba == 0, \
         ("%u: invalid sector offset cannot be converted to LBA "
          "(not a multiple of sector size %u)" % (byteOffset, self.sizeOfLba))
      return byteOffset // self.sizeOfLba

   @property
   def primaryHeaderLba(self):
      """LBA of the primary GPT header (LBA 1 per UEFI spec).
      """
      return 1

   @property
   def primaryPartitionArrayLba(self):
      """LBA of the GPT partition entry array (follows the primary GPT header).
      """
      return self.primaryHeaderLba + 1

   @property
   def secondaryHeaderLba(self):
      """LBA of the secondary GPT header (last LBA of the device per UEFI spec).
      """
      return self.diskSizeInLba - 1

   @property
   def secondaryPartitionArrayLba(self):
      """LBA of the GPT partition entry array (precedes the secondary GPT
      header).
      """
      return self.secondaryHeaderLba - self._sizeOfPartitionEntryArray

   @property
   def firstUsableLba(self):
      """LBA of the first usable LBA.
      """
      assert self.sizeOfLba in (512, 4096), ("%u: invalid sector size (must be "
                                             "512 or 4096)" % self.sizeOfLba)

      if self.sizeOfLba == 512:
         # Per UEFI spec, if the block size is 512, the First Usable LBA must be
         # greater than or equal to 34 (allowing 1 block for the Protective MBR,
         # 1 block for the Partition Table Header, and 32 blocks for the GPT
         # Partition Entry Array)
         return 34
      else:
         # Per UEFI spec, if the logical block size is 4096, the First Useable
         # LBA must be greater than or equal to 6 (allowing 1 block for the
         # Protective MBR, 1 block for the GPT Header, and 4 blocks for the GPT
         # Partition Entry Array).
         return 6

   @property
   def lastUsableLba(self):
      """LBA of the last usable LBA.
      """
      return self.secondaryPartitionArrayLba - 1

   def _packHeader(self, backup=False):
      """Construct a GPT header (primary or backup).
      """
      if backup:
         myLba = self.secondaryHeaderLba
         altLba = self.primaryHeaderLba
         partitionArrayLba = self.secondaryPartitionArrayLba
      else:
         myLba = self.primaryHeaderLba
         altLba = self.secondaryHeaderLba
         partitionArrayLba = self.primaryPartitionArrayLba

      partitionArrayCrc = crc32(self.partitionArray)

      hdr = GptHeader(myLba, altLba, self.firstUsableLba, self.lastUsableLba,
                      self.guid, partitionArrayLba, NUM_GPT_ENTRIES,
                      SIZEOF_GPT_ENTRY, partitionArrayCrc)

      header = hdr.pack()
      header += bytes(self.sizeOfLba - len(header))
      assert len(header) == self.sizeOfLba, ("bad GPT header (size=%u != %u)"
                                             % len(header), self.sizeOfLba)
      return header

   @property
   def primaryGptHeader(self):
      """Primary GPT header as raw bytes.
      """
      return self._packHeader()

   @property
   def secondaryGptHeader(self):
      """Secondary (backup) GPT header as raw bytes.
      """
      return self._packHeader(backup=True)

   @property
   def partitionArray(self):
      """GPT partition entry array as raw bytes.
      """
      p = bytes(NUM_GPT_ENTRIES * SIZEOF_GPT_ENTRY)

      for part in self.partitions.values():
         idx = SIZEOF_GPT_ENTRY * (part.num - 1)
         p = p[:idx] + part.pack() + p[idx + SIZEOF_GPT_ENTRY:]

      assert len(p) == NUM_GPT_ENTRIES * SIZEOF_GPT_ENTRY, \
         ("bad GPT partition array (size=%u != %u)" %
          (len(p), NUM_GPT_ENTRIES * SIZEOF_GPT_ENTRY))
      return p

   def _setPartition(self, part):
      """Internal helper to safely add/modify a partition entry in the GPT.
      """
      backup = self.partitions.get(part.num)
      self.partitions[part.num] = part
      try:
         self.check()
      except Exception:
         if backup is None:
            del self.partitions[part.num]
         else:
            self.partitions[part.num] = backup
         raise

   def setPartition(self, partNum, fsType, start, end, label, guid=None):
      """Add/modify a partition entry in the GPT.
      """
      part = GptPartition(partNum, fsType, start, end, label, guid=guid)
      self._setPartition(part)

   def scan(self, readLbaFn):
      """Read a block device's GPT using the given @readLbaFn callback.
      """
      lba = readLbaFn(self.primaryHeaderLba, 1)
      header = GptHeader.unpack(lba)
      if header.myLba != self.primaryHeaderLba:
         raise RuntimeError("invalid primary GPT header (myLba: %u)" %
                            header.myLba)

      entriesSize = header.numPartitions * SIZEOF_GPT_ENTRY
      sizeOfPartitionArray = self.bytesToLba(entriesSize)
      partitionArray = readLbaFn(header.partitionArrayLba, sizeOfPartitionArray)
      if header.partitionArrayCrc != crc32(partitionArray[0:entriesSize]):
         raise RuntimeError("GPT partition array is corrupted (CRC mismatch)")

      for i in range(header.numPartitions):
         offset = i * SIZEOF_GPT_ENTRY
         entry = partitionArray[offset:offset + SIZEOF_GPT_ENTRY]
         if entry[:16] == bytes(16):
            # skip unused partition
            continue

         part = GptPartition.unpack(i + 1, entry)
         self._setPartition(part)

   def sync(self, writeLbaFn, skipBackupGpt=False):
      """Write the GPT to disk.
      """
      self.check(skipBackupGpt=skipBackupGpt)

      # The kernel expects the request to write to the GPT area as a single
      # scatter gather chunk, otherwise it returns a read-only error.
      # Combine the GPT blocks to be a single unit before calling writeLbaFn
      # so that the kernel has the full context of the write request.
      gptBlocks = self.protectiveMbr + self.primaryGptHeader + self.partitionArray
      size = self.sizeOfLba * 2 + NUM_GPT_ENTRIES * SIZEOF_GPT_ENTRY
      assert len(gptBlocks) == size, ("invalid GPT sync request (size=%u != %u)"
                                      % (len(gptBlocks), size))
      writeLbaFn(0, gptBlocks)

      if not skipBackupGpt:
         writeLbaFn(self.secondaryPartitionArrayLba,
                    self.partitionArray + self.secondaryGptHeader)

   def setBootPartition(self, partNum):
      """Mark the given partition as 'bootable'.
      """
      if partNum not in self.partitions:
         raise ValueError("%u: invalid boot partition number "
                          "(no such partition)" % partNum)

      self._bootPart = partNum

   @property
   def mbrBootCode(self):
      """Protective MBR boot code as raw bytes.
      """
      if self._mbrBootCode is None:
         self._mbrBootCode = bytes(SIZEOF_MBR_BOOT_CODE)
      return self._mbrBootCode

   @mbrBootCode.setter
   def mbrBootCode(self, bootCode):
      """Helper function to set the Protective MBR boot code.
      """
      assert len(bootCode) == SIZEOF_MBR_BOOT_CODE, \
         ("bad GPT boot code block (size=%u != %u)" %
          (len(bootCode), SIZEOF_MBR_BOOT_CODE))
      self._mbrBootCode = bootCode

   @property
   def protectiveMbr(self):
      """The Protective MBR as an array of bytes.

      The Protective MBR precedes the GUID Partition Table Header to maintain
      compatibility with existing tools that do not understand GPT partition
      structures.
      """
      SIZEOF_MBR = 512

      diskSizeInLba = min(self.lastUsableLba, 0xFFFFFFFF)
      mbrGuid = getrandbits(32)

      mbr = self.mbrBootCode
      if self._bootPart is None:
         mbr += bytes(16)
      else:
         mbr += self.partitions[self._bootPart].guid.bytes_le

      mbr += pack('<'    # little endian
                  'I'    # Unique MBR disk Signature
                  '2B'   # unused, set to zero
                  '8BII' # partition Record protecting the entire disk
                  '48x'  # three partition records each set to zero.
                  'BB',  # MBR signature
                  mbrGuid, 0x1d, 0x9a,
                  0x00, 0x00, 0x02, 0x00, 0xee, 0xfe, 0xff, 0xff,
                  0x00000001, diskSizeInLba, 0x55, 0xaa)

      assert len(mbr) == SIZEOF_MBR, ("bad protective MBR block (size=%u != %u)"
                                      % (len(mbr), SIZEOF_MBR))

      # Pad MBR to fill up to disk sector size
      mbr += bytes(self.sizeOfLba - len(mbr))
      return mbr

   def _checkBackupGpt(self):
      """Sanity check of the backup GPT.

      Per UEFI spec, the backup GPT Partition Entry Array must be located
      after the Last Usable LBA and end before the backup GPT Header.
      """
      start = self.secondaryPartitionArrayLba
      end = start + self._sizeOfPartitionEntryArray - 1
      if start <= self.lastUsableLba:
         raise ValueError("misplaced secondary partition array "
                          "(lower than Last Usable LBA)")
      if end >= self.secondaryHeaderLba:
         raise ValueError("misplaced secondary partition array "
                          "(higher than secondary GPT header)")

   def check(self, skipBackupGpt=False):
      """Sanity check the GPT metadata.

      This function ensures that the GPT metadata is consistent and can be
      safely written to disk.
      """
      # Per UEFI spec, the primary GPT Partition Entry Array must be located
      # after the primary GPT Header and end before the First Usable LBA.
      start = self.primaryPartitionArrayLba
      end = start + self._sizeOfPartitionEntryArray - 1
      if start <= self.primaryHeaderLba:
         raise ValueError("misplaced primary partition array "
                          "(lower than primary GPT header)")
      if end >= self.firstUsableLba:
         raise ValueError("misplaced primary partition array "
                          "(higher than First Usable LBA)")

      if not skipBackupGpt:
         self._checkBackupGpt()

      partitions = sorted(self.partitions.values(), key=lambda p: p.start)
      if partitions:
         # Per UEFI spec, all partitions must be containted within the First and
         # Last Usable LBA.
         if partitions[0].start < self.firstUsableLba:
            raise ValueError("invalid partition starting LBA "
                             "(start: %u < firstUsableLba: %u" %
                             (partitions[0].start, self.firstUsableLba))
         if partitions[-1].end > self.lastUsableLba:
            raise ValueError("invalid partition ending LBA "
                             "(end: %u > lastUsableLba: %u" %
                             (partitions[-1].end, self.lastUsableLba))

         # Per UEFI spec, partitions may not overlap.
         end = self.firstUsableLba - 1
         for partition in partitions:
            if partition.start <= end:
               raise ValueError("invalid GPT (partition overlap)")
            end = partition.end

   def isBackupGptHealthy(self, readLbaFn):
      """Check if the backup GPT is present and its metadata is correct.
      """
      lba = readLbaFn(self.primaryHeaderLba, 1)
      primaryHeader = GptHeader.unpack(lba)

      lba = readLbaFn(self.secondaryHeaderLba, 1)
      try:
         secondaryHeader = GptHeader.unpack(lba)
      except (ValueError, RuntimeError) as e:
         return False

      return (secondaryHeader.diskGuid == primaryHeader.diskGuid           and
              secondaryHeader.numPartitions == primaryHeader.numPartitions and
              secondaryHeader.altLba == self.primaryHeaderLba              and
              secondaryHeader.myLba == self.secondaryHeaderLba)
