Skip to main content
deleted 101 characters in body; edited title
Source Link
toolic
  • 16.8k
  • 6
  • 30
  • 224

ECDH implementation in pythonPython (part 3)

After the first and second part, the primary feedback I got was to rename a lot of my variables as I used the same name for different local variables which is really confusing. I tried to hotfix that, as well as adding many comments to the new parts of my code. Notice that the code contains a lot of mathematic formulas which are not easy to understand,understand; don't blame me for that since I didn't invent these formulas. I think my changes on the code are big enough to open a new question instead of editing the previous. If If you don't understand what a specific code part does, please make sure if that's related rather to the code than to the math before you're forcing yourself to an answer which is as helpful as "The code is bad".

ECDH implementation in python (part 3)

After the first and second part, the primary feedback I got was to rename a lot of my variables as I used the same name for different local variables which is really confusing. I tried to hotfix that, as well as adding many comments to the new parts of my code. Notice that the code contains a lot of mathematic formulas which are not easy to understand, don't blame me for that since I didn't invent these formulas. I think my changes on the code are big enough to open a new question instead of editing the previous. If you don't understand what a specific code part does, please make sure if that's related rather to the code than to the math before you're forcing yourself to an answer which is as helpful as "The code is bad".

ECDH implementation in Python (part 3)

After the first and second part, the primary feedback I got was to rename a lot of my variables as I used the same name for different local variables which is really confusing. I tried to hotfix that, as well as adding many comments to the new parts of my code. Notice that the code contains a lot of mathematic formulas which are not easy to understand; don't blame me for that since I didn't invent these formulas. If you don't understand what a specific code part does, please make sure if that's related rather to the code than to the math before you're forcing yourself to an answer which is as helpful as "The code is bad".

Tweeted twitter.com/StackCodeReview/status/1058146787708538880
fixed typo, clearified comment
Source Link
Aemyl
  • 860
  • 1
  • 8
  • 18
# coding: utf-8

MERSENNE_EXPONENTS = [
    2, 3, 5, 7, 13, 17, 19, 31, 61, 89,
    107, 127, 521, 607, 1279, 2203, 2281,
    3217, 4253, 4423
]


def ext_euclidean(a, b):
    """extended euclidean algorithm.
    returns gcd(a, b) as well as two numbers
    u and v, such that a*u + b*v = gcd(a, b)
    if gcd(a, b) is 1, u is the
    multiplicative inverse of a (mod b)
    """
    t = u = 1
    s = v = 0
    while b:
        a, (q, b) = b, divmod(a, b)
        u, s = s, u - q*s
        v, t = t, v - q*s
    return a, u, v  # a is now the gcdgreatest common divisor


