legendre_decomp.module_mba ========================== .. py:module:: legendre_decomp.module_mba .. autoapi-nested-parse:: CUDA-enabled LegendreDecomposition calculations Functions --------- .. autoapisummary:: legendre_decomp.module_mba.xp_get legendre_decomp.module_mba.kl legendre_decomp.module_mba.get_eta legendre_decomp.module_mba.get_h legendre_decomp.module_mba.get_q legendre_decomp.module_mba.get_slice legendre_decomp.module_mba.initialize_theta legendre_decomp.module_mba.compute_G legendre_decomp.module_mba.LD_MBA legendre_decomp.module_mba.get_weight legendre_decomp.module_mba.compute_nbody legendre_decomp.module_mba.recons_nbody Module Contents --------------- .. py:function:: xp_get(val) .. py:function:: kl(P, Q, xp = cp) Kullback-Leibler divergence. :param P: P tensor :param Q: Q tensor :param xp: Array module, either numpy (CPU) or cupy :type xp: ModuleType :returns: KL divergence. .. py:function:: get_eta(Q, D, xp = cp) Eta tensor. :param Q: Q tensor :param D: Dimensionality :param xp: Array module, either numpy (CPU) or cupy :type xp: ModuleType :returns: Eta tensor. .. py:function:: get_h(theta, D, xp = cp) H tensor. :param theta: Theta tensor :param D: Dimensionality :param xp: Array module, either numpy (CPU) or cupy :type xp: ModuleType :returns: Updated theta. .. py:function:: get_q(h, gpu=True, xp = cp) Q tensor. :param H: H tensor :returns: Updated Q. .. py:function:: get_slice(key, D) .. py:function:: initialize_theta(keys, shape, theta0_flag=False, xp = cp) .. py:function:: compute_G(eta, mask, xp = cp) .. py:function:: LD_MBA(X, I = None, order = 2, n_iter = 100, lr = 1.0, eps = 1e-05, error_tol = 1e-05, init_theta = None, init_theta_mask = None, ngd = True, ngd_lstsq=True, verbose = True, gpu = True, dtype = None) Compute many-body tensor approximation. :param X: Input tensor. :param I: A list of pairs of indices that represent slices with nonzero elements in the parameter tensor. e.g. [(0,1),(2,),(1,3)] :param n_iter: Maximum number of iteration. :param lr: Learning rate. :param eps: (see paper). :param error_tol: KL divergence tolerance for the iteration. :param ngd: Use natural gradient. :param verbose: Print debug messages. :returns: KL divergence history. scaleX: Scaled X tensor. Q: Q tensor. theta: Theta. :rtype: all_history_kl .. py:function:: get_weight(shape, I_x=None, order=2, xp = cp) .. py:function:: compute_nbody(theta, shape, I_x=None, order=2, dtype=None, gpu=True, verbose=False) .. py:function:: recons_nbody(X_out, D, rescale=True, dtype=None, gpu=True)