from munkres import Munkres, DISALLOWED, UnsolvableMatrix
import munkres
import pytest

m = Munkres()

def _get_cost(matrix):
    indices = m.compute(matrix)
    return sum([matrix[row][column] for row, column in indices])

def test_documented_example():
    '''
    Test the matrix in the documented example.
    '''
    matrix = [[5, 9, 1],
              [10, 3, 2],
              [8, 7, 4]]
    cost = _get_cost(matrix)
    assert cost == 12

def float_example():
    '''
    Test a matrix with float values
    '''
    matrix = [[5.1, 9.2, 1.3],
              [10.4, 3.5, 2.6],
              [8.7, 7.8, 4.9]]
    cost = _get_cost(matrix)
    assert_almost_equal(cost, 13.5)

def test_5_x_5():
    matrix = [[12, 9, 27, 10, 23],
              [7, 13, 13, 30, 19],
              [25, 18, 26, 11, 26],
              [9, 28, 26, 23, 13],
              [16, 16, 24, 6, 9]]
    cost = _get_cost(matrix)
    assert cost == 51

def test_5_x_5_float():
    matrix = [[12.01, 9.02, 27.03, 10.04, 23.05],
              [7.06, 13.07, 13.08, 30.09, 19.1],
              [25.11, 18.12, 26.13, 11.14, 26.15],
              [9.16, 28.17, 26.18, 23.19, 13.2],
              [16.21, 16.22, 24.23, 6.24, 9.25]]
    cost = _get_cost(matrix)
    assert cost == pytest.approx(51.65)


def test_10_x_10():
    matrix = [[37, 34, 29, 26, 19, 8, 9, 23, 19, 29],
              [9, 28, 20, 8, 18, 20, 14, 33, 23, 14],
              [15, 26, 12, 28, 6, 17, 9, 13, 21, 7],
              [2, 8, 38, 36, 39, 5, 36, 2, 38, 27],
              [30, 3, 33, 16, 21, 39, 7, 23, 28, 36],
              [7, 5, 19, 22, 36, 36, 24, 19, 30, 2],
              [34, 20, 13, 36, 12, 33, 9, 10, 23, 5],
              [7, 37, 22, 39, 33, 39, 10, 3, 13, 26],
              [21, 25, 23, 39, 31, 37, 32, 33, 38, 1],
              [17, 34, 40, 10, 29, 37, 40, 3, 25, 3]]
    cost = _get_cost(matrix)
    assert cost == 66

def test_10_x_10_float():
    matrix = [[37.001, 34.002, 29.003, 26.004, 19.005, 8.006, 9.007, 23.008, 19.009, 29.01],
              [9.011, 28.012, 20.013, 8.014, 18.015, 20.016, 14.017, 33.018, 23.019, 14.02],
              [15.021, 26.022, 12.023, 28.024, 6.025, 17.026, 9.027, 13.028, 21.029, 7.03],
              [2.031, 8.032, 38.033, 36.034, 39.035, 5.036, 36.037, 2.038, 38.039, 27.04],
              [30.041, 3.042, 33.043, 16.044, 21.045, 39.046, 7.047, 23.048, 28.049, 36.05],
              [7.051, 5.052, 19.053, 22.054, 36.055, 36.056, 24.057, 19.058, 30.059, 2.06],
              [34.061, 20.062, 13.063, 36.064, 12.065, 33.066, 9.067, 10.068, 23.069, 5.07],
              [7.071, 37.072, 22.073, 39.074, 33.075, 39.076, 10.077, 3.078, 13.079, 26.08],
              [21.081, 25.082, 23.083, 39.084, 31.085, 37.086, 32.087, 33.088, 38.089, 1.09],
              [17.091, 34.092, 40.093, 10.094, 29.095, 37.096, 40.097, 3.098, 25.099, 3.1]]
    cost = _get_cost(matrix)
    assert cost == pytest.approx(66.505)

