
if jit.arch ~= 'x64' then
    print('WARNING: please use BIT=64 for optimal OpenBLAS performance')
end

local ffi     = require 'ffi'
local bit     = require 'bit'
local time    = require 'time'
local alg     = require 'sci.alg'
local prng    = require 'sci.prng'
local stat    = require 'sci.stat'
local dist    = require 'sci.dist'
local complex = require 'sci.complex'

local min, sqrt, random, abs = math.min, math.sqrt, math.random, math.abs
local cabs = complex.abs
local rshift = bit.rshift
local format = string.format
local nowutc = time.nowutc
local rng = prng.std()
local vec, mat, join = alg.vec, alg.mat, alg.join
local sum, trace = alg.sum, alg.trace
local var, mean = stat.var, stat.mean

--------------------------------------------------------------------------------
local function elapsed(f)
    local t0 = nowutc()
    local val1, val2 = f()
    local t1 = nowutc()
    return (t1 - t0):tomilliseconds(), val1, val2
end

local function timeit(f, name, check)
    local t, k, s = 1/0, 0, nowutc()
    while true do
        k = k + 1
        local tx, val1, val2 = elapsed(f)
        t = min(t, tx)
        if check then 
            check(val1, val2)
        end
        if k > 5 and (nowutc() - s):toseconds() >= 2 then break end
    end
    io.write(format('lua,%s,%g\n', name, t))
end

--------------------------------------------------------------------------------
local function fib(n)
    if n < 2 then
        return n
    else
        return fib(n-1) + fib(n-2)
    end
end

timeit(function() return fib(20) end, 'fib', function(x) assert(x == 6765) end)

local function parseint()
    local lmt = 2^32 - 1
    local n, m
    for i = 1, 1000 do
        n = random(lmt) -- Between 0 and 2^32 - 1, i.e. uint32_t.
        local s = format('0x%x', tonumber(n))
        m = tonumber(s)
    end
    assert(n == m) -- Done here to be even with Julia benchmark.
    return n, m
end

timeit(parseint, 'parse_int')

local function mandel(z)
    local c = z
    local maxiter = 80
    for n = 1, maxiter do
        if cabs(z) > 2 then
            return n-1
        end
        z = z*z + c
    end
    return maxiter
end
local function mandelperf()
    local a = mat(26, 21)
    for r=1,26 do -- Lua's for i=l,u,c doesn't match Julia's for i=l:c:u.
        for c=1,21 do
            local re, im = (r - 21)*0.1, (c - 11)*0.1
            a[{r, c}] = mandel(re + im*1i)
        end
    end
    return a
end

timeit(mandelperf, 'mandel', function(a) assert(sum(a) == 14791) end)

local function qsort(a, lo, hi)
    local i, j = lo, hi
    while i < hi do
        local pivot = a[rshift(lo+hi, 1)]
        while i <= j do
            while a[i] < pivot do i = i+1 end
            while a[j] > pivot do j = j-1 end
            if i <= j then
                a[i], a[j] = a[j], a[i]
                i, j = i+1, j-1
            end
        end
        if lo < j then qsort(a, lo, j) end
        lo, j = i, hi
    end
    return a
end

local function sortperf()
    local n = 5000
    local v = ffi.new('double[?]', n+1)
    for i=1,n do
        v[i] = rng:sample()
    end
    return qsort(v, 1, n)
end

timeit(sortperf, 'quicksort', function(x)
    for i=2,5000 do
        assert(x[i-1] <= x[i])
    end
end
)

local function pisum()
    local s
    for j = 1, 500 do
        s = 0
        for k = 1, 10000 do
            s = s + 1 / (k*k)
        end
    end
    return s
end

timeit(pisum, 'pi_sum', function(x) 
    assert(abs(x - 1.644834071848065) < 1e-12)
end)

local function rand(r, c)
    local x = mat(r, c)
    for i=1,#x do 
        x[i] = rng:sample()
    end
    return x
end

local function randn(r, c)
    local x = mat(r, c)
    for i=1,#x do 
        x[i] = dist.normal(0, 1):sample(rng) 
    end
    return x
end

local function randmatstat(t)
    local n = 5
    local v, w = vec(t), vec(t)
    for i=1,t do
        local a, b, c, d = randn(n, n), randn(n, n), randn(n, n), randn(n, n)
        local P = join(a..b..c..d)
        local Q = join(a..b, c..d)
        v[i] = trace((P[]`**P[])^^4)
        w[i] = trace((Q[]`**Q[])^^4)
    end
    return sqrt(var(v))/mean(v), sqrt(var(w))/mean(w)
end

timeit(function() return randmatstat(1000) end, 'rand_mat_stat', 
    function(s1, s2)
        assert( 0.5 < s1 and s1 < 1.0 and 0.5 < s2 and s2 < 1.0 )
    end)

local function randmatmult(n)
    local a, b = rand(n, n), rand(n, n)
    return a[]**b[]
end

timeit(function() return randmatmult(1000) end, 'rand_mat_mul')

if jit.os ~= 'Windows' then
    local function printfd(n)
        local f = io.open('/dev/null','w')
        for i = 1, n do
            f:write(format('%d %d\n', i, i+1))
        end
        f:close()
    end

    timeit(function() return printfd(100000) end, 'printfd')
end