maintainers/scripts/sha-to-sri: format

This commit is contained in:
nicoo 2024-09-13 09:35:26 +00:00
parent cc9007b639
commit c425822e17

View File

@ -26,11 +26,12 @@ class Encoding(ABC):
assert len(digest) == self.n assert len(digest) == self.n
from base64 import b64encode from base64 import b64encode
return f"{self.hashName}-{b64encode(digest).decode()}" return f"{self.hashName}-{b64encode(digest).decode()}"
@classmethod @classmethod
def all(cls, h) -> 'List[Encoding]': def all(cls, h) -> "List[Encoding]":
return [ c(h) for c in cls.__subclasses__() ] return [c(h) for c in cls.__subclasses__()]
def __init__(self, h): def __init__(self, h):
self.n = h.digest_size self.n = h.digest_size
@ -38,54 +39,57 @@ class Encoding(ABC):
@property @property
@abstractmethod @abstractmethod
def length(self) -> int: def length(self) -> int: ...
...
@property @property
def regex(self) -> str: def regex(self) -> str:
return f"[{self.alphabet}]{{{self.length}}}" return f"[{self.alphabet}]{{{self.length}}}"
@abstractmethod @abstractmethod
def decode(self, s: str) -> bytes: def decode(self, s: str) -> bytes: ...
...
class Nix32(Encoding): class Nix32(Encoding):
alphabet = "0123456789abcdfghijklmnpqrsvwxyz" alphabet = "0123456789abcdfghijklmnpqrsvwxyz"
inverted = { c: i for i, c in enumerate(alphabet) } inverted = {c: i for i, c in enumerate(alphabet)}
@property @property
def length(self): def length(self):
return 1 + (8 * self.n) // 5 return 1 + (8 * self.n) // 5
def decode(self, s: str): def decode(self, s: str):
assert len(s) == self.length assert len(s) == self.length
out = [ 0 for _ in range(self.n) ] out = [0 for _ in range(self.n)]
# TODO: Do better than a list of byte-sized ints # TODO: Do better than a list of byte-sized ints
for n, c in enumerate(reversed(s)): for n, c in enumerate(reversed(s)):
digit = self.inverted[c] digit = self.inverted[c]
i, j = divmod(5 * n, 8) i, j = divmod(5 * n, 8)
out[i] = out[i] | (digit << j) & 0xff out[i] = out[i] | (digit << j) & 0xFF
rem = digit >> (8 - j) rem = digit >> (8 - j)
if rem == 0: if rem == 0:
continue continue
elif i < self.n: elif i < self.n:
out[i+1] = rem out[i + 1] = rem
else: else:
raise ValueError(f"Invalid nix32 hash: '{s}'") raise ValueError(f"Invalid nix32 hash: '{s}'")
return bytes(out) return bytes(out)
class Hex(Encoding): class Hex(Encoding):
alphabet = "0-9A-Fa-f" alphabet = "0-9A-Fa-f"
@property @property
def length(self): def length(self):
return 2 * self.n return 2 * self.n
def decode(self, s: str): def decode(self, s: str):
from binascii import unhexlify from binascii import unhexlify
return unhexlify(s) return unhexlify(s)
class Base64(Encoding): class Base64(Encoding):
alphabet = "A-Za-z0-9+/" alphabet = "A-Za-z0-9+/"
@ -94,36 +98,39 @@ class Base64(Encoding):
"""Number of characters in data and padding.""" """Number of characters in data and padding."""
i, k = divmod(self.n, 3) i, k = divmod(self.n, 3)
return 4 * i + (0 if k == 0 else k + 1), (3 - k) % 3 return 4 * i + (0 if k == 0 else k + 1), (3 - k) % 3
@property @property
def length(self): def length(self):
return sum(self.format) return sum(self.format)
@property @property
def regex(self): def regex(self):
data, padding = self.format data, padding = self.format
return f"[{self.alphabet}]{{{data}}}={{{padding}}}" return f"[{self.alphabet}]{{{data}}}={{{padding}}}"
def decode(self, s): def decode(self, s):
from base64 import b64decode from base64 import b64decode
return b64decode(s, validate = True) return b64decode(s, validate = True)
_HASHES = (hashlib.new(n) for n in ('SHA-256', 'SHA-512')) _HASHES = (hashlib.new(n) for n in ("SHA-256", "SHA-512"))
ENCODINGS = { ENCODINGS = {h.name: Encoding.all(h) for h in _HASHES}
h.name: Encoding.all(h)
for h in _HASHES
}
RE = { RE = {
h: "|".join( h: "|".join(
(f"({h}-)?" if e.name == 'base64' else '') + (f"({h}-)?" if e.name == "base64" else "") + f"(?P<{h}_{e.name}>{e.regex})"
f"(?P<{h}_{e.name}>{e.regex})"
for e in encodings for e in encodings
) for h, encodings in ENCODINGS.items() )
for h, encodings in ENCODINGS.items()
} }
_DEF_RE = re.compile("|".join( _DEF_RE = re.compile(
"|".join(
f"(?P<{h}>{h} = (?P<{h}_quote>['\"])({re})(?P={h}_quote);)" f"(?P<{h}>{h} = (?P<{h}_quote>['\"])({re})(?P={h}_quote);)"
for h, re in RE.items() for h, re in RE.items()
)) )
)
def defToSRI(s: str) -> str: def defToSRI(s: str) -> str:
@ -153,7 +160,7 @@ def defToSRI(s: str) -> str:
@contextmanager @contextmanager
def atomicFileUpdate(target: Path): def atomicFileUpdate(target: Path):
'''Atomically replace the contents of a file. """Atomically replace the contents of a file.
Guarantees that no temporary files are left behind, and `target` is either Guarantees that no temporary files are left behind, and `target` is either
left untouched, or overwritten with new content if no exception was raised. left untouched, or overwritten with new content if no exception was raised.
@ -164,9 +171,10 @@ def atomicFileUpdate(target: Path):
Upon exiting the context, the files are closed; if no exception was Upon exiting the context, the files are closed; if no exception was
raised, `new` (atomically) replaces the `target`, otherwise it is deleted. raised, `new` (atomically) replaces the `target`, otherwise it is deleted.
''' """
# That's mostly copied from noto-emoji.py, should DRY it out # That's mostly copied from noto-emoji.py, should DRY it out
from tempfile import mkstemp from tempfile import mkstemp
fd, _p = mkstemp( fd, _p = mkstemp(
dir = target.parent, dir = target.parent,
prefix = target.name, prefix = target.name,
@ -175,7 +183,7 @@ def atomicFileUpdate(target: Path):
try: try:
with target.open() as original: with target.open() as original:
with tmpPath.open('w') as new: with tmpPath.open("w") as new:
yield (original, new) yield (original, new)
tmpPath.replace(target) tmpPath.replace(target)
@ -188,22 +196,20 @@ def atomicFileUpdate(target: Path):
def fileToSRI(p: Path): def fileToSRI(p: Path):
with atomicFileUpdate(p) as (og, new): with atomicFileUpdate(p) as (og, new):
for i, line in enumerate(og): for i, line in enumerate(og):
with log_context(line=i): with log_context(line = i):
new.write(defToSRI(line)) new.write(defToSRI(line))
_SKIP_RE = re.compile( _SKIP_RE = re.compile("(generated by)|(do not edit)", re.IGNORECASE)
"(generated by)|(do not edit)",
re.IGNORECASE
)
if __name__ == "__main__": if __name__ == "__main__":
from sys import argv, stderr from sys import argv, stderr
logger.info("Starting!") logger.info("Starting!")
for arg in argv[1:]: for arg in argv[1:]:
p = Path(arg) p = Path(arg)
with log_context(path=str(p)): with log_context(path = str(p)):
try: try:
if p.name == "yarn.nix" or p.name.find("generated") != -1: if p.name == "yarn.nix" or p.name.find("generated") != -1:
logger.warning("File looks autogenerated, skipping!") logger.warning("File looks autogenerated, skipping!")