def legendre(x, p):
    """calculates the legendre symbol.
    p has to be an odd prime.
    returns 1 if x is a quadratic residue (mod p)
    returns -1 if x is quadratic non-residue (mod p)
    returns 0 if x = 0 (mod p)
    """
    return pow(x, (p-1) // 2, p)


def W(n, r, x, modulus):
    """Calculates recursive defined numbers
    which are needed to calculate the modular
    square root of x modulo modulus if modulus = 1 (mod 4)
    """
    if n == 1:
        inv = ext_euclidean(x, modulus)[1]
        return (r*r*inv - 2) % modulus
    if n % 2 == 0:
        w0 = W((n-1) // 2, r, x, modulus)
        w1 = W((n+1) // 2, r, x, modulus)
        return (w0*w1 - W(1, r, x, modulus))
    if n % 2 == 0:
        return (W(n // 2, r, x, modulus)**2 - 2) % modulus


class Point:
    
    def __init__(self, x, y):
        
        self.x = x
        self.y = y
    
    def __str__(self):
        
        return '(' + str(self.x) + ', ' + str(self.y) + ')'
    
    def __eq__(self, P):
        
        if type(P) != type(self):
            return False
        return self.x == P.x and self.y == P.y


class EllipticCurve:
    """Provides functions for
    calculations on finite elliptic
    curves.
    """
    def __init__(self, a, b, modulus, warning=True):
        """Constructs the curve.
        a and b are parameters of the
        short Weierstraß equation:
        y^2 = x^3 + ax + b
        
        modulus is the order of the finite field,
        so the actual equation is
        y^2 = x^3 + ax + b (mod modulus)
        """
        self.a = a
        self.b = b
        self.modulus = modulus
        if warning:
            if modulus % 4 == 3 and b == 0:
                raise Warning
            if modulus % 6 == 5 and a == 0:
                raise Warning
    
    def mod_sqrt(self, v):
        """Calculates the modular square root
        of a given value v.
        """
        # check if there is a solution
        l = legendre(v, self.modulus)
        if l == (-1) % self.modulus:
            return None  # no solution
        if l == 0:
            return 0
        if l == 1:
            if self.modulus % 4 == 1:
                r = 0
                while legendre(r*r - 4*v, self.modulus) != (-1) % self.modulus:
                    r += 1
                w1 = W((self.modulus-1) // 4, r, v, self.modulus)
                w3 = W((self.modulus+3) // 4, r, v, self.modulus)
                inv_r = ext_euclidean(r, self.modulus)[1]
                inv_2 = (self.modulus + 1) // 2
                return (v * (w1 + w3) * inv_2 * inv_r) % self.modulus
            if self.modulus % 4 == 3:
                return pow(v, (self.modulus + 1) // 4, self.modulus)
            raise ValueError
        raise ValueError
    
    def generate(self, x):
        """generate Point with given x coordinate.
        """
        x %= self.modulus
        v = (x**3 + self.a*x + self.b) % self.modulus  # the curve equation
        y = self.mod_sqrt(v)
        if y is None:
            return None  # no solution
        return Point(x, y)
    
    def add(self, P, Q):
        """point addition on this curve.
        None is the neutral element.
        """
        if P is None:
            return Q
        if Q is None:
            return P
        numerator = (Q.y - P.y) % self.modulus
        denominator = (Q.x - P.x) % self.modulus
        if denominator == 0:
            if P == Q:
                # doubling the point
                if P.y == 0:
                    return None
                inv = ext_euclidean(2 * P.y, self.modulus)[1]
                slope = inv * (3 * P.x**2 + self.a) % self.modulus
            else:
                return None
        else:
            # normal point addition
            inv = ext_euclidean(denominator, self.modulus)[1]
            slope = inv * numerator % self.modulus
        Rx = (slope**2 - (P.x + Q.x)) % self.modulus
        Ry = (slope * (P.x - Rx) - P.y) % self.modulus
        return Point(Rx, Ry)
    
    def mul(self, P, n):
        """binary multiplication.
        double and add instead of square and multiply.
        """
        if P is None:
            return None
        if n < 0:
            P = Point(P.x, self.modulus - P.y)
            n = -n
        R = None
        for bit in bin(n)[2:]:
            R = self.add(R, R)
            if bit == '1':
                R = self.add(R, P)
        return R


class MersenneCurve(EllipticCurve):
    """Elliptic curve where the
    curve order is a Mersenne prime.
    """
    def __init__(self, a, b, exponent, warning=True):
        
        if exponent not in MERSENNE_EXPONENTS:
            raise ValueError
        if b == 0 and warning:
            raise Warning
        self.a = a
        self.b = b
        self.exponent = exponent
        self.modulus = 2**exponent - 1
# coding: utf-8

MERSENNE_EXPONENTS = [
    2, 3, 5, 7, 13, 17, 19, 31, 61, 89,
    107, 127, 521, 607, 1279, 2203, 2281,
    3217, 4253, 4423
]


def ext_euclidean(a, b):
    """extended euclidean algorithm.
    returns gcd(a, b) as well as two numbers
    u and v, such that a*u + b*v = gcd(a, b)
    if gcd(a, b) is 1, u is the
    multiplicative inverse of a (mod b)
    """
    t = u = 1
    s = v = 0
    while b:
        a, (q, b) = b, divmod(a, b)
        u, s = s, u - q*s
        v, t = t, v - q*s
    return a, u, v  # a is now the gcd


def legendre(x, p):
    """calculates the legendre symbol.
    p has to be an odd prime.
    returns 1 if x is a quadratic residue (mod p)
    returns -1 if x is quadratic non-residue (mod p)
    returns 0 if x = 0 (mod p)
    """
    return pow(x, (p-1) // 2, p)


def W(n, r, x, modulus):
    """Calculates recursive defined numbers
    which are needed to calculate the modular
    square root of x modulo modulus if modulus = 1 (mod 4)
    """
    if n == 1:
        inv = ext_euclidean(x, modulus)[1]
        return (r*r*inv - 2) % modulus
    if n % 2 == 0:
        w0 = W((n-1) // 2, r, x, modulus)
        w1 = W((n+1) // 2, r, x, modulus)
        return (w0*w1 - W(1, r, x, modulus))
    if n % 2 == 0:
        return (W(n // 2, r, x, modulus)**2 - 2) % modulus


class Point:
    
    def __init__(self, x, y):
        
        self.x = x
        self.y = y
    
    def __str__(self):
        
        return '(' + str(self.x) + ', ' + str(self.y) + ')'
    
    def __eq__(self, P):
        
        if type(P) != type(self):
            return False
        return self.x == P.x and self.y == P.y


class EllipticCurve:
    """Provides functions for
    calculations on finite elliptic
    curves.
    """
    def __init__(self, a, b, modulus, warning=True):
        """Constructs the curve.
        a and b are parameters of the
        short Weierstraß equation:
        y^2 = x^3 + ax + b
        
        modulus is the order of the finite field,
        so the actual equation is
        y^2 = x^3 + ax + b (mod modulus)
        """
        self.a = a
        self.b = b
        self.modulus = modulus
        if warning:
            if modulus % 4 == 3 and b == 0:
                raise Warning
            if modulus % 6 == 5 and a == 0:
                raise Warning
    
    def mod_sqrt(self, v):
        """Calculates the modular square root
        of a given value v.
        """
        # check if there is a solution
        l = legendre(v, self.modulus)
        if l == (-1) % self.modulus:
            return None  # no solution
        if l == 0:
            return 0
        if l == 1:
            if self.modulus % 4 == 1:
                r = 0
                while legendre(r*r - 4*v, self.modulus) != (-1) % self.modulus:
                    r += 1
                w1 = W((self.modulus-1) // 4, r, v, self.modulus)
                w3 = W((self.modulus+3) // 4, r, v, self.modulus)
                inv_r = ext_euclidean(r, self.modulus)[1]
                inv_2 = (self.modulus + 1) // 2
                return (v * (w1 + w3) * inv_2 * inv_r) % self.modulus
            if self.modulus % 4 == 3:
                return pow(v, (self.modulus + 1) // 4, self.modulus)
            raise ValueError
        raise ValueError
    
    def generate(self, x):
        """generate Point with given x coordinate.
        """
        x %= self.modulus
        v = (x**3 + self.a*x + self.b) % self.modulus  # the curve equation
        y = self.mod_sqrt(v)
        if y is None:
            return None  # no solution
        return Point(x, y)
    
    def add(self, P, Q):
        """point addition on this curve.
        None is the neutral element.
        """
        if P is None:
            return Q
        if Q is None:
            return P
        numerator = (Q.y - P.y) % self.modulus
        denominator = (Q.x - P.x) % self.modulus
        if denominator == 0:
            if P == Q:
                # doubling the point
                if P.y == 0:
                    return None
                inv = ext_euclidean(2 * P.y, self.modulus)[1]
                slope = inv * (3 * P.x**2 + self.a) % self.modulus
            else:
                return None
        else:
            # normal point addition
            inv = ext_euclidean(denominator, self.modulus)[1]
            slope = inv * numerator % self.modulus
        Rx = (slope**2 - (P.x + Q.x)) % self.modulus
        Ry = (slope * (P.x - Rx) - P.y) % self.modulus
        return Point(Rx, Ry)
    
    def mul(self, P, n):
        """binary multiplication.
        double and add instead of square and multiply.
        """
        if P is None:
            return None
        if n < 0:
            P = Point(P.x, self.modulus - P.y)
            n = -n
        R = None
        for bit in bin(n)[2:]:
            R = self.add(R, R)
            if bit == '1':
                R = self.add(R, P)
        return R


class MersenneCurve(EllipticCurve):
    """Elliptic curve where the
    curve order is a Mersenne prime.
    """
    def __init__(self, a, b, exponent, warning=True)
        
        if exponent not in MERSENNE_EXPONENTS:
            raise ValueError
        if b == 0 and warning:
            raise Warning
        self.a = a
        self.b = b
        self.exponent = exponent
        self.modulus = 2**exponent - 1
# coding: utf-8

MERSENNE_EXPONENTS = [
    2, 3, 5, 7, 13, 17, 19, 31, 61, 89,
    107, 127, 521, 607, 1279, 2203, 2281,
    3217, 4253, 4423
]


def ext_euclidean(a, b):
    """extended euclidean algorithm.
    returns gcd(a, b) as well as two numbers
    u and v, such that a*u + b*v = gcd(a, b)
    if gcd(a, b) is 1, u is the
    multiplicative inverse of a (mod b)
    """
    t = u = 1
    s = v = 0
    while b:
        a, (q, b) = b, divmod(a, b)
        u, s = s, u - q*s
        v, t = t, v - q*s
    return a, u, v  # a is now the greatest common divisor


def legendre(x, p):
    """calculates the legendre symbol.
    p has to be an odd prime.
    returns 1 if x is a quadratic residue (mod p)
    returns -1 if x is quadratic non-residue (mod p)
    returns 0 if x = 0 (mod p)
    """
    return pow(x, (p-1) // 2, p)


def W(n, r, x, modulus):
    """Calculates recursive defined numbers
    which are needed to calculate the modular
    square root of x modulo modulus if modulus = 1 (mod 4)
    """
    if n == 1:
        inv = ext_euclidean(x, modulus)[1]
        return (r*r*inv - 2) % modulus
    if n % 2 == 0:
        w0 = W((n-1) // 2, r, x, modulus)
        w1 = W((n+1) // 2, r, x, modulus)
        return (w0*w1 - W(1, r, x, modulus))
    if n % 2 == 0:
        return (W(n // 2, r, x, modulus)**2 - 2) % modulus


class Point:
    
    def __init__(self, x, y):
        
        self.x = x
        self.y = y
    
    def __str__(self):
        
        return '(' + str(self.x) + ', ' + str(self.y) + ')'
    
    def __eq__(self, P):
        
        if type(P) != type(self):
            return False
        return self.x == P.x and self.y == P.y


class EllipticCurve:
    """Provides functions for
    calculations on finite elliptic
    curves.
    """
    def __init__(self, a, b, modulus, warning=True):
        """Constructs the curve.
        a and b are parameters of the
        short Weierstraß equation:
        y^2 = x^3 + ax + b
        
        modulus is the order of the finite field,
        so the actual equation is
        y^2 = x^3 + ax + b (mod modulus)
        """
        self.a = a
        self.b = b
        self.modulus = modulus
        if warning:
            if modulus % 4 == 3 and b == 0:
                raise Warning
            if modulus % 6 == 5 and a == 0:
                raise Warning
    
    def mod_sqrt(self, v):
        """Calculates the modular square root
        of a given value v.
        """
        # check if there is a solution
        l = legendre(v, self.modulus)
        if l == (-1) % self.modulus:
            return None  # no solution
        if l == 0:
            return 0
        if l == 1:
            if self.modulus % 4 == 1:
                r = 0
                while legendre(r*r - 4*v, self.modulus) != (-1) % self.modulus:
                    r += 1
                w1 = W((self.modulus-1) // 4, r, v, self.modulus)
                w3 = W((self.modulus+3) // 4, r, v, self.modulus)
                inv_r = ext_euclidean(r, self.modulus)[1]
                inv_2 = (self.modulus + 1) // 2
                return (v * (w1 + w3) * inv_2 * inv_r) % self.modulus
            if self.modulus % 4 == 3:
                return pow(v, (self.modulus + 1) // 4, self.modulus)
            raise ValueError
        raise ValueError
    
    def generate(self, x):
        """generate Point with given x coordinate.
        """
        x %= self.modulus
        v = (x**3 + self.a*x + self.b) % self.modulus  # the curve equation
        y = self.mod_sqrt(v)
        if y is None:
            return None  # no solution
        return Point(x, y)
    
    def add(self, P, Q):
        """point addition on this curve.
        None is the neutral element.
        """
        if P is None:
            return Q
        if Q is None:
            return P
        numerator = (Q.y - P.y) % self.modulus
        denominator = (Q.x - P.x) % self.modulus
        if denominator == 0:
            if P == Q:
                # doubling the point
                if P.y == 0:
                    return None
                inv = ext_euclidean(2 * P.y, self.modulus)[1]
                slope = inv * (3 * P.x**2 + self.a) % self.modulus
            else:
                return None
        else:
            # normal point addition
            inv = ext_euclidean(denominator, self.modulus)[1]
            slope = inv * numerator % self.modulus
        Rx = (slope**2 - (P.x + Q.x)) % self.modulus
        Ry = (slope * (P.x - Rx) - P.y) % self.modulus
        return Point(Rx, Ry)
    
    def mul(self, P, n):
        """binary multiplication.
        double and add instead of square and multiply.
        """
        if P is None:
            return None
        if n < 0:
            P = Point(P.x, self.modulus - P.y)
            n = -n
        R = None
        for bit in bin(n)[2:]:
            R = self.add(R, R)
            if bit == '1':
                R = self.add(R, P)
        return R


class MersenneCurve(EllipticCurve):
    """Elliptic curve where the
    curve order is a Mersenne prime.
    """
    def __init__(self, a, b, exponent, warning=True):
        
        if exponent not in MERSENNE_EXPONENTS:
            raise ValueError
        if b == 0 and warning:
            raise Warning
        self.a = a
        self.b = b
        self.exponent = exponent
        self.modulus = 2**exponent - 1
Source Link
Aemyl
  • 860
  • 1
  • 8
  • 18

ECDH implementation in python (part 3)

After the first and second part, the primary feedback I got was to rename a lot of my variables as I used the same name for different local variables which is really confusing. I tried to hotfix that, as well as adding many comments to the new parts of my code. Notice that the code contains a lot of mathematic formulas which are not easy to understand, don't blame me for that since I didn't invent these formulas. I think my changes on the code are big enough to open a new question instead of editing the previous. If you don't understand what a specific code part does, please make sure if that's related rather to the code than to the math before you're forcing yourself to an answer which is as helpful as "The code is bad".

# coding: utf-8

MERSENNE_EXPONENTS = [
    2, 3, 5, 7, 13, 17, 19, 31, 61, 89,
    107, 127, 521, 607, 1279, 2203, 2281,
    3217, 4253, 4423
]


def ext_euclidean(a, b):
    """extended euclidean algorithm.
    returns gcd(a, b) as well as two numbers
    u and v, such that a*u + b*v = gcd(a, b)
    if gcd(a, b) is 1, u is the
    multiplicative inverse of a (mod b)
    """
    t = u = 1
    s = v = 0
    while b:
        a, (q, b) = b, divmod(a, b)
        u, s = s, u - q*s
        v, t = t, v - q*s
    return a, u, v  # a is now the gcd


def legendre(x, p):
    """calculates the legendre symbol.
    p has to be an odd prime.
    returns 1 if x is a quadratic residue (mod p)
    returns -1 if x is quadratic non-residue (mod p)
    returns 0 if x = 0 (mod p)
    """
    return pow(x, (p-1) // 2, p)


def W(n, r, x, modulus):
    """Calculates recursive defined numbers
    which are needed to calculate the modular
    square root of x modulo modulus if modulus = 1 (mod 4)
    """
    if n == 1:
        inv = ext_euclidean(x, modulus)[1]
        return (r*r*inv - 2) % modulus
    if n % 2 == 0:
        w0 = W((n-1) // 2, r, x, modulus)
        w1 = W((n+1) // 2, r, x, modulus)
        return (w0*w1 - W(1, r, x, modulus))
    if n % 2 == 0:
        return (W(n // 2, r, x, modulus)**2 - 2) % modulus


class Point:
    
    def __init__(self, x, y):
        
        self.x = x
        self.y = y
    
    def __str__(self):
        
        return '(' + str(self.x) + ', ' + str(self.y) + ')'
    
    def __eq__(self, P):
        
        if type(P) != type(self):
            return False
        return self.x == P.x and self.y == P.y


class EllipticCurve:
    """Provides functions for
    calculations on finite elliptic
    curves.
    """
    def __init__(self, a, b, modulus, warning=True):
        """Constructs the curve.
        a and b are parameters of the
        short Weierstraß equation:
        y^2 = x^3 + ax + b
        
        modulus is the order of the finite field,
        so the actual equation is
        y^2 = x^3 + ax + b (mod modulus)
        """
        self.a = a
        self.b = b
        self.modulus = modulus
        if warning:
            if modulus % 4 == 3 and b == 0:
                raise Warning
            if modulus % 6 == 5 and a == 0:
                raise Warning
    
    def mod_sqrt(self, v):
        """Calculates the modular square root
        of a given value v.
        """
        # check if there is a solution
        l = legendre(v, self.modulus)
        if l == (-1) % self.modulus:
            return None  # no solution
        if l == 0:
            return 0
        if l == 1:
            if self.modulus % 4 == 1:
                r = 0
                while legendre(r*r - 4*v, self.modulus) != (-1) % self.modulus:
                    r += 1
                w1 = W((self.modulus-1) // 4, r, v, self.modulus)
                w3 = W((self.modulus+3) // 4, r, v, self.modulus)
                inv_r = ext_euclidean(r, self.modulus)[1]
                inv_2 = (self.modulus + 1) // 2
                return (v * (w1 + w3) * inv_2 * inv_r) % self.modulus
            if self.modulus % 4 == 3:
                return pow(v, (self.modulus + 1) // 4, self.modulus)
            raise ValueError
        raise ValueError
    
    def generate(self, x):
        """generate Point with given x coordinate.
        """
        x %= self.modulus
        v = (x**3 + self.a*x + self.b) % self.modulus  # the curve equation
        y = self.mod_sqrt(v)
        if y is None:
            return None  # no solution
        return Point(x, y)
    
    def add(self, P, Q):
        """point addition on this curve.
        None is the neutral element.
        """
        if P is None:
            return Q
        if Q is None:
            return P
        numerator = (Q.y - P.y) % self.modulus
        denominator = (Q.x - P.x) % self.modulus
        if denominator == 0:
            if P == Q:
                # doubling the point
                if P.y == 0:
                    return None
                inv = ext_euclidean(2 * P.y, self.modulus)[1]
                slope = inv * (3 * P.x**2 + self.a) % self.modulus
            else:
                return None
        else:
            # normal point addition
            inv = ext_euclidean(denominator, self.modulus)[1]
            slope = inv * numerator % self.modulus
        Rx = (slope**2 - (P.x + Q.x)) % self.modulus
        Ry = (slope * (P.x - Rx) - P.y) % self.modulus
        return Point(Rx, Ry)
    
    def mul(self, P, n):
        """binary multiplication.
        double and add instead of square and multiply.
        """
        if P is None:
            return None
        if n < 0:
            P = Point(P.x, self.modulus - P.y)
            n = -n
        R = None
        for bit in bin(n)[2:]:
            R = self.add(R, R)
            if bit == '1':
                R = self.add(R, P)
        return R


class MersenneCurve(EllipticCurve):
    """Elliptic curve where the
    curve order is a Mersenne prime.
    """
    def __init__(self, a, b, exponent, warning=True)
        
        if exponent not in MERSENNE_EXPONENTS:
            raise ValueError
        if b == 0 and warning:
            raise Warning
        self.a = a
        self.b = b
        self.exponent = exponent
        self.modulus = 2**exponent - 1

I'm currently working on a class for Montgomery curves which have a different curve equation.