generated from mwc/lab_subrosa
	
		
			
				
	
	
		
			172 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# This module provides higher-level wrappers around cryptography, 
 | 
						|
# handling intermediate steps and selecting parameters. This module
 | 
						|
# also allows users to work with strings instead of bytes, as users
 | 
						|
# will not yet have learned about encoding and decoding strings. 
 | 
						|
 | 
						|
# NOTE: DO NOT USE THESE CLASSES IN SERIOUS APPLICATIONS.
 | 
						|
 | 
						|
from cryptography.hazmat.primitives.asymmetric import rsa, padding
 | 
						|
from cryptography.hazmat.primitives import serialization, hashes
 | 
						|
from cryptography.exceptions import InvalidSignature
 | 
						|
from base64 import b64encode, b64decode
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
PUBLIC_EXPONENT = 65537
 | 
						|
KEY_SIZE = 2048
 | 
						|
 | 
						|
class PrivateKey:
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def load(cls, pem):
 | 
						|
        """Loads an existing private key.
 | 
						|
        Pem may be a bytes or a string representation of the PEM-formatted key, 
 | 
						|
        or a path to a pem file. 
 | 
						|
        """
 | 
						|
        if isinstance(pem, bytes):
 | 
						|
            key = serialization.load_pem_private_key(pem, password=None)
 | 
						|
        elif isinstance(pem, str) and pem.startswith("-----BEGIN RSA PRIVATE KEY-----"):
 | 
						|
            key = serialization.load_pem_private_key(pem.encode('ascii'), password=None)
 | 
						|
        elif isinstance(pem, (str, Path)):
 | 
						|
            with open(pem, 'rb') as key_file:
 | 
						|
                key = serialization.load_pem_private_key(
 | 
						|
                    key_file.read(), 
 | 
						|
                    password=None
 | 
						|
                )
 | 
						|
        else:
 | 
						|
            raise TypeError("PrivateKey.load requires pem bytes or a file path")
 | 
						|
        return PrivateKey(key)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def generate(cls):
 | 
						|
        "Generates a new private key."
 | 
						|
        key = rsa.generate_private_key(
 | 
						|
            public_exponent=PUBLIC_EXPONENT,
 | 
						|
            key_size=KEY_SIZE,
 | 
						|
        )
 | 
						|
        return PrivateKey(key)
 | 
						|
 | 
						|
    def __init__(self, key):
 | 
						|
        if not isinstance(key, rsa.RSAPrivateKey):
 | 
						|
            err = (
 | 
						|
                "PrivateKey is initialized with a rsa.RSAPrivateKey. " + 
 | 
						|
                "You probably want to use PrivateKey.load or PrivateKey.generate instead."
 | 
						|
            )
 | 
						|
            raise ValueError(err)
 | 
						|
        self.key = key
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        "Returns a string representation of the key in PEM format"
 | 
						|
        return self.key.private_bytes(
 | 
						|
            encoding=serialization.Encoding.PEM,
 | 
						|
            format=serialization.PrivateFormat.TraditionalOpenSSL,
 | 
						|
            encryption_algorithm=serialization.NoEncryption()
 | 
						|
        ).decode('utf8').strip()
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return "<PrivateKey>"
 | 
						|
 | 
						|
    def save(self, filepath):
 | 
						|
        "Saves the key as a pem file"
 | 
						|
        Path(filepath).write_text(str(self))
 | 
						|
 | 
						|
    def get_public_key(self):
 | 
						|
        "Gets the matching public key."
 | 
						|
        return PublicKey(self.key.public_key())
 | 
						|
 | 
						|
    def sign(self, message):
 | 
						|
        """Create an encrypted signature of the message. 
 | 
						|
        Anyone with the public key can verify that the signer had the matching 
 | 
						|
        private key.
 | 
						|
        """
 | 
						|
        message_bytes = message.encode('utf8')
 | 
						|
        signature = self.key.sign(
 | 
						|
            message_bytes,
 | 
						|
            padding.PSS(
 | 
						|
                mgf=padding.MGF1(hashes.SHA256()),
 | 
						|
                salt_length=padding.PSS.MAX_LENGTH
 | 
						|
            ),
 | 
						|
            hashes.SHA256()
 | 
						|
        )
 | 
						|
        return b64encode(signature).decode('ascii')
 | 
						|
 | 
						|
    def decrypt(self, ciphertext):
 | 
						|
        """Decrypts a message encrypted with the matching PublicKey.
 | 
						|
        """
 | 
						|
        ciphertext_bytes = b64decode(ciphertext.encode('ascii'))
 | 
						|
        plaintext = self.key.decrypt(
 | 
						|
            ciphertext_bytes,
 | 
						|
            padding.OAEP(
 | 
						|
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
 | 
						|
                algorithm=hashes.SHA256(),
 | 
						|
                label=None
 | 
						|
            )
 | 
						|
        )
 | 
						|
        return plaintext.decode('utf8')
 | 
						|
 | 
						|
