import numpy as np

def generate_problem(n,m):
    # Generate a random instance of a Z^2-synchronization problem

    # Generate xsol with random entries in {-1,1}
    xsol = np.random.rand(n,1)
    xsol = 2 * (xsol > 0.5) - 1

    # Choose m pairs
    Omega = []
    while len(Omega) < m:
        i = np.random.randint(n)
        j = np.random.randint(n)
        if (i<j) and not((i,j) in Omega):
            Omega.append((i,j))

    # Measurements
    y = np.zeros((m,1))
    for k in range(m):
        y[k] = xsol[Omega[k][0],0] * xsol[Omega[k][1],0]
    
    return xsol, Omega, y


def cost_matrix(n,y,Omega):

    Y = np.zeros((n,n))

    for k in range(len(Omega)):
        Y[Omega[k][0],Omega[k][1]] = y[k]
        Y[Omega[k][1],Omega[k][0]] = y[k]

    return Y


def cost(Y,sig,U):
    # Evaluate cost function
    # -<Y,UU^T> + (sig/2) sum_i (||U[i,:]||^2 - 1)^2

    x = -np.sum((Y@U) * U)
    for k in range(np.shape(U)[0]):
        x = x + (sig/2) * (np.linalg.norm(U[k,:])**2 - 1)**2

    return x


def U_to_sol(U):
    # From a solution to the Burer-Monteiro problem, compute a
    # solution of the original synchronization problem

    svdvecs, _, _ = np.linalg.svd(U)
    x = 2 * (svdvecs[:,0] > 0) - 1
    x = x.reshape(-1,1)
    
    return x


def BM_GD(Y,sig,p,nb_its=100):
    # Attempt to minimize U -> -<Y,UU^T> + (σ/2) ||diag(UU^T) - 1||^2
    # by gradient descent, over R^(n x p)

    n = np.shape(Y)[0]

    # Random initialization
    U = np.random.randn(n,p) / np.sqrt(p)

    for k_it in range(nb_its):

        # Compute gradient
        # 2 (-YU + σ Diag(diag(UU^T) - 1)U)
        grad = - 2 * Y @ U
        diff_norms = np.sum(U**2,axis=1) - 1
        grad = grad + 2 * sig * diff_norms.reshape((n,1)) * U

        if (k_it == 0):
            step = 0.1 * np.linalg.norm(U) / np.linalg.norm(grad)

        # Backtrack
        while True:

            U_new = U - step * grad
            if (cost(Y,sig,U_new) < cost(Y,sig,U) - 0.2*step*np.linalg.norm(grad)**2 + 1e-8):
                break
            else:
                step = step / 2

        U = U_new
        step = step * 1.1
        #print("cost function :", k_it, cost(Y,sig,U))
    
    return U

# Generate problem
n = 20
m = 30
xsol, Omega, y = generate_problem(n,m)
Y = cost_matrix(n,y,Omega)

# Solve Burer-Monteiro factorization
p = 2
sig = 10*m/n
U = BM_GD(Y,sig,p)

# Reconstruct solution of the original problem
x = U_to_sol(U)
x = x * x[0] * xsol[0] # Flip sign to avoid x = -xsol

# Print results
print("Linear part of the cost: ",np.sum(-(Y@U)*U))
print("Average distance between row norm and 1: ", \
      np.sum(np.abs(np.sqrt(np.sum(U**2,axis=1)) - 1)) / n)
if (np.linalg.norm(x-xsol) < 1):
    print("xsol exactly recovered.")
else:
    print("xsol not recovered.")

#print("soluce : ",x,xsol)