def test_20_x_20():
    matrix = [[5, 4, 3, 9, 8, 9, 3, 5, 6, 9, 4, 10, 3, 5, 6, 6, 1, 8, 10, 2],
              [10, 9, 9, 2, 8, 3, 9, 9, 10, 1, 7, 10, 8, 4, 2, 1, 4, 8, 4, 8],
              [10, 4, 4, 3, 1, 3, 5, 10, 6, 8, 6, 8, 4, 10, 7, 2, 4, 5, 1, 8],
              [2, 1, 4, 2, 3, 9, 3, 4, 7, 3, 4, 1, 3, 2, 9, 8, 6, 5, 7, 8],
              [3, 4, 4, 1, 4, 10, 1, 2, 6, 4, 5, 10, 2, 2, 3, 9, 10, 9, 9, 10],
              [1, 10, 1, 8, 1, 3, 1, 7, 1, 1, 2, 1, 2, 6, 3, 3, 4, 4, 8, 6],
              [1, 8, 7, 10, 10, 3, 4, 6, 1, 6, 6, 4, 9, 6, 9, 6, 4, 5, 4, 7],
              [8, 10, 3, 9, 4, 9, 3, 3, 4, 6, 4, 2, 6, 7, 7, 4, 4, 3, 4, 7],
              [1, 3, 8, 2, 6, 9, 2, 7, 4, 8, 10, 8, 10, 5, 1, 3, 10, 10, 2, 9],
              [2, 4, 1, 9, 2, 9, 7, 8, 2, 1, 4, 10, 5, 2, 7, 6, 5, 7, 2, 6],
              [4, 5, 1, 4, 2, 3, 3, 4, 1, 8, 8, 2, 6, 9, 5, 9, 6, 3, 9, 3],
              [3, 1, 1, 8, 6, 8, 8, 7, 9, 3, 2, 1, 8, 2, 4, 7, 3, 1, 2, 4],
              [5, 9, 8, 6, 10, 4, 10, 3, 4, 10, 10, 10, 1, 7, 8, 8, 7, 7, 8, 8],
              [1, 4, 6, 1, 6, 1, 2, 10, 5, 10, 2, 6, 2, 4, 5, 5, 3, 5, 1, 5],
              [5, 6, 9, 10, 6, 6, 10, 6, 4, 1, 5, 3, 9, 5, 2, 10, 9, 9, 5, 1],
              [10, 9, 4, 6, 9, 5, 3, 7, 10, 1, 6, 8, 1, 1, 10, 9, 5, 7, 7, 5],
              [2, 6, 6, 6, 6, 2, 9, 4, 7, 5, 3, 2, 10, 3, 4, 5, 10, 9, 1, 7],
              [5, 2, 4, 9, 8, 4, 8, 2, 4, 1, 3, 7, 6, 8, 1, 6, 8, 8, 10, 10],
              [9, 6, 3, 1, 8, 5, 7, 8, 7, 2, 1, 8, 2, 8, 3, 7, 4, 8, 7, 7],
              [8, 4, 4, 9, 7, 10, 6, 2, 1, 5, 8, 5, 1, 1, 1, 9, 1, 3, 5, 3]]
    cost = _get_cost(matrix)
    assert cost == 22

