# Objective: apply the most simple optimization algorithm to the
# Burer-Monteiro factorization of toy problems, and see which type of
# points it returns (critical, optimal ...).

# The toy problems are of the form
#     Reconstruct X of size nxn such that
#                 X is sdp
#                 A(X) = y
# where A : R^nxn -> R^m is linear,
# with the a priori knowledge that there exists a rank 1 solution

# The Burer-Monteiro factorization (with factorization rank 1) is :
#     Find U such that
#          A(UU^T) = y

# Algorithm: gradient descent over
#     U -> (1/4) ||A(UU^T) - y||^2

using LinearAlgebra

function evalA!(y,A,U)
    # Compute A(UU^T)

    for k=1:size(A)[3]
        y[k] = sum((A[:,:,k] * U) .* U)
    end

end

function generate_problem(n,m)
    
    # Linear operator
    A = randn(n,n,m)
    for k=1:m
        A[:,:,k] .+= A[:,:,k]'
    end

    # True solution
    Usol = randn(n,1) / sqrt(n)

    # Measurements
    y = zeros(m)
    evalA!(y,A,Usol)
    
    return A, Usol, y
end

function gradient_descent(A,y ; nb_its=1000,U_init = [])

    n = size(A)[1]
    m = size(A)[3]

    # Initialization
    if isempty(U_init)
        U = randn(n,1) / sqrt(n)
    else
        U = copy(U_init)
    end
    step = 1

    # Auxiliary storage
    diff = zeros(m)
    diff_new = zeros(m)
    adjoint = zeros(n,n)
    grad = zeros(n,1)
    U_new = zeros(n,1)
    
    for k_it = 1:nb_its
        
        # Compute the gradient
        evalA!(diff,A,U)
        diff .-= y
        adjoint .= 0
        for k=1:m
            adjoint .+= diff[k] .* A[:,:,k]
        end
        mul!(grad,adjoint,U)
        
        # Backtrack
        while true
            U_new .= U .- step .* grad
            evalA!(diff_new,A,U_new)
            diff_new .-= y
            if (1/4) * norm(diff_new)^2 <
                (1/4) * norm(diff)^2 - 0.2 * step * norm(grad)^2 + 1e-8
                break
            else
                step = step/2
            end
        end
        step = step * 1.1

        # Update U
        U .= U_new
        
    end

    return U
    
end

function critical_type(U,Usol,A,y)
    # Determine which kind of point has been reached :
    # global optimum ?
    # first-order critical point ?
    # second-order critical point ?
    # non-critical point ?

    if (norm(U-Usol) < 1e-3 * norm(Usol)) || (norm(U+Usol) < 1e-3 * norm(Usol))
        return "Global Optimum"
    else
        n = size(A)[1]
        m = size(A)[3]
        
        # Compute the gradient
        diff = zeros(m)
        evalA!(diff,A,U)
        diff .-= y
        adjoint = zeros(n,n)
        for k=1:m
            adjoint .+= diff[k] * A[:,:,k]
        end
        grad = adjoint * U

        if norm(grad) < 1e-3 * sqrt(m)
            # At least first-order critical; check whether it is also
            # second-order critical
            
            # Compute the Hessian
            for k=1:m
                adjoint .+= 4 * A[:,:,k] * U * U' * A[:,:,k]'
            end
            adjoint .= (adjoint .+ adjoint') ./ 2
            
            if isposdef(adjoint + 1e-3 * m * I)
                return "Second-order"
            else
                return "First-order"
            end
            
        else
            return "Not critical"
        end

    end
        
end

ns = 5:3:35
ms = 5:5:80
nbtest = 10
results = zeros(length(ms),length(ns),4)

for kn = 1:length(ns)
    n = ns[kn]
    for km = 1:length(ms)
        m = ms[km]
        if (n <= m)
            for nt=1:nbtest
                A, Usol, y = generate_problem(n,m)
                U = gradient_descent(A,y)
                ctype = critical_type(U,Usol,A,y)
                if (ctype == "Not critical")
                    results[km,kn,1] += 1
                elseif (ctype == "First-order")
                    results[km,kn,2] += 1
                elseif (ctype == "Second-order")
                    results[km,kn,3] += 1
                else (ctype == "Global optimum")
                    results[km,kn,4] += 1
                end
            end
        end
    end
end

fig, ax = plt.subplots(1,4)
for k=1:4
    ax[k].imshow(results[:,:,k],vmin=0,vmax=nbtest,extent=[minimum(ns),maximum(ns),
                                                           minimum(ms),maximum(ms)],
                 origin="lower")
end
savefig("points_critiques_output.png")
