Poly1305.swift 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. //
  2. // Poly1305.swift
  3. // CryptoSwift
  4. //
  5. // Created by Marcin Krzyzanowski on 30/08/14.
  6. // Copyright (c) 2014 Marcin Krzyzanowski. All rights reserved.
  7. //
  8. // http://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-04#section-4
  9. //
  10. // Poly1305 takes a 32-byte, one-time key and a message and produces a 16-byte tag that authenticates the
  11. // message such that an attacker has a negligible chance of producing a valid tag for an inauthentic message.
  12. import Foundation
  13. public class Poly1305 {
  14. let blockSize = 16
  15. private var ctx:Context?
  16. private class Context {
  17. var r = [Byte](count: 17, repeatedValue: 0)
  18. var h = [Byte](count: 17, repeatedValue: 0)
  19. var pad = [Byte](count: 17, repeatedValue: 0)
  20. var buffer = [Byte](count: 16, repeatedValue: 0)
  21. var final:Byte = 0
  22. var leftover:Int = 0
  23. init (_ key: [Byte]) {
  24. assert(key.count == 32,"Invalid key length");
  25. if (key.count != 32) {
  26. return;
  27. }
  28. for i in 0..<17 {
  29. h[i] = 0
  30. }
  31. r[0] = key[0] & 0xff;
  32. r[1] = key[1] & 0xff;
  33. r[2] = key[2] & 0xff;
  34. r[3] = key[3] & 0x0f;
  35. r[4] = key[4] & 0xfc;
  36. r[5] = key[5] & 0xff;
  37. r[6] = key[6] & 0xff;
  38. r[7] = key[7] & 0x0f;
  39. r[8] = key[8] & 0xfc;
  40. r[9] = key[9] & 0xff;
  41. r[10] = key[10] & 0xff;
  42. r[11] = key[11] & 0x0f;
  43. r[12] = key[12] & 0xfc;
  44. r[13] = key[13] & 0xff;
  45. r[14] = key[14] & 0xff;
  46. r[15] = key[15] & 0x0f;
  47. r[16] = 0
  48. for i in 0..<16 {
  49. pad[i] = key[i + 16]
  50. }
  51. pad[16] = 0
  52. leftover = 0
  53. final = 0
  54. }
  55. deinit {
  56. for i in 0..<buffer.count {
  57. buffer[i] = 0
  58. }
  59. for i in 0..<r.count {
  60. r[i] = 0
  61. h[i] = 0
  62. pad[i] = 0
  63. final = 0
  64. leftover = 0
  65. }
  66. }
  67. }
  68. // MARK: - Internal
  69. /**
  70. Calculate Message Authentication Code (MAC) for message.
  71. Calculation context is discarder on instance deallocation.
  72. :param: key 256-bit key
  73. :param: message Message
  74. :returns: Message Authentication Code
  75. */
  76. class internal func authenticate(# key: NSData, message: NSData) -> NSData? {
  77. if let mac = Poly1305.authenticate(key: key.bytes(), message: message.bytes()) {
  78. return NSData(bytes: mac, length: mac.count)
  79. }
  80. return nil
  81. }
  82. class internal func authenticate(# key: [Byte], message: [Byte]) -> [Byte]? {
  83. return Poly1305(key).authenticate(message: message)
  84. }
  85. // MARK: - Private
  86. private init (_ key: [Byte]) {
  87. ctx = Context(key)
  88. }
  89. private func authenticate(# message:[Byte]) -> [Byte]? {
  90. if let ctx = self.ctx {
  91. update(ctx, message: message)
  92. return finish(ctx)
  93. }
  94. return nil
  95. }
  96. /**
  97. Add message to be processed
  98. :param: context Context
  99. :param: message message
  100. :param: bytes length of the message fragment to be processed
  101. */
  102. private func update(context:Context, message:[Byte], bytes:Int? = nil) {
  103. var bytes = bytes ?? message.count
  104. var mPos = 0
  105. /* handle leftover */
  106. if (context.leftover > 0) {
  107. var want = blockSize - context.leftover
  108. if (want > bytes) {
  109. want = bytes
  110. }
  111. for i in 0..<want {
  112. context.buffer[context.leftover + i] = message[mPos + i]
  113. }
  114. bytes -= want
  115. mPos += want
  116. context.leftover += want
  117. if (context.leftover < blockSize) {
  118. return
  119. }
  120. blocks(context, m: context.buffer)
  121. context.leftover = 0
  122. }
  123. /* process full blocks */
  124. if (bytes >= blockSize) {
  125. var want = bytes & ~(blockSize - 1)
  126. blocks(context, m: message, startPos: mPos)
  127. mPos += want
  128. bytes -= want;
  129. }
  130. /* store leftover */
  131. if (bytes > 0) {
  132. for i in 0..<bytes {
  133. context.buffer[context.leftover + i] = message[mPos + i]
  134. }
  135. context.leftover += bytes
  136. }
  137. }
  138. private func finish(context:Context) -> [Byte]? {
  139. var mac = [Byte](count: 16, repeatedValue: 0);
  140. /* process the remaining block */
  141. if (context.leftover > 0) {
  142. var i = context.leftover
  143. context.buffer[i++] = 1
  144. for (; i < blockSize; i++) {
  145. context.buffer[i] = 0
  146. }
  147. context.final = 1
  148. blocks(context, m: context.buffer)
  149. }
  150. /* fully reduce h */
  151. freeze(context)
  152. /* h = (h + pad) % (1 << 128) */
  153. add(context, c: context.pad)
  154. for i in 0..<mac.count {
  155. mac[i] = context.h[i]
  156. }
  157. return mac
  158. }
  159. // MARK: - Utils
  160. private func add(context:Context, c:[Byte]) -> Bool {
  161. if (context.h.count != 17 && c.count != 17) {
  162. return false
  163. }
  164. var u:UInt16 = 0
  165. for i in 0..<17 {
  166. u += UInt16(context.h[i]) + UInt16(c[i])
  167. context.h[i] = Byte.withValue(u)
  168. u = u >> 8
  169. }
  170. return true
  171. }
  172. private func squeeze(context:Context, hr:[UInt32]) -> Bool {
  173. if (context.h.count != 17 && hr.count != 17) {
  174. return false
  175. }
  176. var u:UInt32 = 0
  177. for i in 0..<16 {
  178. u += hr[i];
  179. context.h[i] = Byte.withValue(u) // crash! h[i] = UInt8(u) & 0xff
  180. u >>= 8;
  181. }
  182. u += hr[16]
  183. context.h[16] = Byte.withValue(u) & 0x03
  184. u >>= 2
  185. u += (u << 2); /* u *= 5; */
  186. for i in 0..<16 {
  187. u += UInt32(context.h[i])
  188. context.h[i] = Byte.withValue(u) // crash! h[i] = UInt8(u) & 0xff
  189. u >>= 8
  190. }
  191. context.h[16] += Byte.withValue(u);
  192. return true
  193. }
  194. private func freeze(context:Context) -> Bool {
  195. assert(context.h.count == 17,"Invalid length")
  196. if (context.h.count != 17) {
  197. return false
  198. }
  199. let minusp:[Byte] = [0x05,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfc]
  200. var horig:[Byte] = [Byte](count: 17, repeatedValue: 0)
  201. /* compute h + -p */
  202. for i in 0..<17 {
  203. horig[i] = context.h[i]
  204. }
  205. add(context, c: minusp)
  206. /* select h if h < p, or h + -p if h >= p */
  207. let bits:[Bit] = (context.h[16] >> 7).bits()
  208. let invertedBits = bits.map({ (bit) -> Bit in
  209. return bit.inverted()
  210. })
  211. let negative = Byte(bits: invertedBits)
  212. for i in 0..<17 {
  213. context.h[i] ^= negative & (horig[i] ^ context.h[i]);
  214. }
  215. return true;
  216. }
  217. private func blocks(context:Context, m:[Byte], startPos:Int = 0) -> Int {
  218. var bytes = m.count
  219. let hibit = context.final ^ 1 // 1 <<128
  220. var mPos = startPos
  221. while (bytes >= Int(blockSize)) {
  222. var hr:[UInt32] = [UInt32](count: 17, repeatedValue: 0)
  223. var u:UInt32 = 0
  224. var c:[Byte] = [Byte](count: 17, repeatedValue: 0)
  225. /* h += m */
  226. for i in 0..<16 {
  227. c[i] = m[mPos + i]
  228. }
  229. c[16] = hibit
  230. add(context, c: c)
  231. /* h *= r */
  232. for i in 0..<17 {
  233. u = 0
  234. for j in 0...i {
  235. u = u + UInt32(UInt16(context.h[j])) * UInt32(context.r[i - j]) // u += (unsigned short)st->h[j] * st->r[i - j];
  236. }
  237. for j in (i+1)..<17 {
  238. var v:UInt32 = UInt32(UInt16(context.h[j])) * UInt32(context.r[i + 17 - j]) // unsigned long v = (unsigned short)st->h[j] * st->r[i + 17 - j];
  239. v = ((v &<< 8) &+ (v &<< 6))
  240. u = u &+ v
  241. }
  242. hr[i] = u
  243. }
  244. squeeze(context, hr: hr)
  245. mPos += blockSize
  246. bytes -= blockSize
  247. }
  248. return mPos
  249. }
  250. }