def test_20_x_20_float():
    matrix = [[5.0001, 4.0002, 3.0003, 9.0004, 8.0005, 9.0006, 3.0007, 5.0008, 6.0009, 9.001, 4.0011, 10.0012, 3.0013, 5.0014, 6.0015, 6.0016, 1.0017, 8.0018, 10.0019, 2.002],
              [10.0021, 9.0022, 9.0023, 2.0024, 8.0025, 3.0026, 9.0027, 9.0028, 10.0029, 1.003, 7.0031, 10.0032, 8.0033, 4.0034, 2.0035, 1.0036, 4.0037, 8.0038, 4.0039, 8.004],
              [10.0041, 4.0042, 4.0043, 3.0044, 1.0045, 3.0046, 5.0047, 10.0048, 6.0049, 8.005, 6.0051, 8.0052, 4.0053, 10.0054, 7.0055, 2.0056, 4.0057, 5.0058, 1.0059, 8.006],
              [2.0061, 1.0062, 4.0063, 2.0064, 3.0065, 9.0066, 3.0067, 4.0068, 7.0069, 3.007, 4.0071, 1.0072, 3.0073, 2.0074, 9.0075, 8.0076, 6.0077, 5.0078, 7.0079, 8.008],
              [3.0081, 4.0082, 4.0083, 1.0084, 4.0085, 10.0086, 1.0087, 2.0088, 6.0089, 4.009, 5.0091, 10.0092, 2.0093, 2.0094, 3.0095, 9.0096, 10.0097, 9.0098, 9.0099, 10.01],
              [1.0101, 10.0102, 1.0103, 8.0104, 1.0105, 3.0106, 1.0107, 7.0108, 1.0109, 1.011, 2.0111, 1.0112, 2.0113, 6.0114, 3.0115, 3.0116, 4.0117, 4.0118, 8.0119, 6.012],
              [1.0121, 8.0122, 7.0123, 10.0124, 10.0125, 3.0126, 4.0127, 6.0128, 1.0129, 6.013, 6.0131, 4.0132, 9.0133, 6.0134, 9.0135, 6.0136, 4.0137, 5.0138, 4.0139, 7.014],
              [8.0141, 10.0142, 3.0143, 9.0144, 4.0145, 9.0146, 3.0147, 3.0148, 4.0149, 6.015, 4.0151, 2.0152, 6.0153, 7.0154, 7.0155, 4.0156, 4.0157, 3.0158, 4.0159, 7.016],
              [1.0161, 3.0162, 8.0163, 2.0164, 6.0165, 9.0166, 2.0167, 7.0168, 4.0169, 8.017, 10.0171, 8.0172, 10.0173, 5.0174, 1.0175, 3.0176, 10.0177, 10.0178, 2.0179, 9.018],
              [2.0181, 4.0182, 1.0183, 9.0184, 2.0185, 9.0186, 7.0187, 8.0188, 2.0189, 1.019, 4.0191, 10.0192, 5.0193, 2.0194, 7.0195, 6.0196, 5.0197, 7.0198, 2.0199, 6.02],
              [4.0201, 5.0202, 1.0203, 4.0204, 2.0205, 3.0206, 3.0207, 4.0208, 1.0209, 8.021, 8.0211, 2.0212, 6.0213, 9.0214, 5.0215, 9.0216, 6.0217, 3.0218, 9.0219, 3.022],
              [3.0221, 1.0222, 1.0223, 8.0224, 6.0225, 8.0226, 8.0227, 7.0228, 9.0229, 3.023, 2.0231, 1.0232, 8.0233, 2.0234, 4.0235, 7.0236, 3.0237, 1.0238, 2.0239, 4.024],
              [5.0241, 9.0242, 8.0243, 6.0244, 10.0245, 4.0246, 10.0247, 3.0248, 4.0249, 10.025, 10.0251, 10.0252, 1.0253, 7.0254, 8.0255, 8.0256, 7.0257, 7.0258, 8.0259, 8.026],
              [1.0261, 4.0262, 6.0263, 1.0264, 6.0265, 1.0266, 2.0267, 10.0268, 5.0269, 10.027, 2.0271, 6.0272, 2.0273, 4.0274, 5.0275, 5.0276, 3.0277, 5.0278, 1.0279, 5.028],
              [5.0281, 6.0282, 9.0283, 10.0284, 6.0285, 6.0286, 10.0287, 6.0288, 4.0289, 1.029, 5.0291, 3.0292, 9.0293, 5.0294, 2.0295, 10.0296, 9.0297, 9.0298, 5.0299, 1.03],
              [10.0301, 9.0302, 4.0303, 6.0304, 9.0305, 5.0306, 3.0307, 7.0308, 10.0309, 1.031, 6.0311, 8, 1.0312, 1.0313, 10.0314, 9.0315, 5.0316, 7.0317, 7.0318, 5.0319, 1.032],
              [2.0321, 6.0322, 6.0323, 6.0324, 6.0325, 2.0326, 9.0327, 4.0328, 7.0329, 5.033, 3.0331, 2.0332, 10.0333, 3.0334, 4.0335, 5.0336, 10.0337, 9.0338, 1.0339, 7.034],
              [5.0341, 2.0342, 4.0343, 9.0344, 8.0345, 4.0346, 8.0347, 2.0348, 4.0349, 1.035, 3.0351, 7.0352, 6.0353, 8.0354, 1.0355, 6.0356, 8.0357, 8.0358, 10.0359, 10.036],
              [9.0361, 6.0362, 3.0363, 1.0364, 8.0365, 5.0366, 7.0367, 8.0368, 7.0369, 2.037, 1.0371, 8.0372, 2.0373, 8.0374, 3.0375, 7.0376, 4.0377, 8.0378, 7.0379, 7.038],
              [8.0381, 4.0382, 4.0383, 9.0384, 7.0385, 10.0386, 6.0387, 2.0388, 1.0389, 5.039, 8.0391, 5.0392, 1.0393, 1.0394, 1.0395, 9.0396, 1.0397, 3.0398, 5.0399, 3.04]]
    cost = _get_cost(matrix)
    '''
    Here, it becomes mandatory to set "places" argument, otherwise test might
    fails. It happens because float values in this example have more number of
    digits after decimal point than other float examples.
    '''
    assert cost == pytest.approx(20.362, rel=1e-3)

