Skip to content

Commit

Permalink
mem hash update
Browse files Browse the repository at this point in the history
  • Loading branch information
madMAx43v3r committed Nov 30, 2023
1 parent 47ce89e commit 8810518
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 35 deletions.
5 changes: 2 additions & 3 deletions include/mmx/pos/mem_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ namespace pos {
void gen_mem_array(uint32_t* mem, const uint8_t* key, const int key_size, const uint32_t mem_size);

/*
* M = log2 number of iterations
* mem = array of size (32 << B)
* mem = array of size 1024
* hash = array of size 128
*/
void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int M, const int B);
void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int num_iter);


} // pos
Expand Down
29 changes: 13 additions & 16 deletions src/pos/mem_hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,36 +65,33 @@ void gen_mem_array(uint32_t* mem, const uint8_t* key, const int key_size, const
}
}

void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int M, const int B)
void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int num_iter)
{
static constexpr int N = 32;

const int num_iter = (1 << M);
const uint32_t index_mask = ((1 << B) - 1);
const uint32_t offset_mask = 31u << 5;

uint32_t state[N];
for(int i = 0; i < N; ++i) {
state[i] = mem[index_mask * N + i];
state[i] = mem[31 * N + i];
}

for(int k = 0; k < num_iter; ++k)
{
uint32_t tmp = 0;
uint32_t sum = 0;
for(int i = 0; i < N; ++i) {
tmp += rotl_32(state[i] ^ 0x55555555, (k + i) % N);
sum += rotl_32(state[i], (k + i) % 32);
}
tmp ^= 0x55555555;
const uint32_t dir = sum % 1193u; // mod by prime

const auto bits = tmp % 32;
// const auto offset = ((tmp >> 5) & index_mask) * N;
const auto offset = tmp & (index_mask << 5);
const uint32_t bits = dir % 32u;
const uint32_t offset = dir & offset_mask;

for(int i = 0; i < N; ++i) {
const int shift = (k + i) % N;
state[i] += rotl_32(mem[offset + shift] ^ state[i], bits);
}
for(int i = 0; i < N; ++i) {
mem[offset + i] = state[i];
for(int i = 0; i < N; ++i)
{
state[i] += rotl_32(mem[offset + (k + i) % N], bits) ^ sum;

mem[offset + i] ^= state[i];
}
}

Expand Down
8 changes: 3 additions & 5 deletions src/pos/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
namespace mmx {
namespace pos {

static constexpr int MEM_HASH_N = 32;
static constexpr int MEM_HASH_M = 8;
static constexpr int MEM_HASH_B = 5;
static constexpr int MEM_HASH_ITER = 256;

static constexpr uint64_t MEM_SIZE = uint64_t(MEM_HASH_N) << MEM_HASH_B;
static constexpr uint64_t MEM_SIZE = 32 * 32;

static std::mutex g_mutex;
static std::shared_ptr<vnx::ThreadPool> g_threads;
Expand Down Expand Up @@ -50,7 +48,7 @@ void compute_f1(std::vector<uint32_t>* X_tmp,
gen_mem_array(mem_buf.data(), key.data(), key.size(), MEM_SIZE);

uint8_t mem_hash[128 + 64] = {};
calc_mem_hash(mem_buf.data(), mem_hash, MEM_HASH_M, MEM_HASH_B);
calc_mem_hash(mem_buf.data(), mem_hash, MEM_HASH_ITER);

::memcpy(mem_hash + 128, key.data(), key.size());

Expand Down
40 changes: 29 additions & 11 deletions test/test_mem_hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ using namespace mmx::pos;
int main(int argc, char** argv)
{
const int N = 32;
const int M = 8;
const int B = 5;

const int num_iter = argc > 1 ? std::atoi(argv[1]) : 1;
const int count = argc > 1 ? std::atoi(argv[1]) : 1;
const int num_iter = argc > 2 ? std::atoi(argv[2]) : 256;

const uint64_t mem_size = uint64_t(N) << B;
const uint64_t mem_size = uint64_t(N) * N;

std::cout << "N = " << N << std::endl;
std::cout << "M = " << M << std::endl;
std::cout << "B = " << B << std::endl;
std::cout << "count = " << count << std::endl;
std::cout << "num_iter = " << num_iter << std::endl;
std::cout << "mem_size = " << mem_size << " (" << mem_size * 4 / 1024 << " KiB)" << std::endl;

size_t pop_sum = 0;
size_t num_pass = 0;
size_t min_pop_count = 1024;

uint32_t* mem = new uint32_t[mem_size];

for(int iter = 0; iter < num_iter; ++iter)
for(int iter = 0; iter < count; ++iter)
{
uint8_t key[32] = {};
::memcpy(key, &iter, sizeof(iter));
Expand All @@ -42,7 +44,7 @@ int main(int argc, char** argv)
if(iter == 0) {
std::map<uint32_t, uint32_t> init_count;

for(int k = 0; k < (1 << B); ++k) {
for(int k = 0; k < 32; ++k) {
std::cout << "[" << k << "] " << std::hex;
for(int i = 0; i < N; ++i) {
init_count[mem[k * N + i]]++;
Expand All @@ -59,11 +61,27 @@ int main(int argc, char** argv)
}
mmx::bytes_t<128> hash;

calc_mem_hash(mem, hash.data(), M, B);
calc_mem_hash(mem, hash.data(), num_iter);

size_t pop = 0;
for(int i = 0; i < 1024; ++i) {
pop += (hash[i / 8] >> (i % 8)) & 1;
}
pop_sum += pop;

min_pop_count = std::min(min_pop_count, pop);

std::cout << "[" << iter << "] " << hash << std::endl;
if(pop <= 469) {
num_pass++;
}

std::cout << "[" << iter << "] " << hash << " (" << pop << ")" << std::endl;
}

std::cout << "num_pass = " << num_pass << " (" << num_pass / double(count) << ")" << std::endl;
std::cout << "min_pop_count = " << min_pop_count << std::endl;
std::cout << "avg_pop_count = " << pop_sum / double(count) << std::endl;

delete [] mem;

return 0;
Expand Down

0 comments on commit 8810518

Please sign in to comment.