# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Tests for the photometry module.
"""

from astropy.convolution.utils import discretize_model
from astropy.modeling import Fittable2DModel, Parameter
from astropy.modeling.fitting import LevMarLSQFitter, SimplexLSQFitter
from astropy.modeling.models import Gaussian2D, Moffat2D
from astropy.stats import SigmaClip, gaussian_sigma_to_fwhm
from astropy.table import Table
from astropy.tests.helper import catch_warnings
from astropy.utils.exceptions import AstropyUserWarning
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
import pytest

from ..groupstars import DAOGroup
from ..models import IntegratedGaussianPRF, FittableImageModel
from ..photometry import (BasicPSFPhotometry, DAOPhotPSFPhotometry,
                          IterativelySubtractedPSFPhotometry)
from ..sandbox import DiscretePRF
from ..utils import prepare_psf_model
from ...background import MMMBackground, StdBackgroundRMS
from ...datasets import make_gaussian_prf_sources_image, make_noise_image
from ...detection import DAOStarFinder

try:
    import scipy  # noqa
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False


def make_psf_photometry_objs(std=1, sigma_psf=1):
    """
    Produces baseline photometry objects which are then
    modified as-needed in specific tests below
    """

    daofind = DAOStarFinder(threshold=5.0 * std,
                            fwhm=sigma_psf * gaussian_sigma_to_fwhm)
    daogroup = DAOGroup(1.5 * sigma_psf * gaussian_sigma_to_fwhm)
    threshold = 5. * std
    fwhm = sigma_psf * gaussian_sigma_to_fwhm
    crit_separation = 1.5 * sigma_psf * gaussian_sigma_to_fwhm

    daofind = DAOStarFinder(threshold=threshold, fwhm=fwhm)
    daogroup = DAOGroup(crit_separation)
    mode_bkg = MMMBackground()
    psf_model = IntegratedGaussianPRF(sigma=sigma_psf)
    fitter = LevMarLSQFitter()

    basic_phot_obj = BasicPSFPhotometry(finder=daofind,
                                        group_maker=daogroup,
                                        bkg_estimator=mode_bkg,
                                        psf_model=psf_model,
                                        fitter=fitter,
                                        fitshape=(11, 11))

    iter_phot_obj = IterativelySubtractedPSFPhotometry(finder=daofind,
                                                       group_maker=daogroup,
                                                       bkg_estimator=mode_bkg,
                                                       psf_model=psf_model,
                                                       fitter=fitter, niters=1,
                                                       fitshape=(11, 11))

    dao_phot_obj = DAOPhotPSFPhotometry(crit_separation=crit_separation,
                                        threshold=threshold, fwhm=fwhm,
                                        psf_model=psf_model, fitshape=(11, 11),
                                        niters=1)

    return (basic_phot_obj, iter_phot_obj, dao_phot_obj)


sigma_psfs = []

# A group of two overlapped stars and an isolated one
sigma_psfs.append(2)
sources1 = Table()
sources1['flux'] = [800, 1000, 1200]
sources1['x_0'] = [13, 18, 25]
sources1['y_0'] = [16, 16, 25]
sources1['sigma'] = [sigma_psfs[-1]] * 3
sources1['theta'] = [0] * 3
sources1['id'] = [1, 2, 3]
sources1['group_id'] = [1, 1, 2]


# one single group with four stars.
sigma_psfs.append(2)
sources2 = Table()
sources2['flux'] = [700, 800, 700, 800]
sources2['x_0'] = [12, 17, 12, 17]
sources2['y_0'] = [15, 15, 20, 20]
sources2['sigma'] = [sigma_psfs[-1]] * 4
sources2['theta'] = [0] * 4
sources2['id'] = [1, 2, 3, 4]
sources2['group_id'] = [1, 1, 1, 1]

# one faint star and one brither companion
# although they are in the same group, the detection algorithm
# is not able to detect the fainter star, hence photometry should
# be performed with niters > 1 or niters=None
sigma_psfs.append(2)
sources3 = Table()
sources3['flux'] = [10000, 1000]
sources3['x_0'] = [18, 13]
sources3['y_0'] = [17, 19]
sources3['sigma'] = [sigma_psfs[-1]] * 2
sources3['theta'] = [0] * 2
sources3['id'] = [1] * 2
sources3['group_id'] = [1] * 2
sources3['iter_detected'] = [1, 2]


@pytest.mark.skipif('not HAS_SCIPY')
@pytest.mark.parametrize("sigma_psf, sources", [(sigma_psfs[2], sources3)])
def test_psf_photometry_niters(sigma_psf, sources):
    img_shape = (32, 32)
    # generate image with read-out noise (Gaussian) and
    # background noise (Poisson)
    image = (make_gaussian_prf_sources_image(img_shape, sources) +
             make_noise_image(img_shape, distribution='poisson', mean=6.,
                              seed=0) +
             make_noise_image(img_shape, distribution='gaussian', mean=0.,
                              stddev=2., seed=0))
    cp_image = image.copy()
    sigma_clip = SigmaClip(sigma=3.)
    bkgrms = StdBackgroundRMS(sigma_clip)
    std = bkgrms(image)

    phot_obj = make_psf_photometry_objs(std, sigma_psf)[1:3]
    for iter_phot_obj in phot_obj:
        iter_phot_obj.niters = None

        result_tab = iter_phot_obj(image)
        residual_image = iter_phot_obj.get_residual_image()

        assert (result_tab['x_0_unc'] < 1.96 * sigma_psf /
                np.sqrt(sources['flux'])).all()
        assert (result_tab['y_0_unc'] < 1.96 * sigma_psf /
                np.sqrt(sources['flux'])).all()
        assert (result_tab['flux_unc'] < 1.96 *
                np.sqrt(sources['flux'])).all()

        assert_allclose(result_tab['x_fit'], sources['x_0'], rtol=1e-1)
        assert_allclose(result_tab['y_fit'], sources['y_0'], rtol=1e-1)
        assert_allclose(result_tab['flux_fit'], sources['flux'], rtol=1e-1)
        assert_array_equal(result_tab['id'], sources['id'])
        assert_array_equal(result_tab['group_id'], sources['group_id'])
        assert_array_equal(result_tab['iter_detected'],
                           sources['iter_detected'])
        assert_allclose(np.mean(residual_image), 0.0, atol=1e1)

        # make sure image is note overwritten
        assert_array_equal(cp_image, image)


@pytest.mark.skipif('not HAS_SCIPY')
@pytest.mark.parametrize("sigma_psf, sources",
                         [(sigma_psfs[0], sources1),
                          (sigma_psfs[1], sources2),
                          # these ensure that the test *fails* if the model
                          # PSFs are the wrong shape
                          pytest.param(sigma_psfs[0] / 1.2, sources1,
                                       marks=pytest.mark.xfail()),
                          pytest.param(sigma_psfs[1] * 1.2, sources2,
                                       marks=pytest.mark.xfail())])
def test_psf_photometry_oneiter(sigma_psf, sources):
    """
    Tests in an image with a group of two overlapped stars and an
    isolated one.
    """

    img_shape = (32, 32)
    # generate image with read-out noise (Gaussian) and
    # background noise (Poisson)
    image = (make_gaussian_prf_sources_image(img_shape, sources) +
             make_noise_image(img_shape, distribution='poisson', mean=6.,
                              seed=0) +
             make_noise_image(img_shape, distribution='gaussian', mean=0.,
                              stddev=2., seed=0))
    cp_image = image.copy()

    sigma_clip = SigmaClip(sigma=3.)
    bkgrms = StdBackgroundRMS(sigma_clip)
    std = bkgrms(image)
    phot_objs = make_psf_photometry_objs(std, sigma_psf)

    for phot_proc in phot_objs:
        result_tab = phot_proc(image)
        residual_image = phot_proc.get_residual_image()
        assert (result_tab['x_0_unc'] < 1.96 * sigma_psf /
                np.sqrt(sources['flux'])).all()
        assert (result_tab['y_0_unc'] < 1.96 * sigma_psf /
                np.sqrt(sources['flux'])).all()
        assert (result_tab['flux_unc'] < 1.96 *
                np.sqrt(sources['flux'])).all()
        assert_allclose(result_tab['x_fit'], sources['x_0'], rtol=1e-1)
        assert_allclose(result_tab['y_fit'], sources['y_0'], rtol=1e-1)
        assert_allclose(result_tab['flux_fit'], sources['flux'], rtol=1e-1)
        assert_array_equal(result_tab['id'], sources['id'])
        assert_array_equal(result_tab['group_id'], sources['group_id'])
        assert_allclose(np.mean(residual_image), 0.0, atol=1e1)

        # test fixed photometry
        phot_proc.psf_model.x_0.fixed = True
        phot_proc.psf_model.y_0.fixed = True

        pos = Table(names=['x_0', 'y_0'], data=[sources['x_0'],
                                                sources['y_0']])
        cp_pos = pos.copy()

        result_tab = phot_proc(image, pos)
        residual_image = phot_proc.get_residual_image()
        assert 'x_0_unc' not in result_tab.colnames
        assert 'y_0_unc' not in result_tab.colnames
        assert (result_tab['flux_unc'] < 1.96 *
                np.sqrt(sources['flux'])).all()
        assert_array_equal(result_tab['x_fit'], sources['x_0'])
        assert_array_equal(result_tab['y_fit'], sources['y_0'])
        assert_allclose(result_tab['flux_fit'], sources['flux'], rtol=1e-1)
        assert_array_equal(result_tab['id'], sources['id'])
        assert_array_equal(result_tab['group_id'], sources['group_id'])
        assert_allclose(np.mean(residual_image), 0.0, atol=1e1)

        # make sure image is not overwritten
        assert_array_equal(cp_image, image)

        # make sure initial guess table is not modified
        assert_array_equal(cp_pos, pos)

        # resets fixed positions
        phot_proc.psf_model.x_0.fixed = False
        phot_proc.psf_model.y_0.fixed = False


@pytest.mark.skipif('not HAS_SCIPY')
def test_niters_errors():
    iter_phot_obj = make_psf_photometry_objs()[1]

    # tests that niters is set to an integer even if the user inputs
    # a float
    iter_phot_obj.niters = 1.1
    assert_equal(iter_phot_obj.niters, 1)

    # test that a ValueError is raised if niters <= 0
    with pytest.raises(ValueError):
        iter_phot_obj.niters = 0

    # test that it's OK to set niters to None
    iter_phot_obj.niters = None


@pytest.mark.skipif('not HAS_SCIPY')
def test_fitshape_errors():
    basic_phot_obj = make_psf_photometry_objs()[0]

    # first make sure setting to a scalar does the right thing (and makes
    # no errors)
    basic_phot_obj.fitshape = 11
    assert np.all(basic_phot_obj.fitshape == (11, 11))

    # test that a ValuError is raised if fitshape has even components
    with pytest.raises(ValueError):
        basic_phot_obj.fitshape = (2, 2)
    with pytest.raises(ValueError):
        basic_phot_obj.fitshape = 2

    # test that a ValueError is raised if fitshape has non positive
    # components
    with pytest.raises(ValueError):
        basic_phot_obj.fitshape = (-1, 0)

    # test that a ValueError is raised if fitshape has more than two
    # dimensions
    with pytest.raises(ValueError):
        basic_phot_obj.fitshape = (3, 3, 3)


@pytest.mark.skipif('not HAS_SCIPY')
def test_aperture_radius_errors():
    basic_phot_obj = make_psf_photometry_objs()[0]

    # test that aperture_radius was set to None by default
    assert_equal(basic_phot_obj.aperture_radius, None)

    # test that a ValueError is raised if aperture_radius is non positive
    with pytest.raises(ValueError):
        basic_phot_obj.aperture_radius = -3


@pytest.mark.skipif('not HAS_SCIPY')
def test_finder_errors():
    iter_phot_obj = make_psf_photometry_objs()[1]

    with pytest.raises(ValueError):
        iter_phot_obj.finder = None

    with pytest.raises(ValueError):
        iter_phot_obj = IterativelySubtractedPSFPhotometry(
            finder=None, group_maker=DAOGroup(1),
            bkg_estimator=MMMBackground(),
            psf_model=IntegratedGaussianPRF(1), fitshape=(11, 11))


@pytest.mark.skipif('not HAS_SCIPY')
def test_finder_positions_warning():
    basic_phot_obj = make_psf_photometry_objs(sigma_psf=2)[0]
    positions = Table()
    positions['x_0'] = [12.8, 18.2, 25.3]
    positions['y_0'] = [15.7, 16.5, 25.1]

    image = (make_gaussian_prf_sources_image((32, 32), sources1) +
             make_noise_image((32, 32), distribution='poisson', mean=6.,
                              seed=0))

    with catch_warnings(AstropyUserWarning):
        result_tab = basic_phot_obj(image=image, init_guesses=positions)
        assert_array_equal(result_tab['x_0'], positions['x_0'])
        assert_array_equal(result_tab['y_0'], positions['y_0'])
        assert_allclose(result_tab['x_fit'], positions['x_0'], rtol=1e-1)
        assert_allclose(result_tab['y_fit'], positions['y_0'], rtol=1e-1)

    with pytest.raises(ValueError):
        basic_phot_obj.finder = None
        result_tab = basic_phot_obj(image=image)


@pytest.mark.skipif('not HAS_SCIPY')
def test_aperture_radius():
    img_shape = (32, 32)

    # generate image with read-out noise (Gaussian) and
    # background noise (Poisson)
    image = (make_gaussian_prf_sources_image(img_shape, sources1) +
             make_noise_image(img_shape, distribution='poisson', mean=6.,
                              seed=0) +
             make_noise_image(img_shape, distribution='gaussian', mean=0.,
                              stddev=2., seed=0))

    basic_phot_obj = make_psf_photometry_objs()[0]

    # test that aperture radius is properly set whenever the PSF model has
    # a `fwhm` attribute
    class PSFModelWithFWHM(Fittable2DModel):
        x_0 = Parameter(default=1)
        y_0 = Parameter(default=1)
        flux = Parameter(default=1)
        fwhm = Parameter(default=5)

        def __init__(self, fwhm=fwhm.default):
            super().__init__(fwhm=fwhm)

        def evaluate(self, x, y, x_0, y_0, flux, fwhm):
            return flux / (fwhm * (x - x_0)**2 * (y - y_0)**2)

    psf_model = PSFModelWithFWHM()
    basic_phot_obj.psf_model = psf_model
    basic_phot_obj(image)
    assert_equal(basic_phot_obj.aperture_radius, psf_model.fwhm.value)


PARS_TO_SET_0 = {'x_0': 'x_0', 'y_0': 'y_0', 'flux_0': 'flux'}
PARS_TO_OUTPUT_0 = {'x_fit': 'x_0', 'y_fit': 'y_0', 'flux_fit': 'flux'}
PARS_TO_SET_1 = PARS_TO_SET_0.copy()
PARS_TO_SET_1['sigma_0'] = 'sigma'
PARS_TO_OUTPUT_1 = PARS_TO_OUTPUT_0.copy()
PARS_TO_OUTPUT_1['sigma_fit'] = 'sigma'


@pytest.mark.parametrize("actual_pars_to_set, actual_pars_to_output,"
                         "is_sigma_fixed", [(PARS_TO_SET_0, PARS_TO_OUTPUT_0,
                                             True),
                                            (PARS_TO_SET_1, PARS_TO_OUTPUT_1,
                                             False)])
@pytest.mark.skipif('not HAS_SCIPY')
def test_define_fit_param_names(actual_pars_to_set, actual_pars_to_output,
                                is_sigma_fixed):
    psf_model = IntegratedGaussianPRF()
    psf_model.sigma.fixed = is_sigma_fixed

    basic_phot_obj = make_psf_photometry_objs()[0]
    basic_phot_obj.psf_model = psf_model

    basic_phot_obj._define_fit_param_names()
    assert_equal(basic_phot_obj._pars_to_set, actual_pars_to_set)
    assert_equal(basic_phot_obj._pars_to_output, actual_pars_to_output)


# tests previously written to psf_photometry

PSF_SIZE = 11
GAUSSIAN_WIDTH = 1.
IMAGE_SIZE = 101

# Position and FLUXES of test sources
INTAB = Table([[50., 23, 12, 86], [50., 83, 80, 84],
               [np.pi * 10, 3.654, 20., 80 / np.sqrt(3)]],
              names=['x_0', 'y_0', 'flux_0'])

# Create test psf
psf_model = Gaussian2D(1. / (2 * np.pi * GAUSSIAN_WIDTH ** 2), PSF_SIZE // 2,
                       PSF_SIZE // 2, GAUSSIAN_WIDTH, GAUSSIAN_WIDTH)
test_psf = discretize_model(psf_model, (0, PSF_SIZE), (0, PSF_SIZE),
                            mode='oversample')

# Set up grid for test image
image = np.zeros((IMAGE_SIZE, IMAGE_SIZE))

# Add sources to test image
for x, y, flux in INTAB:
    model = Gaussian2D(flux / (2 * np.pi * GAUSSIAN_WIDTH ** 2),
                       x, y, GAUSSIAN_WIDTH, GAUSSIAN_WIDTH)
    image += discretize_model(model, (0, IMAGE_SIZE), (0, IMAGE_SIZE),
                              mode='oversample')

# Some tests require an image with wider sources.
WIDE_GAUSSIAN_WIDTH = 3.
WIDE_INTAB = Table([[50, 23.2], [50.5, 1], [10, 20]],
                   names=['x_0', 'y_0', 'flux_0'])
wide_image = np.zeros((IMAGE_SIZE, IMAGE_SIZE))

# Add sources to test image
for x, y, flux in WIDE_INTAB:
    model = Gaussian2D(flux / (2 * np.pi * WIDE_GAUSSIAN_WIDTH ** 2),
                       x, y, WIDE_GAUSSIAN_WIDTH, WIDE_GAUSSIAN_WIDTH)
    wide_image += discretize_model(model, (0, IMAGE_SIZE), (0, IMAGE_SIZE),
                                   mode='oversample')


@pytest.mark.skipif('not HAS_SCIPY')
def test_psf_photometry_discrete():
    """ Test psf_photometry with discrete PRF model. """

    prf = DiscretePRF(test_psf, subsampling=1)
    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=prf,
                                    fitshape=7)
    f = basic_phot(image=image, init_guesses=INTAB)

    for n in ['x', 'y', 'flux']:
        assert_allclose(f[n + '_0'], f[n + '_fit'], rtol=1e-6)


@pytest.mark.skipif('not HAS_SCIPY')
def test_tune_coordinates():
    """
    Test psf_photometry with discrete PRF model and coordinates that need
    to be adjusted in the fit.
    """

    prf = DiscretePRF(test_psf, subsampling=1)
    prf.x_0.fixed = False
    prf.y_0.fixed = False
    # Shift all sources by 0.3 pixels
    intab = INTAB.copy()
    intab['x_0'] += 0.3

    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=prf,
                                    fitshape=7)

    f = basic_phot(image=image, init_guesses=intab)
    for n in ['x', 'y', 'flux']:
        assert_allclose(f[n + '_0'], f[n + '_fit'], rtol=1e-3)


@pytest.mark.skipif('not HAS_SCIPY')
def test_psf_boundary():
    """
    Test psf_photometry with discrete PRF model at the boundary of the data.
    """

    prf = DiscretePRF(test_psf, subsampling=1)

    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=prf,
                                    fitshape=7, aperture_radius=5.5)

    intab = Table(data=[[1], [1]], names=['x_0', 'y_0'])
    f = basic_phot(image=image, init_guesses=intab)
    assert_allclose(f['flux_fit'], 0, atol=1e-8)


@pytest.mark.skipif('not HAS_SCIPY')
def test_aperture_radius_value_error():
    """
    Test that a ValueError is raised for tabular PSF models when
    aperture_radius is not defined.
    """

    prf = DiscretePRF(test_psf, subsampling=1)

    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=prf,
                                    fitshape=7)

    with pytest.raises(ValueError):
        basic_phot(image=image)

    # with initial guesses, but without a "flux_0" column
    intab = Table(data=[[1], [1]], names=['x_0', 'y_0'])
    with pytest.raises(ValueError):
        basic_phot(image=image, init_guesses=intab)


@pytest.mark.skipif('not HAS_SCIPY')
def test_psf_boundary_gaussian():
    """
    Test psf_photometry with discrete PRF model at the boundary of the data.
    """

    psf = IntegratedGaussianPRF(GAUSSIAN_WIDTH)

    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=psf,
                                    fitshape=7)

    intab = Table(data=[[1], [1]], names=['x_0', 'y_0'])
    f = basic_phot(image=image, init_guesses=intab)
    assert_allclose(f['flux_fit'], 0, atol=1e-8)


@pytest.mark.skipif('not HAS_SCIPY')
def test_psf_photometry_gaussian():
    """
    Test psf_photometry with Gaussian PSF model.
    """

    psf = IntegratedGaussianPRF(sigma=GAUSSIAN_WIDTH)

    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=psf,
                                    fitshape=7)
    f = basic_phot(image=image, init_guesses=INTAB)
    for n in ['x', 'y', 'flux']:
        assert_allclose(f[n + '_0'], f[n + '_fit'], rtol=1e-3)


@pytest.mark.skipif('not HAS_SCIPY')
@pytest.mark.parametrize("renormalize_psf", (True, False))
def test_psf_photometry_gaussian2(renormalize_psf):
    """
    Test psf_photometry with Gaussian PSF model from Astropy.
    """

    psf = Gaussian2D(1. / (2 * np.pi * GAUSSIAN_WIDTH ** 2), PSF_SIZE // 2,
                     PSF_SIZE // 2, GAUSSIAN_WIDTH, GAUSSIAN_WIDTH)
    psf = prepare_psf_model(psf, xname='x_mean', yname='y_mean',
                            renormalize_psf=renormalize_psf)

    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=psf,
                                    fitshape=7)
    f = basic_phot(image=image, init_guesses=INTAB)

    for n in ['x', 'y']:
        assert_allclose(f[n + '_0'], f[n + '_fit'], rtol=1e-1)
    assert_allclose(f['flux_0'], f['flux_fit'], rtol=1e-1)


@pytest.mark.skipif('not HAS_SCIPY')
def test_psf_photometry_moffat():
    """
    Test psf_photometry with Moffat PSF model from Astropy.
    """

    psf = Moffat2D(1. / (2 * np.pi * GAUSSIAN_WIDTH ** 2), PSF_SIZE // 2,
                   PSF_SIZE // 2, 1, 1)
    psf = prepare_psf_model(psf, xname='x_0', yname='y_0',
                            renormalize_psf=False)

    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=psf,
                                    fitshape=7)
    f = basic_phot(image=image, init_guesses=INTAB)
    f.pprint(max_width=-1)

    for n in ['x', 'y']:
        assert_allclose(f[n + '_0'], f[n + '_fit'], rtol=1e-3)
    # image was created with a gaussian, so flux won't match exactly
    assert_allclose(f['flux_0'], f['flux_fit'], rtol=1e-1)


@pytest.mark.skipif('not HAS_SCIPY')
def test_psf_fitting_data_on_edge():
    """
    No mask is input explicitly here, but source 2 is so close to the
    edge that the subarray that's extracted gets a mask internally.
    """

    psf_guess = IntegratedGaussianPRF(flux=1, sigma=WIDE_GAUSSIAN_WIDTH)
    psf_guess.flux.fixed = psf_guess.x_0.fixed = psf_guess.y_0.fixed = False
    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=psf_guess,
                                    fitshape=7)

    outtab = basic_phot(image=wide_image, init_guesses=WIDE_INTAB)

    for n in ['x', 'y', 'flux']:
        assert_allclose(outtab[n + '_0'], outtab[n + '_fit'],
                        rtol=0.05, atol=0.1)


@pytest.mark.skipif('not HAS_SCIPY')
@pytest.mark.parametrize("sigma_psf, sources", [(sigma_psfs[2], sources3)])
def test_psf_extra_output_cols(sigma_psf, sources):
    """
    Test the handling of a non-None extra_output_cols
    """

    psf_model = IntegratedGaussianPRF(sigma=sigma_psf)
    tshape = (32, 32)
    image = (make_gaussian_prf_sources_image(tshape, sources) +
             make_noise_image(tshape, distribution='poisson', mean=6.,
                              seed=0) +
             make_noise_image(tshape, distribution='gaussian', mean=0.,
                              stddev=2., seed=0))

    init_guess1 = None
    init_guess2 = Table(names=['x_0', 'y_0', 'sharpness', 'roundness1',
                               'roundness2'],
                        data=[[17.4], [16], [0.4], [0], [0]])
    init_guess3 = Table(names=['x_0', 'y_0'],
                        data=[[17.4], [16]])
    init_guess4 = Table(names=['x_0', 'y_0', 'sharpness'],
                        data=[[17.4], [16], [0.4]])
    for i, init_guesses in enumerate([init_guess1, init_guess2, init_guess3,
                                      init_guess4]):
        dao_phot = DAOPhotPSFPhotometry(crit_separation=8, threshold=40,
                                        fwhm=4 * np.sqrt(2 * np.log(2)),
                                        psf_model=psf_model, fitshape=(11, 11),
                                        extra_output_cols=['sharpness',
                                                           'roundness1',
                                                           'roundness2'])
        phot_results = dao_phot(image, init_guesses=init_guesses)
        # test that the original required columns are also passed back, as well
        # as extra_output_cols
        assert np.all([name in phot_results.colnames for name in
                       ['x_0', 'y_0']])
        assert np.all([name in phot_results.colnames for name in
                       ['sharpness', 'roundness1', 'roundness2']])
        assert len(phot_results) == 2
        # checks to verify that half-passing init_guesses results in NaN output
        # for extra_output_cols not passed as initial guesses
        if i == 2:  # init_guess3
            assert(np.all(np.all(np.isnan(phot_results[o])) for o in
                   ['sharpness', 'roundness1', 'roundness2']))
        if i == 3:  # init_guess4
            assert(np.all(np.all(np.isnan(phot_results[o])) for o in
                   ['roundness1', 'roundness2']))
            assert(np.all(~np.isnan(phot_results['sharpness'])))


@pytest.mark.skipif('not HAS_SCIPY')
def test_finder_return_none():
    """
    Test psf_photometry with finder that does not return None if no
    sources are detected, to test Iterative PSF fitting.
    """
    def tophatfinder(image):
        """ Simple top hat finder function for use with a top hat PRF"""
        fluxes = np.unique(image[image > 1])
        table = Table(names=['id', 'xcentroid', 'ycentroid', 'flux'],
                      dtype=[int, float, float, float])
        for n, f in enumerate(fluxes):
            ys, xs = np.where(image == f)
            x = np.mean(xs)
            y = np.mean(ys)
            table.add_row([int(n + 1), x, y, f * 9])
        table.sort(['flux'])

        return table

    prf = np.zeros((7, 7), float)
    prf[2:5, 2:5] = 1 / 9
    prf = FittableImageModel(prf)

    img = np.zeros((50, 50), float)
    x0 = [38, 20, 35]
    y0 = [20, 5, 40]
    f0 = [50, 100, 200]
    for x, y, f in zip(x0, y0, f0):
        img[y - 1:y + 2, x - 1:x + 2] = f / 9

    intab = Table(data=[[37, 19.6, 34.9], [19.6, 4.5, 40.1], [45, 103, 210]],
                  names=['x_0', 'y_0', 'flux_0'])

    iter_phot = IterativelySubtractedPSFPhotometry(finder=tophatfinder,
                                                   group_maker=DAOGroup(2),
                                                   bkg_estimator=None,
                                                   psf_model=prf,
                                                   fitshape=7, niters=2,
                                                   aperture_radius=3)

    results = iter_phot(image=img, init_guesses=intab)
    assert_allclose(results['flux_fit'], f0, rtol=0.05)


@pytest.mark.skipif('not HAS_SCIPY')
def test_psf_photometry_uncertainties():
    """
    Test an Astropy fitter that does not return a parameter
    covariance matrix (param_cov). The output table should not
    contain flux_unc, x_0_unc, and y_0_unc columns.
    """
    psf = IntegratedGaussianPRF(sigma=GAUSSIAN_WIDTH)

    basic_phot = BasicPSFPhotometry(group_maker=DAOGroup(2),
                                    bkg_estimator=None, psf_model=psf,
                                    fitter=SimplexLSQFitter(),
                                    fitshape=7)
    phot_tbl = basic_phot(image=image, init_guesses=INTAB)
    columns = ('flux_unc', 'x_0_unc', 'y_0_unc')
    for column in columns:
        assert column not in phot_tbl.colnames
