Browse Source

refactor to use Context class. It's not finished.

Marcin Krzyżanowski 11 years ago
parent
commit
23a5c3a787
2 changed files with 95 additions and 93 deletions
  1. 94 92
      CryptoSwift/Poly1305.swift
  2. 1 1
      CryptoSwiftTests/CipherTests.swift

+ 94 - 92
CryptoSwift/Poly1305.swift

@@ -15,87 +15,89 @@ import Foundation
 public class Poly1305 {
 public class Poly1305 {
     let blockSize = 16
     let blockSize = 16
 
 
-    lazy var buffer:[Byte] = {
-        [unowned self] in return [Byte](count: self.blockSize, repeatedValue: 0)
-    }()
+    private var ctx = Context()
     
     
-    var r            = [Byte](count: 17, repeatedValue: 0)
-    var h            = [Byte](count: 17, repeatedValue: 0)
-    var pad          = [Byte](count: 17, repeatedValue: 0)
-    var final:Byte   = 0
-    var leftover:Int = 0
+    private class Context {
+        var r            = [Byte](count: 17, repeatedValue: 0)
+        var h            = [Byte](count: 17, repeatedValue: 0)
+        var pad          = [Byte](count: 17, repeatedValue: 0)
+        var buffer       = [Byte](count: 16, repeatedValue: 0)
+        
+        var final:Byte   = 0
+        var leftover:Int = 0
+    }
     
     
     public init (key: [Byte]) {
     public init (key: [Byte]) {
-        setupKey(key)
+        setupKey(ctx, key: key)
     }
     }
     
     
     deinit {
     deinit {
-        for i in 0..<buffer.count {
-            buffer[i] = 0
+        for i in 0..<ctx.buffer.count {
+            ctx.buffer[i] = 0
         }
         }
         
         
-        for i in 0..<(r.count) {
-            r[i] = 0
-            h[i] = 0
-            pad[i] = 0
-            final = 0
-            leftover = 0
+        for i in 0..<ctx.r.count {
+            ctx.r[i] = 0
+            ctx.h[i] = 0
+            ctx.pad[i] = 0
+            ctx.final = 0
+            ctx.leftover = 0
         }
         }
     }
     }
     
     
-    func setupKey(key:[Byte]) {
+    private func setupKey(context:Context, key:[Byte]) {
         assert(key.count == 32,"Invalid key length");
         assert(key.count == 32,"Invalid key length");
         if (key.count != 32) {
         if (key.count != 32) {
             return;
             return;
         }
         }
         
         
         for i in 0..<17 {
         for i in 0..<17 {
-            h[i] = 0
+            context.h[i] = 0
         }
         }
         
         
-        r[0] = key[0] & 0xff;
-        r[1] = key[1] & 0xff;
-        r[2] = key[2] & 0xff;
-        r[3] = key[3] & 0x0f;
-        r[4] = key[4] & 0xfc;
-        r[5] = key[5] & 0xff;
-        r[6] = key[6] & 0xff;
-        r[7] = key[7] & 0x0f;
-        r[8] = key[8] & 0xfc;
-        r[9] = key[9] & 0xff;
-        r[10] = key[10] & 0xff;
-        r[11] = key[11] & 0x0f;
-        r[12] = key[12] & 0xfc;
-        r[13] = key[13] & 0xff;
-        r[14] = key[14] & 0xff;
-        r[15] = key[15] & 0x0f;
-        r[16] = 0
+        context.r[0] = key[0] & 0xff;
+        context.r[1] = key[1] & 0xff;
+        context.r[2] = key[2] & 0xff;
+        context.r[3] = key[3] & 0x0f;
+        context.r[4] = key[4] & 0xfc;
+        context.r[5] = key[5] & 0xff;
+        context.r[6] = key[6] & 0xff;
+        context.r[7] = key[7] & 0x0f;
+        context.r[8] = key[8] & 0xfc;
+        context.r[9] = key[9] & 0xff;
+        context.r[10] = key[10] & 0xff;
+        context.r[11] = key[11] & 0x0f;
+        context.r[12] = key[12] & 0xfc;
+        context.r[13] = key[13] & 0xff;
+        context.r[14] = key[14] & 0xff;
+        context.r[15] = key[15] & 0x0f;
+        context.r[16] = 0
         
         
         for i in 0..<16 {
         for i in 0..<16 {
-            pad[i] = key[i + 16]
+            context.pad[i] = key[i + 16]
         }
         }
-        pad[16] = 0
+        context.pad[16] = 0
         
         
-        leftover = 0
-        final = 0
+        context.leftover = 0
+        context.final = 0
     }
     }
     
     
-    private func add(inout h:[Byte], c:[Byte]) -> Bool {
-        if (h.count != 17 && c.count != 17) {
+    private func add(context:Context, c:[Byte]) -> Bool {
+        if (context.h.count != 17 && c.count != 17) {
             return false
             return false
         }
         }
         
         
         var u:UInt16 = 0
         var u:UInt16 = 0
         for i in 0..<17 {
         for i in 0..<17 {
-            u += UInt16(h[i]) + UInt16(c[i])
-            h[i] = Byte.withValue(u)
+            u += UInt16(context.h[i]) + UInt16(c[i])
+            context.h[i] = Byte.withValue(u)
             u = u >> 8
             u = u >> 8
         }
         }
         return true
         return true
     }
     }
     
     
-    private func squeeze(inout h:[Byte], hr:[UInt32]) -> Bool {
-        if (h.count != 17 && hr.count != 17) {
+    private func squeeze(context:Context, hr:[UInt32]) -> Bool {
+        if (context.h.count != 17 && hr.count != 17) {
             return false
             return false
         }
         }
 
 
@@ -103,27 +105,27 @@ public class Poly1305 {
 
 
         for i in 0..<16 {
         for i in 0..<16 {
             u += hr[i];
             u += hr[i];
-            h[i] = Byte.withValue(u) // crash! h[i] = UInt8(u) & 0xff
+            context.h[i] = Byte.withValue(u) // crash! h[i] = UInt8(u) & 0xff
             u >>= 8;
             u >>= 8;
         }
         }
         
         
         u += hr[16]
         u += hr[16]
-        h[16] = Byte.withValue(u) & 0x03
+        context.h[16] = Byte.withValue(u) & 0x03
         u >>= 2
         u >>= 2
         u += (u << 2); /* u *= 5; */
         u += (u << 2); /* u *= 5; */
         for i in 0..<16 {
         for i in 0..<16 {
-            u += UInt32(h[i])
-            h[i] = Byte.withValue(u) // crash! h[i] = UInt8(u) & 0xff
+            u += UInt32(context.h[i])
+            context.h[i] = Byte.withValue(u) // crash! h[i] = UInt8(u) & 0xff
             u >>= 8
             u >>= 8
         }
         }
-        h[16] += Byte.withValue(u);
+        context.h[16] += Byte.withValue(u);
         
         
         return true
         return true
     }
     }
     
     
-    private func freeze(inout h:[Byte]) -> Bool {
-        assert(h.count == 17,"Invalid length")
-        if (h.count != 17) {
+    private func freeze(context:Context) -> Bool {
+        assert(context.h.count == 17,"Invalid length")
+        if (context.h.count != 17) {
             return false
             return false
         }
         }
         
         
@@ -132,28 +134,28 @@ public class Poly1305 {
         
         
         /* compute h + -p */
         /* compute h + -p */
         for i in 0..<17 {
         for i in 0..<17 {
-            horig[i] = h[i]
+            horig[i] = context.h[i]
         }
         }
         
         
-        add(&h, c: minusp)
+        add(context, c: minusp)
         
         
         /* select h if h < p, or h + -p if h >= p */
         /* select h if h < p, or h + -p if h >= p */
-        let bits:[Bit] = (h[16] >> 7).bits()
+        let bits:[Bit] = (context.h[16] >> 7).bits()
         let invertedBits = bits.map({ (bit) -> Bit in
         let invertedBits = bits.map({ (bit) -> Bit in
             return bit.inverted()
             return bit.inverted()
         })
         })
         
         
         let negative = Byte(bits: invertedBits)
         let negative = Byte(bits: invertedBits)
         for i in 0..<17 {
         for i in 0..<17 {
-            h[i] ^= negative & (horig[i] ^ h[i]);
+            context.h[i] ^= negative & (horig[i] ^ context.h[i]);
         }
         }
         
         
         return true;
         return true;
     }
     }
     
     
-    private func blocks(m:[Byte], startPos:Int = 0) -> Int {
+    private func blocks(context:Context, m:[Byte], startPos:Int = 0) -> Int {
         var bytes = m.count
         var bytes = m.count
-        let hibit = final ^ 1 // 1 <<128
+        let hibit = context.final ^ 1 // 1 <<128
         var mPos = startPos
         var mPos = startPos
         
         
         while (bytes >= Int(blockSize)) {
         while (bytes >= Int(blockSize)) {
@@ -167,23 +169,23 @@ public class Poly1305 {
             }
             }
             c[16] = hibit
             c[16] = hibit
 
 
-            add(&h,c: c)
+            add(context, c: c)
 
 
             /* h *= r */
             /* h *= r */
             for i in 0..<17 {
             for i in 0..<17 {
                 u = 0
                 u = 0
                 for j in 0...i {
                 for j in 0...i {
-                    u = u + UInt32(UInt16(h[j])) * UInt32(r[i - j]) // u += (unsigned short)st->h[j] * st->r[i - j];
+                    u = u + UInt32(UInt16(context.h[j])) * UInt32(context.r[i - j]) // u += (unsigned short)st->h[j] * st->r[i - j];
                 }
                 }
                 for j in (i+1)..<17 {
                 for j in (i+1)..<17 {
-                    var v:UInt32 = UInt32(UInt16(h[j])) * UInt32(r[i + 17 - j])  // unsigned long v = (unsigned short)st->h[j] * st->r[i + 17 - j];
+                    var v:UInt32 = UInt32(UInt16(context.h[j])) * UInt32(context.r[i + 17 - j])  // unsigned long v = (unsigned short)st->h[j] * st->r[i + 17 - j];
                     v = ((v &<< 8) &+ (v &<< 6))
                     v = ((v &<< 8) &+ (v &<< 6))
                     u = u &+ v
                     u = u &+ v
                 }
                 }
                 hr[i] = u
                 hr[i] = u
             }
             }
             
             
-            squeeze(&h, hr: hr)
+            squeeze(context, hr: hr)
 
 
             mPos += blockSize
             mPos += blockSize
             bytes -= blockSize
             bytes -= blockSize
@@ -191,69 +193,71 @@ public class Poly1305 {
         return mPos
         return mPos
     }
     }
     
     
-    private func finish(inout mac:[Byte]) -> Bool {
+    private func finish(context:Context, mac:[Byte]) -> [Byte]? {
         assert(mac.count == 16, "Invalid mac length")
         assert(mac.count == 16, "Invalid mac length")
         if (mac.count != 16) {
         if (mac.count != 16) {
-            return false
+            return nil
         }
         }
         
         
+        var resultMAC = mac;
+        
         /* process the remaining block */
         /* process the remaining block */
-        if (leftover > 0) {
+        if (context.leftover > 0) {
             
             
-            var i = leftover
-            buffer[i++] = 1
+            var i = context.leftover
+            context.buffer[i++] = 1
             for (; i < blockSize; i++) {
             for (; i < blockSize; i++) {
-                buffer[i] = 0
+                context.buffer[i] = 0
             }
             }
-            final = 1
+            context.final = 1
             
             
-            blocks(buffer)
+            blocks(context, m: context.buffer)
         }
         }
         
         
         
         
         /* fully reduce h */
         /* fully reduce h */
-        freeze(&h)
+        freeze(context)
         
         
         /* h = (h + pad) % (1 << 128) */
         /* h = (h + pad) % (1 << 128) */
-        add(&h, c: pad)
+        add(context, c: context.pad)
         for i in 0..<16 {
         for i in 0..<16 {
-            mac[i] = h[i]
+            resultMAC[i] = context.h[i]
         }
         }
         
         
-        return true
+        return resultMAC
     }
     }
     
     
-    private func update(m:[Byte]) {
+    private func update(context:Context, m:[Byte]) {
         var bytes = m.count
         var bytes = m.count
         var mPos = 0
         var mPos = 0
         
         
         /* handle leftover */
         /* handle leftover */
-        if (leftover > 0) {
-            var want = blockSize - leftover
+        if (context.leftover > 0) {
+            var want = blockSize - context.leftover
             if (want > bytes) {
             if (want > bytes) {
                 want = bytes
                 want = bytes
             }
             }
             
             
             for i in 0..<want {
             for i in 0..<want {
-                buffer[leftover + i] = m[mPos + i]
+                context.buffer[context.leftover + i] = m[mPos + i]
             }
             }
             
             
             bytes -= want
             bytes -= want
             mPos += want
             mPos += want
-            leftover += want
+            context.leftover += want
             
             
-            if (leftover < blockSize) {
+            if (context.leftover < blockSize) {
                 return
                 return
             }
             }
             
             
-            blocks(buffer)
-            leftover = 0
+            blocks(context, m: context.buffer)
+            context.leftover = 0
         }
         }
         
         
         /* process full blocks */
         /* process full blocks */
         if (bytes >= blockSize) {
         if (bytes >= blockSize) {
             var want = bytes & ~(blockSize - 1)
             var want = bytes & ~(blockSize - 1)
-            blocks(m, startPos: mPos)
+            blocks(context, m: m, startPos: mPos)
             mPos += want
             mPos += want
             bytes -= want;
             bytes -= want;
         }
         }
@@ -261,18 +265,16 @@ public class Poly1305 {
         /* store leftover */
         /* store leftover */
         if (bytes > 0) {
         if (bytes > 0) {
             for i in 0..<bytes {
             for i in 0..<bytes {
-                buffer[leftover + i] = m[mPos + i]
+                context.buffer[context.leftover + i] = m[mPos + i]
             }
             }
             
             
-            leftover += bytes
+            context.leftover += bytes
         }
         }
         
         
     }
     }
     
     
-    public func auth(mac:[Byte], m:[Byte]) -> [Byte] {
-        update(m)
-        var macWork = mac
-        finish(&macWork)
-        return macWork
+    public func auth(mac:[Byte], m:[Byte]) -> [Byte]? {
+        update(ctx, m: m)
+        return finish(ctx, mac: mac)
     }
     }
 }
 }

+ 1 - 1
CryptoSwiftTests/CipherTests.swift

@@ -28,7 +28,7 @@ class CipherTests: XCTestCase {
         
         
         let poly = Poly1305(key: key);
         let poly = Poly1305(key: key);
         var resultMac = poly.auth(mac, m: msg)
         var resultMac = poly.auth(mac, m: msg)
-        XCTAssertEqual(resultMac, expectedMac, "Invalid auth mac")
+        XCTAssertEqual(resultMac!, expectedMac, "Invalid auth mac")
     }
     }
 
 
     func testChaCha20() {
     func testChaCha20() {