lab_subrosa/encryption.py

172 lines
5.9 KiB
Python

# This module provides higher-level wrappers around cryptography,
# handling intermediate steps and selecting parameters. This module
# also allows users to work with strings instead of bytes, as users
# will not yet have learned about encoding and decoding strings.
# NOTE: DO NOT USE THESE CLASSES IN SERIOUS APPLICATIONS.
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.exceptions import InvalidSignature
from base64 import b64encode, b64decode
from pathlib import Path
PUBLIC_EXPONENT = 65537
KEY_SIZE = 2048
class PrivateKey:
@classmethod
def load(cls, pem):
"""Loads an existing private key.
Pem may be a bytes or a string representation of the PEM-formatted key,
or a path to a pem file.
"""
if isinstance(pem, bytes):
key = serialization.load_pem_private_key(pem, password=None)
elif isinstance(pem, str) and pem.startswith("-----BEGIN RSA PRIVATE KEY-----"):
key = serialization.load_pem_private_key(pem.encode('ascii'), password=None)
elif isinstance(pem, (str, Path)):
with open(pem, 'rb') as key_file:
key = serialization.load_pem_private_key(
key_file.read(),
password=None
)
else:
raise TypeError("PrivateKey.load requires pem bytes or a file path")
return PrivateKey(key)
@classmethod
def generate(cls):
"Generates a new private key."
key = rsa.generate_private_key(
public_exponent=PUBLIC_EXPONENT,
key_size=KEY_SIZE,
)
return PrivateKey(key)
def __init__(self, key):
if not isinstance(key, rsa.RSAPrivateKey):
err = (
"PrivateKey is initialized with a rsa.RSAPrivateKey. " +
"You probably want to use PrivateKey.load or PrivateKey.generate instead."
)
raise ValueError(err)
self.key = key
def __str__(self):
"Returns a string representation of the key in PEM format"
return self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode('utf8').strip()
def __repr__(self):
return "<PrivateKey>"
def save(self, filepath):
"Saves the key as a pem file"
Path(filepath).write_text(str(self))
def get_public_key(self):
"Gets the matching public key."
return PublicKey(self.key.public_key())
def sign(self, message):
"""Create an encrypted signature of the message.
Anyone with the public key can verify that the signer had the matching
private key.
"""
message_bytes = message.encode('utf8')
signature = self.key.sign(
message_bytes,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return b64encode(signature).decode('ascii')
def decrypt(self, ciphertext):
"""Decrypts a message encrypted with the matching PublicKey.
"""
ciphertext_bytes = b64decode(ciphertext.encode('ascii'))
plaintext = self.key.decrypt(
ciphertext_bytes,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
return plaintext.decode('utf8')
class PublicKey:
@classmethod
def load(self, pem):
"""Loads an existing public key.
Pem may be a bytes or a string representation of the PEM-formatted key,
or a path to a pem file.
"""
if isinstance(pem, bytes):
key = serialization.load_pem_public_key(pem)
elif isinstance(pem, str) and pem.startswith("-----BEGIN PUBLIC KEY-----"):
key = serialization.load_pem_public_key(pem.encode('ascii'))
elif isinstance(pem, (str, Path)):
with open(pem, 'rb') as key_file:
key = serialization.load_pem_public_key(key_file.read())
else:
raise TypeError("PublicKey.load requires pem bytes or a file path")
return PublicKey(key)
def __init__(self, key):
self.key = key
def __str__(self):
"Returns a string representation of the key in PEM format"
return self.key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode('utf8').strip()
def __repr__(self):
return "<PublicKey>"
def save(self, filepath):
"""Saves this key to a file in PEM format.
"""
Path(filepath).write_text(str(self))
def verify_signature(self, message, signature):
"""Verifies that `message` was signed using the matching private key.
"""
message_bytes = message.encode('utf8')
signature_bytes = b64decode(signature.encode('ascii'))
self.key.verify(
signature_bytes,
message_bytes,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
def encrypt(self, message):
"""Encrypts a message so it can be decrypted with the matching PrivateKey.
If encryption fails, your message is probably too long.
"""
message_bytes = message.encode('utf8')
ciphertext = self.key.encrypt(
message_bytes,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
return b64encode(ciphertext).decode('ascii')