def test_disallowed():
    matrix = [[5, 9, DISALLOWED],
              [10, DISALLOWED, 2],
              [8, DISALLOWED, 4]]
    cost = _get_cost(matrix)
    assert cost == 19

def test_disallowed_float():
    matrix = [[5.1, 9.2, DISALLOWED],
              [10.3, DISALLOWED, 2.4],
              [8.5, DISALLOWED, 4.6]]
    cost = _get_cost(matrix)
    assert cost == pytest.approx(20.1)

def test_profit():
    profit_matrix = [[94, 66, 100, 18, 48],
                     [51, 63, 97, 79, 11],
                     [37, 53, 57, 78, 28],
                     [59, 43, 97, 88, 48],
                     [52, 19, 89, 60, 60]]
    import sys
    cost_matrix = munkres.make_cost_matrix(
        profit_matrix, lambda cost: sys.maxsize - cost
    )
    indices = m.compute(cost_matrix)
    profit = sum([profit_matrix[row][column] for row, column in indices])
    assert profit == 392

def test_profit_float():
    profit_matrix = [[94.01, 66.02, 100.03, 18.04, 48.05],
                     [51.06, 63.07, 97.08, 79.09, 11.1],
                     [37.11, 53.12, 57.13, 78.14, 28.15],
                     [59.16, 43.17, 97.18, 88.19, 48.2],
                     [52.21, 19.22, 89.23, 60.24, 60.25]]
    import sys
    cost_matrix = munkres.make_cost_matrix(
        profit_matrix, lambda cost: sys.maxsize - cost
    )
    indices = m.compute(cost_matrix)
    profit = sum([profit_matrix[row][column] for row, column in indices])
    assert profit == pytest.approx(362.65)

def test_irregular():
    matrix = [[12, 26, 17],
              [49, 43, 36, 10, 5],
              [97, 9, 66, 34],
              [52, 42, 19, 36],
              [15, 93, 55, 80]]

    cost = _get_cost(matrix)
    assert cost == 43

def test_irregular_float():
    matrix = [[12.01, 26.02, 17.03],
              [49.04, 43.05, 36.06, 10.07, 5.08],
              [97.09, 9.1, 66.11, 34.12],
              [52.13, 42.14, 19.15, 36.16],
              [15.17, 93.18, 55.19, 80.2]]

    cost = _get_cost(matrix)
    assert cost == pytest.approx(43.42)

def test_rectangular():
    matrix = [[34, 26, 17, 12],
              [43, 43, 36, 10],
              [97, 47, 66, 34],
              [52, 42, 19, 36],
              [15, 93, 55, 80]]

    padded_matrix = m.pad_matrix(matrix, 0)
    padded_cost = _get_cost(padded_matrix)
    cost = _get_cost(matrix)
    assert padded_cost == cost
    assert cost == 70

def test_rectangular_float():
    matrix = [[34.01, 26.02, 17.03, 12.04],
              [43.05, 43.06, 36.07, 10.08],
              [97.09, 47.1, 66.11, 34.12],
              [52.13, 42.14, 19.15, 36.16],
              [15.17, 93.18, 55.19, 80.2]]

    padded_matrix = m.pad_matrix(matrix, 0)
    padded_cost = _get_cost(padded_matrix)
    cost = _get_cost(matrix)
    assert padded_cost == pytest.approx(cost)
    assert cost == pytest.approx(70.42)

def test_unsolvable():
    with pytest.raises(UnsolvableMatrix):
        matrix = [[5, 9, DISALLOWED],
                [10, DISALLOWED, 2],
                [DISALLOWED, DISALLOWED, DISALLOWED]]
        m.compute(matrix)

def test_unsolvable_float():
    with pytest.raises(UnsolvableMatrix):
        matrix = [[5.1, 9.2, DISALLOWED],
                [10.3, DISALLOWED, 2.4],
                [DISALLOWED, DISALLOWED, DISALLOWED]]
        m.compute(matrix)
