-
Notifications
You must be signed in to change notification settings - Fork 0
/
root.zok
112 lines (103 loc) · 5.66 KB
/
root.zok
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// TODO: exp should be checked just like aud, since plaintext numbers cannot be compared easily in base64
// we must make sure it {esn't reveal too much sensitive information though
import "hashes/sha256/sha256" as sha256;
from "EMBED" import u8_from_bits as u8FromBits;
from "EMBED" import u32_to_bits as u32ToBits;
// returns a substring of length S, starting at fullString[idx];
def substringAt<F,S>(u8[F] fullString, u32 idx) -> u8[S] {
u8[S] mut r = [0; S];
for u32 i in 0..S {
r[i] = fullString[idx + i];
}
return r;
}
// checks whether the substring of length S starting at fullstring[idx] matches `subString`
def hasSubstringAt<F,S>(u8[F] fullString, u8[S] subString, u32 idx) -> bool {
u8[S] r = substringAt(fullString, idx);
return (r == subString);
}
// applies `mask` to `str`. returns str & mask bitwise and for every element of the string/mask;
def maskString<L>(u8[L] str, u8[L] mask) -> u8[L] {
u8[L] mut masked = [0; L];
for u32 i in 0..L {
masked[i] = mask[i] & str[i];
}
return masked; //(r == subString)
}
// Checks masks, but where bytes are the unit of information, and bytes are set to all 1 or all 0
def checkMask<M>(u8[M] mask) -> bool {
u32 mut nChanges = 0;
bool mut valid = ((mask[0] != 255) || (mask[0] != 0)); // Track whether all bytes are 255 or 0
for u32 i in 1..M {
nChanges = (mask[i-1] != mask[i]) ? nChanges+1 : nChanges;
valid = valid && ((mask[i] != 255) || (mask[i] != 0));
}
return valid && ((nChanges == 0) || (nChanges == 1) || (nChanges == 2));
}
// Checks a substring in flattened starting at audIdx equals a public input, the aud claim
def checkAud<F,A>(u8[F] flattened, u8[A] aud, u32 audIdx) -> bool {
return hasSubstringAt(flattened, aud, audIdx);
}
// Checks a JWT's exp claim at expIdx. expIdx should be private but expGreaterThan should be public
// def checkExp<N>(u32[N][16] paddedJwt, u32 expGreaterThan, u32[2] expIdx) -> bool {
// return paddedJwt[expIdx[0]][expIdx[1]] > expGreaterThan;
// @param {extendedSub} the sub claim, plus the following extra characters up to length S for padding as inputs must be of fixed length
// @param {mask} since we {n't want to reveal anything after sub, the mask is provided to mask sub with 1s and other bits with 0s
// @param {masked} this can be made public -- this is the sub claim followed by 0s until length S
// @param {subIdx} this is where the the sub starts in flattened
def hasMaskedString<F,S>(u8[F] flattened, u8[S] mask, u8[S] masked, u32 subIdx) -> bool {
u8[S] extendedSub = substringAt(flattened, subIdx);
return (maskString(extendedSub, mask) == masked) && (checkMask(mask));
}
// TODO : PR this to stdlib/utils/casts
def u32sToU8s<N,E>(u32[N] u32s) -> u8[E] {
u8[E] mut ret = [0; E];
for u32 i in 0..N {
u32 fourI = 4*i;
bool[32] toBits = u32ToBits(u32s[i]);
ret[fourI] = u8FromBits(toBits[0..8]);
ret[fourI+1] = u8FromBits(toBits[8..16]);
ret[fourI+2] = u8FromBits(toBits[16..24]);
ret[fourI+3] = u8FromBits(toBits[24..32]);
}
return ret;
}
// Flatten a N*16 array of u32 to a 1-dimensional array of bytes
def flattenedToBytes<N,F>(u32[N][16] nBy16Arr) -> u8[F] {
u8[N*128] mut flattened = [0; N*128];
u32 mut iTimes64 = 0;
u32 mut iTimes64PlusJTimes4 = 0;
for u32 i in 0..N {
iTimes64 = 64 * i;
for u32 j in 0..16 {
iTimes64PlusJTimes4 = iTimes64 + 4 * j;
// iTimes16TimesJTimes8 = iTimes16 * j * 8
bool[32] newBits = u32ToBits(nBy16Arr[i][j]);
flattened[iTimes64PlusJTimes4] = u8FromBits(newBits[0..8]);
flattened[iTimes64PlusJTimes4+1] = u8FromBits(newBits[8..16]);
flattened[iTimes64PlusJTimes4+2] = u8FromBits(newBits[16..24]);
flattened[iTimes64PlusJTimes4+3] = u8FromBits(newBits[24..32]);
}
}
return flattened;
}
// @param {string} subMasked
// @param {string} serverKey a key deterministically generated by the server (e.g., by signing sub).
// serverKey is centralized *but* {es not mean the server can steal funds -- a valid JWT and serverKey are required for recovery
// having a serverKey prevents being {xxed -- otherwise, hash(sub) must be stored on-chain and can easily be found if somebody knows a sub they want to {xx on-chain
// serverKey can be requested from the server by providing a valid JWT; serverKey can equal hmac(maskedSub, serverSecret), for example.
// The server can keep the serverSecret secret but give people their unique serverKeys based on their subs for the proof
// The purpose of the serverKey is is pseudorandomness so that nobody except the server can brute-force the encryptedSub which is stored on-chain
//
// def encryptSub(u8[S] subMasked, serverKey) -> u32[8] {
// return;
// Wrap everything together to check a JWT
def checkJwt<N,A,S>(u32[N][16] paddedJwt, u32[8] jwtHash, private u8[A] audMask, u8[A] audMasked, private u32 audIdx, private u8[S] subMask, u8[S] subMasked, private u32 subIdx, private u8[24] expMask, u8[24] expMasked, private u32 expIdx) -> bool { //allow JWT up to 512 bits
// Flatten to bytes to search for aud, sub, and exp substring
u8[N*128] flattened = flattenedToBytes(paddedJwt);
// checkAud(flattened, aud, audIdx);
return (sha256(paddedJwt) == jwtHash) && hasMaskedString(flattened, audMask, audMasked, audIdx) && hasMaskedString(flattened, subMask, subMasked, subIdx) && hasMaskedString(flattened, expMask, expMasked, expIdx);
}
def main(private u32[3][16] paddedJwt, u32[8] jwtHash, private u8[24] audMask, u8[24] audMasked, private u32 audIdx, private u8[48] subMask, u8[48] subMasked, private u32 subIdx, private u8[24] expMask, u8[24] expMasked, private u32 expIdx) -> bool {
return checkJwt(paddedJwt, jwtHash, audMask, audMasked, audIdx, subMask, subMasked, subIdx, expMask, expMasked, expIdx);
}