class PublicKey:
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def load(self, pem):
 | 
						|
        """Loads an existing public key.
 | 
						|
        Pem may be a bytes or a string representation of the PEM-formatted key, 
 | 
						|
        or a path to a pem file. 
 | 
						|
        """
 | 
						|
        if isinstance(pem, bytes):
 | 
						|
            key = serialization.load_pem_public_key(pem)
 | 
						|
        elif isinstance(pem, str) and pem.startswith("-----BEGIN PUBLIC KEY-----"):
 | 
						|
            key = serialization.load_pem_public_key(pem.encode('ascii'))
 | 
						|
        elif isinstance(pem, (str, Path)):
 | 
						|
            with open(pem, 'rb') as key_file:
 | 
						|
                key = serialization.load_pem_public_key(key_file.read())
 | 
						|
        else:
 | 
						|
            raise TypeError("PublicKey.load requires pem bytes or a file path")
 | 
						|
        return PublicKey(key)
 | 
						|
 | 
						|
    def __init__(self, key):
 | 
						|
        self.key = key
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        "Returns a string representation of the key in PEM format"
 | 
						|
        return self.key.public_bytes(
 | 
						|
            encoding=serialization.Encoding.PEM,
 | 
						|
            format=serialization.PublicFormat.SubjectPublicKeyInfo
 | 
						|
        ).decode('utf8').strip()
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return "<PublicKey>"
 | 
						|
 | 
						|
    def save(self, filepath):
 | 
						|
        """Saves this key to a file in PEM format.
 | 
						|
        """
 | 
						|
        Path(filepath).write_text(str(self))
 | 
						|
 | 
						|
    def verify_signature(self, message, signature):
 | 
						|
        """Verifies that `message` was signed using the matching private key.
 | 
						|
        """
 | 
						|
        message_bytes = message.encode('utf8')
 | 
						|
        signature_bytes = b64decode(signature.encode('ascii'))
 | 
						|
        self.key.verify(
 | 
						|
            signature_bytes,
 | 
						|
            message_bytes,
 | 
						|
            padding.PSS(
 | 
						|
                mgf=padding.MGF1(hashes.SHA256()),
 | 
						|
                salt_length=padding.PSS.MAX_LENGTH
 | 
						|
            ),
 | 
						|
            hashes.SHA256()
 | 
						|
        )
 | 
						|
 | 
						|
    def encrypt(self, message):
 | 
						|
        """Encrypts a message so it can be decrypted with the matching PrivateKey. 
 | 
						|
        If encryption fails, your message is probably too long. 
 | 
						|
        """
 | 
						|
        message_bytes = message.encode('utf8')
 | 
						|
        ciphertext = self.key.encrypt(
 | 
						|
            message_bytes,
 | 
						|
            padding.OAEP(
 | 
						|
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
 | 
						|
                algorithm=hashes.SHA256(),
 | 
						|
                label=None
 | 
						|
            )
 | 
						|
        )
 | 
						|
        return b64encode(ciphertext).decode('ascii')
 |