# Copyright (c) 2024 Radio Astronomy Software Group
# Licensed under the 2-clause BSD License
"""Code to apply flags to calibration or visibility data."""
import numpy as np
[docs]def apply_uvflag(
uvd, uvf, *, inplace=True, unflag_first=False, flag_missing=True, force_pol=True
):
"""
Apply flags from a UVFlag to a UVData instantiation.
Note that if uvf.Nfreqs or uvf.Ntimes is 1, it will broadcast flags across
that axis.
Parameters
----------
uvd : UVData object
UVData object to add flags to.
uvf : UVFlag object
A UVFlag object in flag mode.
inplace : bool
If True overwrite flags in uvd, otherwise return new object
unflag_first : bool
If True, completely unflag the UVData before applying flags.
Else, OR the inherent uvd flags with uvf flags.
flag_missing : bool
If input uvf is a baseline type and antpairs in uvd do not exist in uvf,
flag them in uvd. Otherwise leave them untouched.
force_pol : bool
If True, broadcast flags to all polarizations if they do not match.
Only works if uvf.Npols == 1.
Returns
-------
UVData
If not inplace, returns new UVData object with flags applied
"""
# assertions
if uvf.mode != "flag":
raise ValueError("UVFlag must be flag mode")
if not inplace:
uvd = uvd.copy()
# make a deepcopy by default b/c it is generally edited inplace downstream
uvf = uvf.copy()
# convert to baseline type
if uvf.type != "baseline":
# edits inplace
uvf.to_baseline(uvd, force_pol=force_pol)
else:
# make sure polarizations match or force_pol
uvd_pols, uvf_pols = (
uvd.polarization_array.tolist(),
uvf.polarization_array.tolist(),
)
if set(uvd_pols) != set(uvf_pols):
if uvf.Npols == 1 and force_pol:
# if uvf is 1pol we can make them match: also edits inplace
uvf.polarization_array = uvd.polarization_array
uvf.Npols = len(uvf.polarization_array)
uvf_pols = uvf.polarization_array.tolist()
else:
raise ValueError("Input uvf and uvd polarizations do not match")
# make sure polarization ordering is correct: also edits inplace
uvf.polarization_array = uvf.polarization_array[
[uvd_pols.index(pol) for pol in uvf_pols]
]
# check time and freq shapes match: if Ntimes or Nfreqs is 1, allow
# implicit broadcasting
if uvf.Ntimes == 1:
mismatch_times = False
elif uvf.Ntimes == uvd.Ntimes:
tdiff = np.unique(uvf.time_array) - np.unique(uvd.time_array)
mismatch_times = np.any(tdiff > np.max(np.abs(uvf._time_array.tols)))
else:
mismatch_times = True
if mismatch_times:
raise ValueError("UVFlag and UVData have mismatched time arrays.")
if uvf.Nfreqs == 1:
mismatch_freqs = False
elif uvf.Nfreqs == uvd.Nfreqs:
fdiff = np.unique(uvf.freq_array) - np.unique(uvd.freq_array)
mismatch_freqs = np.any(fdiff > np.max(np.abs(uvf._freq_array.tols)))
else:
mismatch_freqs = True
if mismatch_freqs:
raise ValueError("UVFlag and UVData have mismatched frequency arrays.")
# unflag if desired
if unflag_first:
uvd.flag_array[:] = False
# iterate over antpairs and apply flags: TODO need to be able to handle
# conjugated antpairs
uvf_antpairs = uvf.get_antpairs()
for ap in uvd.get_antpairs():
uvd_ap_inds = uvd.antpair2ind(ap)
if ap not in uvf_antpairs:
if flag_missing:
uvd.flag_array[uvd_ap_inds] = True
continue
uvf_ap_inds = uvf.antpair2ind(*ap)
# addition of boolean is OR
uvd.flag_array[uvd_ap_inds] += uvf.flag_array[uvf_ap_inds]
uvd.history += "\nFlagged with pyuvdata.utils.apply_uvflags."
if not inplace:
return uvd