Source code for breads.jwst_tools.reduction_utils

import os
import time
from glob import glob
from copy import copy
import fnmatch

import numpy as np
from scipy.stats import median_abs_deviation
from astropy.io import fits

import multiprocessing as mp
import matplotlib.pyplot as plt
import datetime
from scipy.ndimage import generic_filter, gaussian_filter
from scipy.ndimage import convolve1d
from scipy.interpolate import interp1d
import matplotlib.tri as tri

try:
    import jwst
    _HAS_OPTIONAL_DEPENDENCY_JWST = True

    from jwst.pipeline import Detector1Pipeline, Spec2Pipeline, Spec3Pipeline
    from jwst.associations import asn_from_list as afl  # Tools for creating association files
    from jwst.associations.lib.rules_level3_base import DMS_Level3_Base
except ImportError:
    _HAS_OPTIONAL_DEPENDENCY_JWST = False

from breads.instruments.instrument import Instrument
from breads.instruments.jwstnirspec_cal import JWSTNirspec_cal
from breads.instruments.jwst_IFUs import untangle_dq
from breads.instruments.jwst_IFUs import crop_trace_edges
from breads.instruments.jwst_IFUs import fitpsf
from breads.instruments.jwst_IFUs import get_contnorm_spec
from breads.instruments.jwst_IFUs import filter_big_triangles
from breads.instruments.jwstnirspec_multiple_cals import JWSTNirspec_multiple_cals
from breads.fit import fitfm
from breads.utils import get_spline_model
import breads.jwst_tools.plotting

from collections import defaultdict


###########################################################################
#                       JWST reduction tools
#
# This module contains utility functions for JWST reductions, particularly 
# for invoking the JWST pipeline with some customizations and additions for
# tuned for the kind of processing we want to do with breads. 
#
# Top-level function is "run_complete_stage1_2_clean_reduction"
# This invokes lower-level functions for:
#   run_stage1
#   run_noise_clean
#   run_stage2

[docs] def find_files_to_process(input_dir, filetype='uncal.fits', exp_numbers=None, verbose=True): """ Utility function to find files of a given type Parameters ---------- input_dir : str Input directory to search for files filetype : str Filename match pattern. Either a simple ending string like 'uncal.fits' or a more complex regular expression search pattern. This will be used to search the input directory for all FITS files matching this pattern. exp_numbers : list or ndarray of ints Optional list of exposure numbers. The list of files will be filtered to contain only this subset of exposure numbers. verbose : bool Be more verbose in outputs Returns ------- files : list of str List of filenames found in input_dir matching the filetype (and exp_numbers, if provided) """ search_pattern = filetype if filetype.startswith('jw') else "jw*_" + filetype files = glob(os.path.join(input_dir, search_pattern)) files.sort() if verbose: print(f"Searching in {input_dir} for files matching {search_pattern}") print('\tFound ' + str(len(files)) + ' input files to process') for file in files: print("\t" + os.path.basename(file)) if exp_numbers is not None: # Use fnmatch to filter only the wanted exposure numbers files = [f for f in files if any(fnmatch.fnmatch(os.path.basename(f), "jw*_*_{0:05d}_*".format(num)) for num in exp_numbers)] return files
def check_instrument_grating(uncal_files): """ Read the GRATING keyword for a list of uncal files, and verify that all have the same GRATING value. This is useful because forward modeling should be done for only one grating at a time, but some JWST observations may use multiple gratings in different activities within the observation Parameters ---------- uncal_files : list of str filenames Returns ------- grating : str GRATING keyword value """ gratings = [fits.getheader(f)['GRATING'] for f in uncal_files] gratings = set(gratings) if len(gratings) > 1: raise RuntimeError(f"The specified list of files contains multiple different GRATING values: {gratings}. " "Adjust your input file selection criteria to select files all with the same spectral" "grating value.") else: return list(gratings)[0] ########################################################################### # Functions for invoking the pipeline
[docs] def run_stage1(uncal_files, output_dir, overwrite=False, maximum_cores="all", save_plots=True): """ Run pipeline stage 1, with some customizations for reductions intended to be used with breads for IFU high contrast Currently only tested on NIRSpec IFU data For each input file, before doing any reduction, the expected output filename is inferred, and it checks whether that output file already exists. If so, then it is NOT reduced again by default. Set the overwrite flag to True to re-reduce files Parameters ---------- uncal_files : list of strings Filenames of uncal files to reduce output_dir : string Directory path for where to put the output files overwrite : bool Re-reduce and overwrite outputs, if these data were already reduced before? Default is to SKIP re-reducing anything already reduced. maximum_cores : string Passed to JWST pipeline functions that use multiprocessing, such as ramp fit """ from jwst.pipeline import Detector1Pipeline if not os.path.exists(output_dir): os.makedirs(output_dir) time0 = time.perf_counter() rate_files = [] for i, file in enumerate(uncal_files): print(f"Stage 1 Processing file {i + 1} of {len(uncal_files)}.") outname = os.path.join(output_dir, os.path.basename(file).replace('uncal.fits', 'rate.fits')) rate_files.append(outname) if os.path.exists(outname) and not overwrite: print(f"\tStage 1 Output file {os.path.basename(outname)} already exists in output dir;\n\tskipping {os.path.basename(file)}.") continue det1 = Detector1Pipeline() # Instantiate the pipeline # defining used pipeline steps # This version only shows the step parameters which are changes from defaults. step_parameters = { # group_scale - run with defaults # dq_init - run with defaults 'saturation': {'n_pix_grow_sat': 0}, # check for saturated pixels, but do not expand to adjacent pixels # ipc - run with defaults # superbias - run with defaults # linearity - run with defaults 'persistence': {'skip': True}, # This step does nothing; there are no nonzero parameters in the reference files yet # dark_current : run with defaults 'jump': {'maximum_cores': maximum_cores}, # parallelize 'ramp_fit': {'maximum_cores': maximum_cores}, # parallelize # gain_scale : run with defaults } det1.call(file, save_results=True, output_dir=output_dir, steps=step_parameters) # Print out the time benchmark time1 = time.perf_counter() print(f"\tStage 1 Runtime so far: {time1 - time0:0.4f} seconds") time1 = time.perf_counter() print(f"Stage 1 Total Runtime: {time1 - time0:0.4f} seconds") if save_plots: breads.jwst_tools.plotting.plot_2d_image_set(rate_files, output_dir = output_dir, suptitle="Stage 1 pipeline reduction results", plot_label = 'stage1') return rate_files
[docs] def run_stage2(rate_files, output_dir, skip_cubes=True, overwrite=False, TA=False, nsclean_skip=False, save_plots=True): """ Run pipeline stage 2, with some customizations for reductions intended to be used with breads for IFU high contrast Currently only tested on NIRSpec IFU data Parameters ---------- rate_files output_dir skip_cubes overwrite TA Returns ------- cal_files : list of str List of cal filenames produced by stage 2 """ from jwst.pipeline import Spec2Pipeline # We need to check that the desired output directories exist, and if not create them if not os.path.exists(output_dir): os.makedirs(output_dir) # Start a timer to keep track of runtime time0 = time.perf_counter() cal_files = [] for fid, rate_file in enumerate(rate_files): print(f"Stage 2 Processing file {fid + 1} of {len(rate_files)}.") # Setting up steps and running the Spec2 portion of the pipeline. outname = os.path.join(output_dir, os.path.basename(rate_file).replace('rate.fits', 'cal.fits')) cal_files.append(outname) if os.path.exists(outname) and not overwrite: print(f"\tStage 2 Output file {os.path.basename(outname)} already exists in output dir;\n\tskipping {os.path.basename(rate_file)}.") continue spec2 = Spec2Pipeline() pathloss_skip = TA # For target acq images, skip the pathloss step, otherwise don't skip it. step_parameters = { # spec2.assign_wcs.skip = False # spec2.bkg_subtract.skip = False # spec2.imprint_subtract.skip = False # spec2.msa_flagging.skip = False # # spec2.srctype.source_type = 'POINT' # spec2.flat_field.skip = False # spec2.pathloss.skip = False 'pathloss':{'skip':pathloss_skip}, 'nsclean':{'skip':nsclean_skip}, # spec2.photom.skip = False 'cube_build': {'skip': skip_cubes}, # We do not want or need interpolated cubes 'extract_1d': {'skip': True}, # spec3.cube_build.coord_system = 'skyalign' # spec2.cube_build.coord_system='ifualign' } spec2.save_bsub = True # choose what results to save and from what steps spec2.call(rate_file, save_results=True, output_dir=output_dir, steps=step_parameters) # Print out the time benchmark time1 = time.perf_counter() print(f"\tStage 2 Runtime so far: {time1 - time0:0.4f} seconds") time1 = time.perf_counter() print(f"Stage 2 Total Runtime: {time1 - time0:0.4f} seconds") if save_plots: breads.jwst_tools.plotting.plot_2d_image_set(cal_files, output_dir = output_dir, suptitle="Stage 2 pipeline reduction results", plot_label = 'stage2') return cal_files
########################################################################### # Function for centroid calibration # # Definition of the wavelength sampling on which the detector images are interpolated (for each detector) # wv_for_cent_calib_dict = {} # #"G140H","G235H","G395H" # wv_for_cent_calib_dict["G140H nrs1"] = np.arange(0.96646905,1.4494654,0.003) # wv_for_cent_calib_dict["G140H nrs2"] = [] # wv_for_cent_calib_dict["G235H nrs1"] = []#0.005 # wv_for_cent_calib_dict["G235H nrs2"] = [] # wv_for_cent_calib_dict["G395H nrs1"] = np.arange(2.859, 4.103, 0.01) # wv_for_cent_calib_dict["G395H nrs2"] = np.arange(4.081, 5.280, 0.01) def run_coordinate_recenter(cal_files, utils_dir, init_centroid=(0, 0), wv_sampling=None, N_wvs_nodes=40, mask_charge_transfer_radius=None, IWA=0.3, OWA=1.0, debug_init=None, debug_end=None, mppool=None, save_plots=False, filename_suffix="_webbpsf", overwrite=False, targetname=None): """ Parameters ---------- cal_files : list of strings Filenames of stage 2 reduced Cal files to process utils_dir init_centroid : tuple of floats Starting guess for centroid location wv_sampling Wavelength sampling N_wvs_nodes Number of wavelength nodes mask_charge_transfer_radius IWA : float Inner working angle OWA : float Outer working angle debug_init debug_end numthreads save_plots filename_suffix overwrite targetname : str Returns ------- Creates output files: [X]_fitpsf_[filename_suffix].fits [X]_poly2d_centroid_[filename_suffix].txt """ if not os.path.exists(utils_dir): os.makedirs(utils_dir) # Science data: List of stage 2 cal.fits files for filename in cal_files: print(filename) print("N files: {0}".format(len(cal_files))) # Define the filename of the output file saved by fitpsf splitbasename = os.path.basename(cal_files[0]).split("_") fitpsf_filename = os.path.join(utils_dir, splitbasename[0] + "_" + splitbasename[1] + "_" + splitbasename[ 3] + "_fitpsf" + filename_suffix + ".fits") poly2d_centroid_filename = os.path.join(utils_dir, splitbasename[0] + "_" + splitbasename[1] + "_" + splitbasename[ 3] + "_poly2d_centroid" + filename_suffix + ".txt") hdulist_sc = fits.open(cal_files[0]) detector = hdulist_sc[0].header["DETECTOR"].strip().lower() if wv_sampling is None: wv_sampling = np.arange(np.nanmin(hdulist_sc["WAVELENGTH"].data), np.nanmax(hdulist_sc["WAVELENGTH"].data), np.nanmedian(hdulist_sc["WAVELENGTH"].data) / 300) hdulist_sc.close() # If the output centroid file already exists, from a prior run of this function, then # just reload those results and return them, without doing any additional calculation, # unless overwrite is set. if not overwrite: if len(glob(poly2d_centroid_filename)) == 1: print("Found centroid results from prior calculation. Loading and returning those.") output = np.loadtxt(poly2d_centroid_filename, delimiter=' ') poly_p_ra, poly_p_dec = output[0], output[1] print("RA correction " + detector, poly_p_ra) print("Dec correction " + detector, poly_p_dec) return poly_p_ra, poly_p_dec regwvs_dataobj_list = [] for filename in cal_files[0::]: print(filename) if detector not in filename: raise Exception("The files in cal_files should all be for the same detector") # Define a series of processing tasks to be performed on each input file. preproc_task_list = [] preproc_task_list.append(["compute_med_filt_badpix", {"window_size": 50, "mad_threshold": 50}, True, True]) preproc_task_list.append(["compute_coordinates_arrays",{'targname':targetname}, True, True]) preproc_task_list.append(["convert_MJy_per_sr_to_MJy"]) preproc_task_list.append(["compute_starspectrum_contnorm", {"N_nodes": N_wvs_nodes, "threshold_badpix": 100, "mppool": mppool}, True, True]) preproc_task_list.append(["compute_starsubtraction", {"starsub_dir": "starsub1d_tmp", "threshold_badpix": 10, "mppool": mppool}, True, True]) preproc_task_list.append(["compute_interpdata_regwvs", {"wv_sampling": wv_sampling}, True, True]) # Load that file. (TODO: Does this invoke the preproc_task_list?) dataobj = JWSTNirspec_cal(filename, utils_dir=utils_dir, save_utils=True, load_utils=True, preproc_task_list=preproc_task_list) regwvs_dataobj_list.append(dataobj.reload_interpdata_regwvs()) # Combine all input files into a joint dataset regwvs_combdataobj = JWSTNirspec_multiple_cals(regwvs_dataobj_list) if mask_charge_transfer_radius is not None: regwvs_combdataobj.compute_charge_bleeding_mask(threshold2mask=mask_charge_transfer_radius) # Load the webbPSF model (or compute if it does not yet exist) webbpsf_reload = regwvs_combdataobj.reload_webbpsf_model() if webbpsf_reload is None: webbpsf_reload = regwvs_combdataobj.compute_webbpsf_model(wv_sampling=regwvs_combdataobj.wv_sampling, image_mask=None, pixelscale=0.1, oversample=10, parallelize=False, mppool=mppool, save_utils=True) wpsfs, wpsfs_header, wepsfs, webbpsf_wvs, webbpsf_x, webbpsf_y, wpsf_oversample, wpsf_pixelscale = webbpsf_reload webbpsf_x = np.tile(webbpsf_x[None, :, :], (wepsfs.shape[0], 1, 1)) webbpsf_y = np.tile(webbpsf_y[None, :, :], (wepsfs.shape[0], 1, 1)) # Fit a model PSF (WebbPSF) to the combined point cloud of dataobj_list # Save output as fitpsf_filename ann_width = None padding = 0.0 sector_area = None where_center_disk = regwvs_combdataobj.where_point_source((0.0, 0.0), IWA) regwvs_combdataobj.bad_pixels[where_center_disk] = np.nan fitpsf(regwvs_combdataobj, wepsfs, webbpsf_x, webbpsf_y, out_filename=fitpsf_filename, IWA=0.0, OWA=OWA, mppool=mppool, init_centroid=init_centroid, ann_width=ann_width, padding=padding, sector_area=sector_area, RDI_folder_suffix=filename_suffix, rotate_psf=regwvs_combdataobj.east2V2_deg, flipx=True, psf_spaxel_area=(wpsf_pixelscale) ** 2, debug_init=debug_init, debug_end=debug_end) with fits.open(fitpsf_filename) as hdulist: bestfit_coords = hdulist[0].data x2fit = wv_sampling - np.nanmedian(wv_sampling) y2fit = bestfit_coords[0, :, 2] _wv_min = wv_sampling[0] + 0.1 * (wv_sampling[-1] - wv_sampling[0]) _wv_max = wv_sampling[-1] - 0.1 * (wv_sampling[-1] - wv_sampling[0]) print(_wv_min, _wv_max) wherefinite = np.where(np.isfinite(y2fit) * (wv_sampling > _wv_min) * (wv_sampling < _wv_max)) poly_p_ra = np.polyfit(x2fit[wherefinite], y2fit[wherefinite], deg=2) print("RA correction " + detector, poly_p_ra) x2fit = wv_sampling - np.nanmedian(wv_sampling) y2fit = bestfit_coords[0, :, 3] wherefinite = np.where(np.isfinite(y2fit) * (wv_sampling > _wv_min) * (wv_sampling < _wv_max)) poly_p_dec = np.polyfit(x2fit[wherefinite], y2fit[wherefinite], deg=2) print("Dec correction " + detector, poly_p_dec) # Save centroids to a text file np.savetxt(poly2d_centroid_filename, [poly_p_ra, poly_p_dec], delimiter=' ') if save_plots: color_list = ["#ff9900", "#006699", "#6600ff", "#006699", "#ff9900", "#6600ff"] print(bestfit_coords.shape) fontsize = 12 plt.figure(figsize=(12, 10)) plt.subplot(3, 1, 1) plt.plot(wv_sampling, bestfit_coords[0, :, 0] * 1e9, linestyle="-", color=color_list[0], label="Fixed centroid", linewidth=1) plt.plot(wv_sampling, bestfit_coords[0, :, 1] * 1e9, linestyle="--", color=color_list[2], label="Free centroid", linewidth=1) plt.xlim([wv_sampling[0], wv_sampling[-1]]) plt.xlabel("Wavelength ($\\mu$m)", fontsize=fontsize) plt.ylabel("Flux density (mJy)", fontsize=fontsize) plt.gca().tick_params(axis='x', labelsize=fontsize) plt.gca().tick_params(axis='y', labelsize=fontsize) plt.legend(loc="upper right") plt.subplot(3, 1, 2) plt.plot(wv_sampling, bestfit_coords[0, :, 2], label="bestfit centroid") poly_model = np.polyval(poly_p_ra, wv_sampling - np.nanmedian(wv_sampling)) plt.plot(wv_sampling, poly_model, label="polyfit") plt.plot(wv_sampling, bestfit_coords[0, :, 2] - poly_model, label="residuals") plt.xlabel("Wavelength ($\\mu$m)", fontsize=fontsize) plt.ylabel("$\\Delta$RA (as)", fontsize=fontsize) plt.gca().tick_params(axis='x', labelsize=fontsize) plt.gca().tick_params(axis='y', labelsize=fontsize) plt.legend(loc="upper right") plt.subplot(3, 1, 3) plt.plot(wv_sampling, bestfit_coords[0, :, 3], label="bestfit centroid") poly_model = np.polyval(poly_p_dec, wv_sampling - np.nanmedian(wv_sampling)) plt.plot(wv_sampling, poly_model, label="polyfit") plt.plot(wv_sampling, bestfit_coords[0, :, 3] - poly_model, label="residuals") plt.xlabel("Wavelength ($\\mu$m)", fontsize=fontsize) plt.ylabel("$\\Delta$Dec (as)", fontsize=fontsize) plt.gca().tick_params(axis='x', labelsize=fontsize) plt.gca().tick_params(axis='y', labelsize=fontsize) plt.tight_layout() now = datetime.datetime.now() formatted_datetime = now.strftime("%Y%m%d_%H%M%S") out_filename = os.path.join(utils_dir, formatted_datetime + "_centroid_calibration.png") print("Saving " + out_filename) plt.savefig(out_filename, dpi=300) return poly_p_ra, poly_p_dec ########################################################################### # Functions for noise cleaning def fm_column_background(nonlin_paras, cubeobj, nodes=20, fix_parameters=None, return_where_finite=False, regularization=None, badpixfraction=0.75, M_spline=None, spline_reg_std=1.0): """ Forward model column background, for use in forward_model_noise_clean Parameters ---------- nonlin_paras cubeobj nodes fix_parameters return_where_finite regularization badpixfraction M_spline spline_reg_std Returns ------- """ if fix_parameters is not None: _nonlin_paras = np.array(fix_parameters) _nonlin_paras[np.where(np.array(fix_parameters) is None)] = nonlin_paras else: _nonlin_paras = nonlin_paras if M_spline is None: if type(nodes) is int: n_nodes = nodes x_knots = np.linspace(0, np.size(cubeobj.data), n_nodes, endpoint=True).tolist() elif type(nodes) is list or type(nodes) is np.ndarray: x_knots = nodes if type(nodes[0]) is list or type(nodes[0]) is np.ndarray: n_nodes = np.sum([np.size(n) for n in nodes]) else: n_nodes = np.size(nodes) else: raise ValueError("Unknown format for nodes.") else: n_nodes = M_spline.shape[1] # Number of linear parameters n_linpara = n_nodes data = cubeobj.data noise = cubeobj.noise bad_pixels = cubeobj.bad_pixels where_trace_finite = np.where(np.isfinite(data) * np.isfinite(bad_pixels) * (noise != 0)) d = data[where_trace_finite] s = noise[where_trace_finite] if np.size(where_trace_finite[0]) <= (1 - badpixfraction) * np.size(data): # don't bother to do a fit if there are too many bad pixels return np.array([]), np.array([]).reshape(0, n_linpara), np.array([]) else: x = np.arange(np.size(cubeobj.data)) if M_spline is None: m_spline = get_spline_model(x_knots, x, spline_degree=3) else: m_spline = copy(M_spline) m_spline = m_spline[where_trace_finite[0], :] extra_outputs = {} if regularization == "default": s_reg = np.zeros(n_nodes) + spline_reg_std d_reg = np.zeros(n_nodes) extra_outputs["regularization"] = (d_reg, s_reg) elif regularization == "user": raise Exception("user defined regularisation not yet implemented") extra_outputs["regularization"] = (d_reg, s_reg) if return_where_finite: extra_outputs["where_trace_finite"] = where_trace_finite if len(extra_outputs) >= 1: return d, m_spline, s, extra_outputs else: return d, m_spline, s def fm_charge_transfer(nonlin_paras, cubeobj, charge_transfer_mask=None, nodes=40, fix_parameters=None, return_where_finite=False, regularization=None, badpixfraction=0.75, M_spline=None, spline_reg_std=1.0): """ Forward model charge transfer ("bleeding") within the detector, particularly for bright/saturated sources Parameters ---------- nonlin_paras cubeobj charge_transfer_mask nodes fix_parameters return_where_finite regularization badpixfraction M_spline spline_reg_std Returns ------- """ if fix_parameters is not None: _nonlin_paras = np.array(fix_parameters) _nonlin_paras[np.where(np.array(fix_parameters) is None)] = nonlin_paras else: _nonlin_paras = nonlin_paras if M_spline is None: if type(nodes) is int: n_nodes = nodes x_knots = np.linspace(np.nanmin(cubeobj.wavelengths), np.nanmax(cubeobj.wavelengths), n_nodes, endpoint=True).tolist() elif type(nodes) is list or type(nodes) is np.ndarray: x_knots = nodes if type(nodes[0]) is list or type(nodes[0]) is np.ndarray: n_nodes = np.sum([np.size(n) for n in nodes]) else: n_nodes = np.size(nodes) else: raise ValueError("Unknown format for nodes.") else: n_nodes = M_spline.shape[1] # Number of linear parameters n_linpara = n_nodes where_finite = np.where(np.isfinite(cubeobj.data) * np.isfinite(cubeobj.bad_pixels) * (cubeobj.noise != 0)) d = cubeobj.data[where_finite] s = cubeobj.noise[where_finite] if np.size(where_finite[0]) <= (1 - badpixfraction) * np.size(cubeobj.data): # don't bother to do a fit if there are too many bad pixels return np.array([]), np.array([]).reshape(0, n_linpara), np.array([]) else: where_finite_wvs = np.where(np.isfinite(cubeobj.wavelengths)) if M_spline is None: m_tmp = get_spline_model(x_knots, cubeobj.wavelengths[where_finite_wvs], spline_degree=3) else: m_tmp = copy(M_spline) m_entire_image = np.zeros((cubeobj.data.shape[0], cubeobj.data.shape[1], m_tmp.shape[1])) m_entire_image[where_finite_wvs[0], where_finite_wvs[1], :] = m_tmp m_entire_image = m_entire_image * charge_transfer_mask[:, :, None] m_entire_image[np.where(np.isnan(m_entire_image))] = 0 kernel_scale = _nonlin_paras[0] x = np.arange(-cubeobj.data.shape[0], cubeobj.data.shape[0] + 1) charge_transfer_kernel = 1 / (1 + x ** 2 / kernel_scale ** 2) # Lorentzian Function # Convolve each column with the kernel m_entire_image_convolved = convolve1d(m_entire_image, weights=charge_transfer_kernel, axis=0, mode='constant') m_output = m_entire_image_convolved[where_finite[0], where_finite[1], :] extra_outputs = {} if regularization == "default": s_reg = np.zeros(n_nodes) + spline_reg_std d_reg = np.zeros(n_nodes) extra_outputs["regularization"] = (d_reg, s_reg) elif regularization == "user": raise Exception("user defined regularisation not yet implemented") extra_outputs["regularization"] = (d_reg, s_reg) if return_where_finite: extra_outputs["where_finite"] = where_finite if len(extra_outputs) >= 1: return d, m_output, s, extra_outputs else: return d, m_output, s def forward_model_noise_clean(rate_file, cal_file_dir, clean_dir, N_nodes=40, model_charge_transfer=False, utils_dir=None, coords_offset=(0, 0)): """ Clean 1/f stripe noise from NIRSpec IFU data. Inspired by NSClean but implemented independently. The way it works: subtraction done on rate.fits Use the cal.fits to retrieve the mask of the IFU slices Fit detector columns one at time. I just fit a smooth continuum (using my splines) to the masked detector column, and also masking the region around the star more aggressively I believe subtract the fitted continuum Save new rate.fits Parameters ---------- rate_file cal_file_dir clean_dir N_nodes model_charge_transfer utils_dir coords_offset Returns ------- """ basename = os.path.basename(rate_file) cal_filename = os.path.join(cal_file_dir, basename.replace("_rate.fits", "_cal.fits")) if len(glob(cal_filename)) == 0: raise Exception("Could not find the corresponding cal file. Please run stage 2 without cleaning first.") cal_dataobj = JWSTNirspec_cal(cal_filename, utils_dir=utils_dir, save_utils=True, load_utils=True) out = cal_dataobj.reload_coordinates_arrays() if out is None: cal_dataobj.compute_coordinates_arrays(save_utils=True) cal_dataobj.apply_coords_offset(coords_offset=coords_offset) ra_im, dec_im = cal_dataobj.get_sky_coords() sep_im = np.sqrt(ra_im ** 2 + dec_im ** 2) cal_im = cal_dataobj.data hdulist_cal = fits.open(cal_dataobj.filename) dq = hdulist_cal["DQ"].data hdulist_cal.close() print(cal_filename) print(glob(cal_filename)) with fits.open(cal_filename) as hdul: cal_im = hdul["SCI"].data # Get data. Read rate.fits file hdul = fits.open(rate_file) priheader = hdul[0].header extheader = hdul[1].header im = hdul["SCI"].data # im_ori = copy(im) noise = hdul["ERR"].data dq = hdul["DQ"].data ny, nx = im.shape cal_mask = np.ones(cal_im.shape) cal_mask[np.where(np.isnan((sep_im)))] = np.nan # Simplifying bad pixel map following convention in this package as: nan = bad, 1 = good bad_pixels = np.full(cal_im.shape, np.nan) # array full of nans # We select only the background pixels: bad_pixels[np.where(np.isnan(cal_mask))] = 1 # every pixel that is not in a cal slice is actually good here # Pixels marked as "do not use" are marked as bad (nan = bad, 1 = good): bad_pixels[np.where(untangle_dq(dq, verbose=True)[0, :, :])] = np.nan bad_pixels[np.where(np.isnan(im))] = np.nan # Removing any data with zero noise where_zero_noise = np.where(noise == 0) noise[where_zero_noise] = np.nan bad_pixels[where_zero_noise] = np.nan im[np.where(np.isnan(im))] = 0 # Extent the slices masks to the edge of the detector if "nrs1" in rate_file: for rowid in range(im.shape[0]): finite_ids = np.where(np.isfinite(sep_im[rowid, 0:450]))[0] if len(finite_ids) != 0: id_to_mask = np.min(finite_ids) bad_pixels[rowid, 0:id_to_mask] = np.nan sep_im[rowid, 0:id_to_mask] = sep_im[rowid, id_to_mask] elif "nrs2" in rate_file: for rowid in range(im.shape[0]): finite_ids = np.where(np.isfinite(sep_im[rowid, 1550::]))[0] if len(finite_ids) != 0: id_to_mask = np.max(finite_ids) bad_pixels[rowid, 1550 + id_to_mask::] = np.nan sep_im[rowid, 1550 + id_to_mask::] = sep_im[rowid, 1550 + id_to_mask] mad_threshold = 5 window_size = 50 new_badpix = np.ones(bad_pixels.shape) for rowid in range(bad_pixels.shape[0]): row_data = im[rowid, :] - generic_filter(im[rowid, :] * bad_pixels[rowid, :], np.nanmedian, size=window_size) row_data_masking = row_data / median_abs_deviation(row_data[np.where(np.isfinite(bad_pixels[rowid, :]))]) new_badpix[rowid, np.where((row_data_masking > mad_threshold))[0]] = np.nan bad_pixels *= new_badpix if model_charge_transfer: data = Instrument() data.data = copy(im) data.noise = copy(noise) data.bad_pixels = copy(bad_pixels) data.wavelengths = copy(cal_dataobj.wavelengths) cal_model_filename = os.path.join(utils_dir, "RDI_model_webbpsf", basename.replace("_rate.fits", "_cal.fits")) if len(glob(cal_model_filename)) == 0: raise Exception( "Could not find the corresponding RDI model webbpsf cal file. " "Please run run_coordinate_recenter(...) first.") with fits.open(cal_model_filename) as hdul_model: webbpsf_im = hdul_model["SCI"].data saturated_mask = np.full(cal_dataobj.data.shape, np.nan) saturated_mask[np.where(untangle_dq(dq, verbose=False)[1, :, :])] = 1 saturated_mask[np.where((sep_im > 0.5))] = np.nan # Define the saturation threshold in Mjy/sr below. This is definitely not ideal, probably not accurate. # Will probably need to fix later. saturation_threshold = 1e5 # Mjy/sr charge_transfer_mask = (webbpsf_im - saturation_threshold) * saturated_mask charge_transfer_mask[np.where(np.isnan(charge_transfer_mask))] = 0.0 charge_transfer_mask = np.clip(charge_transfer_mask, 0, np.inf) # Define the spline nodes for fitting the background in each detector column n_nodes_charge_transfer = 5 # number of nodes in the column x_knots_charge_transfer = np.linspace(np.nanmin(cal_dataobj.wavelengths), np.nanmax(cal_dataobj.wavelengths), n_nodes_charge_transfer, endpoint=True).tolist() where_finite_wvs = np.where(np.isfinite(cal_dataobj.wavelengths)) m_spline_charge_transfer = get_spline_model(x_knots_charge_transfer, cal_dataobj.wavelengths[where_finite_wvs], spline_degree=3) fix_parameters = [10] # width of the lorentzian fm_paras = {"charge_transfer_mask": charge_transfer_mask, "fix_parameters": fix_parameters, "regularization": None, "badpixfraction": 0.75, "M_spline": m_spline_charge_transfer, "spline_reg_std": 1.0} nonlin_paras = [] out_log_prob, _, rchi2, linparas, linparas_err = fitfm(nonlin_paras, data, fm_charge_transfer, fm_paras, computeH0=False, scale_noise=False) d_masked, m, s, extra_outputs = fm_charge_transfer(nonlin_paras, data, return_where_finite=True, **fm_paras) where_finite = extra_outputs["where_finite"] d_masked_canvas = np.zeros(data.data.shape) + np.nan d_masked_canvas[where_finite] = d_masked data.bad_pixels = np.ones(data.data.shape) d, m, s, _ = fm_charge_transfer(nonlin_paras, data, return_where_finite=True, **fm_paras) model_canvas = np.dot(m, linparas) model_canvas = np.reshape(model_canvas, data.data.shape) ################################ # remove lorentzian model im -= model_canvas ################################ x = np.arange(2048) x_knots_column = np.linspace(0, 2048, N_nodes, endpoint=True).tolist() m_spline_column = get_spline_model(x_knots_column, x, spline_degree=3) data = Instrument() new_im = np.zeros(im.shape) for colid in range(im.shape[1]): # print(colid) # colid=300 data.data = copy(im[:, colid]) data.noise = copy(noise[:, colid]) data.bad_pixels = copy(bad_pixels[:, colid]) nonlin_paras = [] fm_paras = {"badpixfraction": 0.99, "nodes": N_nodes, "fix_parameters": None, "regularization": "default", "M_spline": m_spline_column} if 1: # optimize non linear parameter out_log_prob, _, rchi2, linparas, linparas_err = fitfm(nonlin_paras, data, fm_column_background, fm_paras, computeH0=False, scale_noise=False) if not np.isfinite(out_log_prob): continue d_masked, m, s, extra_outputs = fm_column_background(nonlin_paras, data, return_where_finite=True, **fm_paras) where_finite = extra_outputs["where_trace_finite"] data.bad_pixels = np.ones(data.data.shape) d, m, s, _ = fm_column_background(nonlin_paras, data, return_where_finite=True, **fm_paras) # print(data.data.shape,d.shape,np.size(where_finite[0]),np.size(d_masked)) d_masked_canvas = np.zeros(d.shape) + np.nan d_masked_canvas[where_finite] = d_masked m = np.dot(m, linparas) mad = median_abs_deviation(((d_masked_canvas - m))[np.where(np.isfinite(d_masked_canvas))]) data.bad_pixels = bad_pixels[:, colid] data.bad_pixels[np.where(np.abs(d_masked_canvas - m) > 5 * mad)] = np.nan if 1: # optimize non linear parameter out_log_prob, _, rchi2, linparas, linparas_err = fitfm(nonlin_paras, data, fm_column_background, fm_paras, computeH0=False, scale_noise=False) if not np.isfinite(out_log_prob): continue d_masked, m, s, extra_outputs = fm_column_background(nonlin_paras, data, return_where_finite=True, **fm_paras) where_finite = extra_outputs["where_trace_finite"] data.bad_pixels = np.ones(data.data.shape) d, m, s, _ = fm_column_background(nonlin_paras, data, return_where_finite=True, **fm_paras) # print(data.data.shape,d.shape,np.size(where_finite[0]),np.size(d_masked)) d_masked_canvas = np.zeros(d.shape) + np.nan d_masked_canvas[where_finite] = d_masked m = np.dot(m, linparas) new_im[:, colid] = im[:, colid] - m priheader['comment'] = 'Detector correlated noise removed by custom code' hdul[0].header = priheader hdul["SCI"].data = new_im new_rate_file = os.path.join(clean_dir, os.path.basename(rate_file)) hdul.writeto(new_rate_file, overwrite=True) hdul.close() return new_rate_file
[docs] def run_noise_clean(rate_files, stage2_dir, output_dir, N_nodes=40, model_charge_transfer=False, utils_dir=None, coords_offset=(0, 0), overwrite=False, save_plots=True): """Invoke forward model noise removal for a list of rate files Parameters ---------- rate_files stage2_dir output_dir N_nodes model_charge_transfer utils_dir coords_offset overwrite Returns ------- """ # We need to check that the desired output directories exist, and if not create them if not os.path.exists(output_dir): os.makedirs(output_dir) # Start a timer to keep track of runtime time0 = time.perf_counter() cleaned_rate_files = [] for fid, rate_file in enumerate(rate_files): print(f"Noise Clean: Processing file {fid + 1} of {len(rate_files)}: {os.path.basename(rate_file)}") outname = os.path.join(output_dir, os.path.basename(rate_file)) cleaned_rate_files.append(outname) if os.path.exists(outname) and not overwrite: print(f"\tOutput file {os.path.basename(outname)} already exists in the cleaned output directory; skipping {os.path.basename(rate_file)}.") continue forward_model_noise_clean(rate_file, stage2_dir, output_dir, N_nodes=N_nodes, model_charge_transfer=model_charge_transfer, utils_dir=utils_dir, coords_offset=coords_offset) # Print out the time benchmark time1 = time.perf_counter() print(f"\tNoise Clean Runtime so far: {time1 - time0:0.4f} seconds\n") time1 = time.perf_counter() print(f"Noise Clean Total Runtime: {time1 - time0:0.4f} seconds") if save_plots: breads.jwst_tools.plotting.plot_2d_image_sets_side_by_side(rate_files, cleaned_rate_files, output_dir=output_dir, suptitle="Noise Cleaning results. Left = Before, Right = After.", plot_label='noiseclean') return cleaned_rate_files
########################################################################### # Host Star PSF Subtraction def compute_normalized_stellar_spectrum(cal_files, utils_dir, coords_offset=(0, 0), wv_nodes=None, mask_charge_transfer_radius=None, mppool=None, ra_dec_point_sources=None, overwrite=False,targetname=None): """ Parameters ---------- cal_files utils_dir coords_offset wv_nodes mask_charge_transfer_radius mppool ra_dec_point_sources overwrite Returns ------- """ if not os.path.exists(utils_dir): os.makedirs(utils_dir) hdulist_sc = fits.open(cal_files[0]) detector = hdulist_sc[0].header["DETECTOR"].strip().lower() if wv_nodes is None: wv_nodes = np.linspace(np.nanmin(hdulist_sc["WAVELENGTH"].data), np.nanmax(hdulist_sc["WAVELENGTH"].data), 40, endpoint=True) hdulist_sc.close() splitbasename = os.path.basename(cal_files[0]).split("_") combined_contnorm_spec_filename = os.path.join(utils_dir, splitbasename[0] + "_" + splitbasename[ 1] + "_" + detector + "_starspec_contnorm_combined_1dspline.fits") if not overwrite: if len(glob(combined_contnorm_spec_filename)): with fits.open(combined_contnorm_spec_filename) as hdulist: new_wavelengths = hdulist[0].data combined_fluxes = hdulist[1].data combined_errors = hdulist[2].data combined_star_func = interp1d(new_wavelengths, combined_fluxes, kind="linear", bounds_error=False, fill_value=1) return combined_star_func dataobj_list = [] for filename in cal_files: print(filename) preproc_task_list = [] preproc_task_list.append(["compute_med_filt_badpix", {"window_size": 50, "mad_threshold": 50}]) preproc_task_list.append(["compute_coordinates_arrays",{'targname':targetname}]) preproc_task_list.append(["convert_MJy_per_sr_to_MJy"]) preproc_task_list.append(["apply_coords_offset", {"coords_offset": coords_offset}]) preproc_task_list.append(["compute_starspectrum_contnorm", {"x_nodes": wv_nodes, "threshold_badpix": 100, "mppool": mppool}]) preproc_task_list.append(["compute_starsubtraction", {"starsub_dir": "starsub1d", "threshold_badpix": 10, "mppool": mppool}]) dataobj = JWSTNirspec_cal(filename, utils_dir=utils_dir, save_utils=True, load_utils=False, preproc_task_list=preproc_task_list) # Do some masking dataobj.bad_pixels = crop_trace_edges(dataobj.bad_pixels, N_pix=1, trace_id_map=dataobj.trace_id_map) if mask_charge_transfer_radius is not None: dataobj.compute_charge_bleeding_mask(threshold2mask=mask_charge_transfer_radius) # mask planets before computing the star spectrum if ra_dec_point_sources is not None: for ra_pl, dec_pl in ra_dec_point_sources: where_pl = dataobj.where_point_source([ra_pl / 1000., dec_pl / 1000.], 0.16) dataobj.bad_pixels[where_pl] = np.nan dataobj_list.append(dataobj) new_wavelengths, combined_fluxes, combined_errors = get_contnorm_spec(dataobj_list, spline2d=False, load_utils=False, out_filename=combined_contnorm_spec_filename, spec_R_sampling=2700 * 4, interpolation="linear") combined_star_func = interp1d(new_wavelengths, combined_fluxes, kind="linear", bounds_error=False, fill_value=1) return combined_star_func def compute_starlight_subtraction(cal_files, utils_dir, wv_nodes=None, combined_star_func=None, coords_offset=(0, 0), mppool=None,targetname=None): """ Parameters ---------- cal_files utils_dir wv_nodes combined_star_func coords_offset mppool Returns ------- """ hdulist_sc = fits.open(cal_files[0]) detector = hdulist_sc[0].header["DETECTOR"].strip().lower() if wv_nodes is None: wv_nodes = np.linspace(np.nanmin(hdulist_sc["WAVELENGTH"].data), np.nanmax(hdulist_sc["WAVELENGTH"].data), 40, endpoint=True) hdulist_sc.close() dataobj_list = [] for filename in cal_files[0::]: print(filename) preproc_task_list = [] preproc_task_list.append(["compute_med_filt_badpix", {"window_size": 50, "mad_threshold": 50}, True, True]) preproc_task_list.append(["compute_coordinates_arrays",{'targname':targetname}, True, True]) preproc_task_list.append(["convert_MJy_per_sr_to_MJy"]) preproc_task_list.append(["apply_coords_offset", {"coords_offset": coords_offset}]) if combined_star_func is None: preproc_task_list.append(["compute_starspectrum_contnorm", {"x_nodes": wv_nodes, "threshold_badpix": 100, "mppool": mppool}, True, True]) dataobj = JWSTNirspec_cal(filename, utils_dir=utils_dir, save_utils=True, load_utils=True, preproc_task_list=preproc_task_list) if combined_star_func is not None: dataobj.reload_starspectrum_contnorm() dataobj.star_func = combined_star_func outputs = dataobj.reload_starsubtraction() if outputs is None: outputs = dataobj.compute_starsubtraction(save_utils=True, starsub_dir="starsub1d", threshold_badpix=10, mppool=mppool) subtracted_im, star_model, spline_paras0, _wv_nodes = outputs dataobj_list.append(dataobj) return dataobj_list ########################################################################### # Regular Wavelength Grids def get_combined_regwvs(dataobj_list, wv_sampling=None, mask_charge_transfer_radius=None, use_starsub=False,recompute=False,starsub_dir='starsub1d'): """ Parameters ---------- dataobj_list wv_sampling mask_charge_transfer_radius use_starsub1d Returns ------- """ regwvs_dataobj_list = [] for dataobj in dataobj_list: if use_starsub: starsub_filename = os.path.join(dataobj.utils_dir, starsub_dir, os.path.basename(dataobj.filename)) starsub_dataobj = JWSTNirspec_cal(starsub_filename, utils_dir=dataobj.utils_dir) if (dataobj.data_unit == 'MJy') and (starsub_dataobj.data_unit == 'MJy/sr'): replace_data = dataobj.convert_MJy_per_sr_to_MJy(data_in_MJy_per_sr=starsub_dataobj.data) elif (dataobj.data_unit == 'MJy/sr') and (starsub_dataobj.data_unit == 'MJy/sr'): replace_data = starsub_dataobj.data elif (dataobj.data_unit =='MJy') and (starsub_dataobj.data_unit == 'MJy'): replace_data = starsub_dataobj.data elif (dataobj.data_unit =='MJy/sr') and (starsub_dataobj.data_unit == 'MJy'): print('Exception: data obj in MJy/sr and starsub in MJy') raise Exception('conversion from MJy to MJy/sr not implemented yet.') regwvs_filename = dataobj.default_filenames["compute_interpdata_regwvs"].replace("_regwvs.fits", "_"+starsub_dir+"_regwvs.fits") else: replace_data = None regwvs_filename = dataobj.default_filenames["compute_interpdata_regwvs"] if not recompute: regwvs_dataobj = dataobj.reload_interpdata_regwvs(load_filename=regwvs_filename) else: print('RECOMPUTING GET_combined_REGWVS...') regwvs_dataobj = None if regwvs_dataobj is None: regwvs_dataobj = dataobj.compute_interpdata_regwvs(save_utils=regwvs_filename, wv_sampling=wv_sampling, replace_data=replace_data) regwvs_dataobj_list.append(regwvs_dataobj) regwvs_combdataobj = JWSTNirspec_multiple_cals(regwvs_dataobj_list) if mask_charge_transfer_radius is not None: regwvs_combdataobj.compute_charge_bleeding_mask(threshold2mask=mask_charge_transfer_radius) return regwvs_combdataobj def save_combined_regwvs(regwvs_combdataobj, out_filename): """ Parameters ---------- regwvs_combdataobj out_filename Returns ------- """ hdulist = fits.HDUList() hdulist.append(fits.ImageHDU(data=regwvs_combdataobj.data, name='DATA')) hdulist.append(fits.ImageHDU(data=regwvs_combdataobj.noise, name='ERR')) hdulist.append(fits.ImageHDU(data=regwvs_combdataobj.dra_as_array, name='RA')) hdulist.append(fits.ImageHDU(data=regwvs_combdataobj.ddec_as_array, name='DEC')) hdulist.append(fits.ImageHDU(data=regwvs_combdataobj.wavelengths, name='WAVE')) hdulist.append(fits.ImageHDU(data=regwvs_combdataobj.wv_sampling, name='WV_SAMPLING')) hdulist.append(fits.ImageHDU(data=regwvs_combdataobj.bad_pixels, name='BADPIX')) hdulist.writeto(out_filename, overwrite=True) hdulist.close() def get_2D_point_cloud_interpolator(regwvs_combdataobj, wv0, miri=False): """ Parameters ---------- regwvs_combdataobj wv0 Returns ------- """ if isinstance(regwvs_combdataobj, str): with fits.open(regwvs_combdataobj) as hdulist: data = hdulist["DATA"].data dra_as_array = hdulist["RA"].data ddec_as_array = hdulist["DEC"].data bad_pixels = hdulist["BADPIX"].data wv_sampling = hdulist["WV_SAMPLING"].data else: data = regwvs_combdataobj.data dra_as_array = regwvs_combdataobj.dra_as_array ddec_as_array = regwvs_combdataobj.ddec_as_array bad_pixels = regwvs_combdataobj.bad_pixels wv_sampling = regwvs_combdataobj.wv_sampling if miri: data = data.transpose() dra_as_array = dra_as_array.transpose() ddec_as_array = ddec_as_array.transpose() bad_pixels = bad_pixels.transpose() wv0_index = np.argmin(np.abs(wv_sampling - wv0)) where_good = np.where(np.isfinite(bad_pixels[:, wv0_index])) x = dra_as_array[where_good[0], wv0_index] y = ddec_as_array[where_good[0], wv0_index] z = data[where_good[0], wv0_index] filtered_triangles = filter_big_triangles(x, y, 0.2) # Create filtered triangulation filtered_tri = tri.Triangulation(x, y, triangles=filtered_triangles) # Perform LinearTriInterpolator for filtered triangulation pointcloud_interp = tri.LinearTriInterpolator(filtered_tri, z) return pointcloud_interp def get_2D_point_cloud_per_dither(regwvs_combdataobj, wv0, miri=False): if isinstance(regwvs_combdataobj, str): with fits.open(regwvs_combdataobj) as hdulist: data = hdulist["DATA"].data dra_as_array = hdulist["RA"].data ddec_as_array = hdulist["DEC"].data bad_pixels = hdulist["BADPIX"].data wv_sampling = hdulist["WV_SAMPLING"].data else: data = regwvs_combdataobj.data dra_as_array = regwvs_combdataobj.dra_as_array ddec_as_array = regwvs_combdataobj.ddec_as_array bad_pixels = regwvs_combdataobj.bad_pixels wv_sampling = regwvs_combdataobj.wv_sampling if miri: data = data.transpose() dra_as_array = dra_as_array.transpose() ddec_as_array = ddec_as_array.transpose() bad_pixels = bad_pixels.transpose() detector_row_size = 1032 wv0_index = np.argmin(np.abs(wv_sampling - wv0)) print(data.shape) n_dithers = data.shape[0] // detector_row_size print("n_dithers:", n_dithers) x, y, z = [], [], [] for i in range(n_dithers): where_good = np.where(np.isfinite(bad_pixels[i*detector_row_size:(i+1)*detector_row_size, wv0_index])) x.append(dra_as_array[i*detector_row_size+where_good[0], wv0_index]) y.append(ddec_as_array[i*detector_row_size+where_good[0], wv0_index]) z.append(data[i*detector_row_size+where_good[0], wv0_index]) return x, y, z ############################################################################ # Function to invoke all reduction steps in one go
[docs] def run_complete_stage1_2_clean_reduction(input_dir, output_root_dir=None, overwrite=False): """Overarching top-level function to invoke stage1, stage2, and 1/f noise cleaning code This will run the complete reduction from uncal files to cal files. It will take a while. If files already exist, repeat reductions are skipped, unless overwrite is set True Parameters ---------- input_dir output_root_dir overwrite Returns ------- """ if output_root_dir is None: output_root_dir = input_dir # Set up subdirectory paths det1_dir = os.path.join(output_root_dir, "stage1") # Detector1 pipeline outputs will go here spec2_dir = os.path.join(output_root_dir, "stage2") # Initial spec2 pipeline outputs will go here clean_det1_dir = os.path.join(output_root_dir, "stage1_clean") # noise-cleaned Detector1 pipeline outputs will go here clean_spec2_dir = os.path.join(output_root_dir, "stage2_clean") # noise-cleaned Spec2 pipeline outputs will go here # Find input rate files uncal_files = find_files_to_process(input_dir, 'uncal.fits') # Run all reduction steps rate_files = run_stage1(uncal_files, output_dir=det1_dir, overwrite=overwrite) cleaned_rate_files = run_noise_clean(rate_files, spec2_dir, clean_det1_dir, overwrite=overwrite) cleaned_cal_files = run_stage2(cleaned_rate_files, output_dir=clean_spec2_dir, overwrite=overwrite) return cleaned_cal_files
########################################################################### # Functions for invoking the MIRI/MRS pipeline def mkdir_miri_files(path): """Short function to create directories for MIRI files""" if type(path) != str: raise TypeError("'path' must be a string") if not os.path.exists(path): os.makedirs(path) return path def sort_by_target_name(input_dir, filetype='uncal.fits'): files = find_files_to_process(input_dir, filetype) targname_groups = defaultdict(list) for file in files: header = fits.getheader(file) targname = header.get('TARGNAME', 'UNKNOWN') targname_groups[targname].append(file) return dict(targname_groups) def select_miri_output_directory(uncal_path, target_name, channel, band): """Short function to select the right MIRI output directory""" if band == 'SHORT': band_alias = 'A' elif band == 'MEDIUM': band_alias = 'B' elif band == 'LONG': band_alias = 'C' else: raise ValueError(f"Band {band} is not supported for stage 1 forward modeling") return os.path.join(uncal_path, target_name, channel + band_alias, 'stage1') def run_stage1_miri(uncal_path, list_bands=None, overwrite=False, maximum_cores="1", skip_dark=False): """Run pipeline stage 1, with some customizations for reductions""" if list_bands is None: list_bands = ['12A', '12B', '12C', '34A', '34B', '34C'] dict_files_by_target_names = sort_by_target_name(uncal_path) target_names = list(dict_files_by_target_names.keys()) print("DEBUG target_names", target_names) time0 = time.perf_counter() print(time0) rate_files = [] for target_name in target_names: print("DEBUG target_name", target_name) uncal_files = dict_files_by_target_names[target_name] for band in list_bands: mkdir_miri_files(os.path.join(uncal_path, target_name, band, 'stage1')) rate_files = [] for i, file in enumerate(uncal_files): print(f"Processing file {i + 1} of {len(uncal_files)}.") hdu_uncal = fits.open(file) band_uncal = hdu_uncal[0].header['BAND'] channel_uncal = hdu_uncal[0].header['CHANNEL'] output_dir = select_miri_output_directory(uncal_path, target_name, channel_uncal, band_uncal) new_name = os.path.basename(file).replace('uncal.fits', 'rate.fits') out_name = os.path.join(output_dir, new_name) rate_files.append(out_name) if os.path.exists(out_name) and not overwrite: print(f"Output file {out_name} already exists in output dir;\n\tskipping {file}.") else: det1 = Detector1Pipeline() # Instantiate the pipeline # defining used pipeline steps # This version only shows the step parameters which are changes from defaults. step_parameters = { # group_scale - run with defaults # dq_init - run with defaults 'saturation': {'n_pix_grow_sat': 0}, # check for saturated pixels, but do not expand to adjacent pixels # ipc - run with defaults # superbias - run with defaults # linearity - run with defaults 'emicorr': {'skip': False}, 'persistence': {'skip': True}, # This step does nothing; there are no nonzero parameters in the reference files yet 'dark_current': {'skip': skip_dark}, 'jump': {'maximum_cores': maximum_cores}, # parallelize 'ramp_fit': {'maximum_cores': maximum_cores}, # parallelize # gain_scale : run with defaults } det1.call(file, save_results=True, output_dir=output_dir, steps=step_parameters) # Print out the time benchmark time1 = time.perf_counter() print(f"Total Runtime: {time1 - time0:0.4f} seconds") return rate_files, target_names def run_bkg_subtraction(uncal_path, target_name, list_bands=None, overwrite=False): if list_bands is None: list_bands = ['12A', '12B', '12C', '34A', '34B', '34C'] for band in list_bands: output_dir = os.path.join(uncal_path, target_name, band, 'stage1_sub_bkg') mkdir_miri_files(output_dir) background_outputdir = os.path.join(uncal_path, target_name, band, 'master_bkg') mkdir_miri_files(background_outputdir) rate_files_all = find_files_to_process(os.path.join(uncal_path, target_name, band, 'stage1'), filetype='rate.fits') bkg_files = [f for f in rate_files_all if 'BACKGROUND' in fits.getheader(f)['OBSLABEL'] or 'BKG' in fits.getheader(f)['OBSLABEL']] rate_files = [f for f in rate_files_all if f not in bkg_files] bkg_master = np.zeros((len(bkg_files), 1024, 1032)) for i, bkg_file in enumerate(bkg_files): bkg_master[i, :, :] = fits.getdata(bkg_file) bkg_master = np.nanmedian(bkg_master, axis=0) plt.imshow(bkg_master, origin='lower') plt.show() fits.writeto(os.path.join(background_outputdir,f"background_master_{band}.fits"), bkg_master, overwrite=overwrite) for fid, rate_file in enumerate(rate_files): out_name = os.path.join(output_dir, os.path.basename(rate_file)) rate = fits.getdata(rate_file) with fits.open(rate_file, mode="readonly") as hdu: hdu_copy = fits.HDUList([hd.copy() for hd in hdu]) hdu_copy[1].data = rate - bkg_master hdu_copy[0].header['BKG_SUB'] = 'CUSTOM' hdu_copy.writeto(out_name, overwrite=overwrite) def flat_fringing_stage1(uncal_path, target_name, list_bands=None, flat_path=None, flat_extended=False, bkg_sub=False, overwrite=False): if list_bands is None: list_bands = ['12A', '12B', '12C', '34A', '34B', '34C'] for band in list_bands: mkdir_miri_files(os.path.join(uncal_path, target_name, band, 'stage1_flat')) # Start a timer to keep track of runtime time0 = time.perf_counter() print(time0) for band in list_bands: if bkg_sub: rate_files = find_files_to_process(os.path.join(uncal_path, target_name, band, 'stage1_sub_bkg'), filetype='rate.fits') else: rate_files = find_files_to_process(os.path.join(uncal_path, target_name, band, 'stage1'), filetype='rate.fits') output_dir = os.path.join(uncal_path, target_name, band, 'stage1_flat') rate_filtered_files = [] for fid, rate_file in enumerate(rate_files): print(fid, rate_file) out_name = os.path.join(output_dir, os.path.basename(rate_file)) rate_filtered_files.append(out_name) if os.path.exists(out_name) and not overwrite: print(f"Output file {out_name} already exists;\n\tskipping {rate_file}.") continue hdr = fits.getheader(rate_file) detector = hdr['DETECTOR'] band = hdr['BAND'] if flat_path is None: flat_path_rate = os.getenv("FLAT_PATH") if output_dir is None: raise ValueError("No FLAT_PATH specified to apply the fringe flat") else: flat_path_rate = flat_path print("Searching fringes flat files in:", flat_path_rate) if detector == 'MIRIFUSHORT': if band == 'SHORT': flat_path_rate = os.path.join(flat_path_rate, '12A') elif band == 'MEDIUM': flat_path_rate = os.path.join(flat_path_rate, '12B') elif band == 'LONG': flat_path_rate = os.path.join(flat_path_rate, '12C') else: raise ValueError(f'Unsupported band for file: {rate_file} must be either SHORT, MEDIUM or LONG') channel = 'CH2' else: if band == 'SHORT': flat_path_rate = os.path.join(flat_path_rate, '34A') elif band == 'MEDIUM': flat_path_rate = os.path.join(flat_path_rate, '34B') elif band == 'LONG': flat_path_rate = os.path.join(flat_path_rate, '34C') else: raise ValueError(f'Unsupported band for file: {rate_file} must be either SHORT, MEDIUM or LONG') channel = 'CH3' best_flat, flat_name, std_min = best_flat_selection(rate_file, flat_path_rate, channel, flat_extended=flat_extended) best_flat[np.isnan(best_flat)] = 1 rate_file_data = fits.getdata(rate_file) with fits.open(rate_file, mode="readonly") as hdu: hdu_copy = fits.HDUList([hd.copy() for hd in hdu]) hdu_copy[1].data = rate_file_data / best_flat hdu_copy['ERR'].data /= best_flat hdu_copy[0].header['S_FLAT'] = flat_name hdu_copy[0].header['FLAT_STD_MIN'] = std_min hdu_copy.writeto(out_name, overwrite=overwrite) print(f"==> Wrote fringe-corrected file to {out_name}") def run_stage2_miri(uncal_path, target_name, list_bands=None, custom_flatted=True, custom_bkg_sub=False, skip_cubes=True, skip_fringe=False, skip_residual_fringes=False, skip_flatfield=False, skip_straylight=True, overwrite=False): if list_bands is None: list_bands = ['12A', '12B', '12C', '34A', '34B', '34C'] for band in list_bands: mkdir_miri_files(os.path.join(uncal_path, target_name, band, 'stage2')) time0 = time.perf_counter() print(time0) cal_files = [] for band in list_bands: if custom_flatted: rate_file_band_path = os.path.join(uncal_path, target_name, band, 'stage1_flat') print(f"Processing the custom flatted rate files in {rate_file_band_path} for stage 2.") else: if custom_bkg_sub: rate_file_band_path = os.path.join(uncal_path, target_name, band, 'stage1_sub_bkg') print(f"Processing the background subtracted rate files in {rate_file_band_path} for stage 2.") else: rate_file_band_path = os.path.join(uncal_path, target_name, band, 'stage1') print(f"Processing the rate files in {rate_file_band_path} for stage 2.") rate_files = find_files_to_process(rate_file_band_path, filetype='rate.fits') for fid, rate_file in enumerate(rate_files): print(fid, rate_file) # Setting up steps and running the Spec2 portion of the pipeline. outputdir = os.path.join(uncal_path, target_name, band, 'stage2') out_name = os.path.join(outputdir, os.path.basename(rate_file).replace('rate.fits', 'cal.fits')) cal_files.append(out_name) if os.path.exists(out_name) and not overwrite: print(f"Output file {out_name} already exists;\n\tskipping {rate_file}.") continue spec2 = Spec2Pipeline() # spec2.output_dir = spec2_dir step_parameters = { # spec2.imprint_subtract.skip = False # spec2.msa_flagging.skip = False # # spec2.srctype.source_type = 'POINT' # spec2.flat_field.skip = False # spec2.pathloss.skip = False # spec2.photom.skip = False 'straylight': {'skip': skip_straylight}, 'flat_field': {'skip': skip_flatfield}, 'fringe': {'skip': skip_fringe}, 'residual_fringe': {'skip': skip_residual_fringes}, 'cube_build': {'skip': skip_cubes}, # We do not want or need interpolated cubes 'extract_1d': {'skip': True}, # spec3.cube_build.coord_system = 'skyalign' } spec2.save_bsub = True spec2.call(rate_file, save_results=True, output_dir=outputdir, steps=step_parameters) # Print out the time benchmark time1 = time.perf_counter() print(f"Runtime so far: {time1 - time0:0.4f} seconds") time1 = time.perf_counter() print(f"Total Runtime: {time1 - time0:0.4f} seconds") return cal_files def run_stage3_miri(uncal_path, target_name, list_bands=None, overwrite=False): if list_bands is None: list_bands = ['12A', '12B', '12C', '34A', '34B', '34C'] for band in list_bands: outputdir = mkdir_miri_files(os.path.join(uncal_path, target_name, band, 'stage3')) if os.path.exists(os.path.join(outputdir, f'Level3_ch{band[0]}-short_s3d.fits')) and overwrite is False: print(f"Output file Level3_ch{band[0]}-short_s3d.fits already exists;\n\tskipping.") continue inputdir = os.path.join(uncal_path, target_name, band, 'stage2') # Start a timer to keep track of runtime time0 = time.perf_counter() print(time0) calfiles = find_files_to_process(inputdir, filetype='mirifushort_cal.fits') sstring = calfiles # cal_files_dir + '*cal.fits' print(sstring) calfiles = np.array(sorted(sstring)) print(calfiles) sortfiles = sort_calfiles(calfiles) # Split them up into bands print('Found ' + str(len(calfiles)) + ' input files to process for stage 3') asnlist = [] bands = ['12A', '12B', '12C', '34A', '34B', '34C'] for ii in range(0, len(sortfiles)): thesefiles = sortfiles[ii] ninband = len(thesefiles) if (ninband > 0): filename = 'l3asn-' + bands[ii] + '.json' asnlist.append(filename) writel3asn(thesefiles, filename, 'Level3') print("asnlist", asnlist) runspec3(asnlist[0], outputdir) def run_full_miri_default_pipeline(uncal_path, target_name, list_bands=None, overwrite=False): run_stage1_miri(uncal_path, list_bands=list_bands, overwrite=overwrite, maximum_cores="1", skip_dark=False) run_stage2_miri(uncal_path, target_name, list_bands=list_bands, custom_flatted=False, skip_cubes=False, skip_fringe=False, skip_residual_fringes=True, skip_flatfield=False, skip_straylight=False, overwrite=overwrite) run_stage3_miri(uncal_path, target_name, list_bands=list_bands, overwrite=overwrite) return 1 # Define a useful function to write out a Lvl3 association file from an input list def writel3asn(files, asnfile, prodname, **kwargs): # Define the basic association of science files asn = afl.asn_from_list(files, rule=DMS_Level3_Base, product_name=prodname) # Add any background files to the association if ('bg' in kwargs): print("bg in kwargs") for bgfile in kwargs['bg']: asn['products'][0]['members'].append({'expname': bgfile, 'exptype': 'background'}) # Write the association to a json file _, serialized = asn.dump() with open(asnfile, 'w') as outfile: outfile.write(serialized) def sort_calfiles(files): channel = [] band = [] for file in files: hdr = (fits.open(file))[0].header channel.append(hdr['CHANNEL']) band.append(hdr['BAND']) channel = np.array(channel) band = np.array(band) indx = np.where((channel == '12') & (band == 'SHORT')) files12A = files[indx] indx = np.where((channel == '12') & (band == 'MEDIUM')) files12B = files[indx] indx = np.where((channel == '12') & (band == 'LONG')) files12C = files[indx] indx = np.where((channel == '34') & (band == 'SHORT')) files34A = files[indx] indx = np.where((channel == '34') & (band == 'MEDIUM')) files34B = files[indx] indx = np.where((channel == '34') & (band == 'LONG')) files34C = files[indx] return files12A, files12B, files12C, files34A, files34B, files34C def runspec3(filename, outputdir): # This initial setup is just to make sure that we get the latest parameter reference files # pulled in for our files. This is a temporary workaround to get around an issue with # how this pipeline calling method works. crds_config = Spec3Pipeline.get_config_from_reference('l3asn-12A.json') # The exact asn file used doesn't matter spec3 = Spec3Pipeline.from_config_section(crds_config) spec3.output_dir = outputdir spec3.save_results = True spec3.master_background.skip = True # Computes and subtracts a master background signal spec3.outlier_detection.skip = False # Identifies and flags any pixels with values that produce outliers in overlapping regions of cube space spec3.mrs_imatch.skip = False # Ensure that there are no jumps in the background between individual exposures spec3.cube_build.skip = False # Build the composite data cubes spec3.extract_1d.skip = False # Extract 1d spectra from the composite data cubes spec3.process(filename) def best_flat_selection(cal_file, flat_dir, channel, flat_extended=False, save_png=True, full_output=False): """ For MIRI, look at a selection of possible fringe flats and find the best match for a given science file Parameters ---------- cal_file flat_dir channel flat_extended : bool use the FLAT_EXTENDED extension, instead of regular FLAT? save_png : bool Save a PNG showing the fringes used to determine the best match full_output : bool """ hdu = fits.open(cal_file) data = hdu[1].data hdr = hdu[0].header pattern_type = hdr['PATTTYPE'] dither_direction = hdr['DITHDIRC'] dither_numero = hdr['PATT_NUM'] band = hdr['BAND'] print(f"Band: {band}") filenames = os.listdir(flat_dir) std = [] file = [] brightest_col = column_median_max_channel(data, channel=channel) print(f"Brightest column for Channel {channel}: {brightest_col}") if save_png: xlim = [450, 500] plt.title(f"Best fringes flat pattern selection\nFor channel {channel} {band}, using brighest column: {brightest_col} ") plt.xlabel("Row index") plt.ylabel("Fringes transmission") plt.xlim(*xlim) plt.ylim([0.5, 1.5]) file_name = hdr['FILENAME'] col_data = data[:, brightest_col] while np.any(np.isnan(col_data)): import astropy col_data = astropy.convolution.interpolate_replace_nans(col_data, kernel=[1,1,1]) continuum = gaussian_filter(col_data, sigma=8) fringe_data = col_data / continuum plt.plot(fringe_data, label='data') for filename in filenames: if filename.endswith(".fits"): flat_hdu = fits.open(os.path.join(flat_dir, filename)) hdr_flat = flat_hdu[0].header if flat_extended: flat = flat_hdu['FLAT_EXTENDED'].data else: flat = flat_hdu['FLAT'].data if hdr_flat['PATT_NUM'] == dither_numero and hdr_flat['PATTTYPE'] == pattern_type and hdr_flat[ 'DITHDIRC'] == dither_direction and hdr_flat['BAND'] == band: d_f = data[:, brightest_col] / flat[:, brightest_col] d_f_hf = d_f - gaussian_filter(d_f, sigma=8) fringe_residuals_std = np.nanstd(d_f_hf) std.append(fringe_residuals_std) file.append(filename) print(filename, fringe_residuals_std) if save_png: if std[-1] < 40: plt.plot(flat[:, brightest_col], label=filename) flat_hdu.close() idx = np.nanargmin(std) std_min = np.nanmin(std) flat_name = file[idx] print("Flat selected:", flat_name) if save_png: plt.text(0.05, 0.05, "Flat selected: "+flat_name, transform=plt.gca().transAxes) plt.legend() plt.tight_layout() out_name = f"./fig_fringes_{os.path.splitext(os.path.basename(file_name))[0]}.png" plt.savefig(out_name) print(f"==> Plot saved to {out_name}") plt.close() hdu_best_flat = fits.open(os.path.join(flat_dir, flat_name)) if flat_extended: best_flat = hdu_best_flat['FLAT_EXTENDED'].data else: best_flat = hdu_best_flat['FLAT'].data if full_output: return best_flat, flat_name, std_min, file, std else: return best_flat, flat_name, std_min def column_median_max(mat): medianes = np.nanmedian(mat, axis=0) col_index = np.nanargmax(medianes) return col_index def column_median_max_channel(data, channel='CH1'): if channel == 'CH1' or channel == 'CH4': brightest_col = column_median_max(data[:, :500]) elif channel == 'CH2' or channel == 'CH3': brightest_col = column_median_max(data[:, 500:]) + 500 else: raise ValueError('Channel must be CH1, CH2, CH3 or CH4') return brightest_col ## Breads function functions for miri def compute_coordinates_offset(path_cal_files, channel, utils_dir, target_name=None, IWA=None, OWA=None): from breads.instruments.jwstmiri_cal import JWSTMiri_cal from breads.instruments.jwstmiri_multiple_cals import JWSTMiri_multiple_cals from breads.instruments.jwstmiri_cal import get_contnorm_spec_miri if not os.path.exists(utils_dir): os.makedirs(utils_dir) def find_observation_numbers(cal_files): observation_numbers = [] for filename in cal_files: base = os.path.basename(filename).split('_')[0] observation_number = base[-6:-3] if observation_number not in observation_numbers: observation_numbers.append(observation_number) return observation_numbers cal_files = find_files_to_process(path_cal_files, filetype="cal.fits") observation_numbers = find_observation_numbers(cal_files) print('Observation numbers:', observation_numbers) coords_offset = [] for observation_number in observation_numbers: dataobj_list = [] for filename in cal_files: base = os.path.basename(filename).split('_')[0] obs_number_file = base[-6:-3] print(filename, 'Observation number:', obs_number_file) if obs_number_file == observation_number: print('yes', filename) preproc_task_list = [] preproc_task_list.append(["compute_med_filt_badpix", {"window_size": 10, "mad_threshold": 20}, True, True]) if target_name is not None: preproc_task_list.append(["compute_coordinates_arrays", {"targname": target_name}]) else: preproc_task_list.append(["compute_coordinates_arrays"]) preproc_task_list.append(["convert_MJy_per_sr_to_MJy"]) preproc_task_list.append(["compute_quick_webbpsf_model"]) dataobj = JWSTMiri_cal(filename, channel_reduction=channel, utils_dir=utils_dir, save_utils=True, load_utils=True, preproc_task_list=preproc_task_list) dataobj_list.append(dataobj) dataobj_combined = JWSTMiri_multiple_cals(dataobj_list) ra_offset, dec_offset = dataobj_combined.compute_new_coords_from_webbPSFfit(IWA=IWA, OWA=OWA) print(observation_number, ra_offset, dec_offset) coords_offset.append([observation_number, ra_offset, dec_offset]) return coords_offset def compute_normalized_stellar_spectrum_miri(cal_files, channel, utils_dir, coords_offset=(0, 0), wv_nodes=None, target_name=None, star_hf_subtraction=True, mppool=None, ra_dec_point_sources=None, overwrite=False): from breads.instruments.jwstmiri_cal import JWSTMiri_cal if not os.path.exists(utils_dir): os.makedirs(utils_dir) hdulist_sc = fits.open(cal_files[0]) detector = hdulist_sc[0].header["DETECTOR"].strip().lower() hdulist_sc.close() splitbasename = os.path.basename(cal_files[0]).split("_") combined_contnorm_spec_filename = os.path.join(utils_dir, splitbasename[0] + "_" + splitbasename[ 1] + "_" + detector + "_starspec_contnorm_combined_1dspline.fits") if not overwrite: if len(glob(combined_contnorm_spec_filename)): with fits.open(combined_contnorm_spec_filename) as hdulist: new_wavelengths = hdulist[0].data combined_fluxes = hdulist[1].data combined_errors = hdulist[2].data combined_star_func = interp1d(new_wavelengths, combined_fluxes, kind="linear", bounds_error=False, fill_value=1) return combined_star_func dataobj_list = [] for filename in cal_files: print(filename) preproc_task_list = [] preproc_task_list.append(["compute_med_filt_badpix", {"window_size": 10, "mad_threshold": 20}, True, True]) if target_name is not None: preproc_task_list.append(["compute_coordinates_arrays", {"targname": target_name}]) else: preproc_task_list.append(["compute_coordinates_arrays"]) preproc_task_list.append(["convert_MJy_per_sr_to_MJy"]) preproc_task_list.append(["compute_quick_webbpsf_model"]) preproc_task_list.append(["apply_coords_offset", {"coords_offset": coords_offset}]) preproc_task_list.append(["compute_starspectrum_contnorm", {"x_nodes": wv_nodes, "threshold_badpix": 100, "mppool": mppool, "star_hf_subtraction":star_hf_subtraction}, True, True]) preproc_task_list.append(["compute_starsubtraction", {"starsub_dir": "starsub1d", "threshold_badpix": 10, "mppool": mppool}, True, True]) dataobj = JWSTMiri_cal(filename, channel_reduction=channel, utils_dir=utils_dir, save_utils=True, load_utils=True, preproc_task_list=preproc_task_list) # mask planets before computing the star spectrum if ra_dec_point_sources is not None: for ra_pl, dec_pl in ra_dec_point_sources: where_pl = dataobj.where_point_source([ra_pl / 1000., dec_pl / 1000.], 0.16) dataobj.bad_pixels[where_pl] = np.nan dataobj_list.append(dataobj) new_wavelengths, combined_fluxes, combined_errors = get_contnorm_spec(dataobj_list, spline2d=False, load_utils=False, out_filename=combined_contnorm_spec_filename, spec_R_sampling=2700 * 4, interpolation="linear") if star_hf_subtraction: combined_star_func = interp1d(new_wavelengths, combined_fluxes, kind="linear", bounds_error=False, fill_value=1) else: combined_star_func = interp1d(np.arange(0, 30, 100), np.ones_like(np.arange(0, 30, 100)), kind="linear", bounds_error=False, fill_value=1) return combined_star_func def compute_starlight_subtraction_miri(cal_files, channel, utils_dir, wv_nodes=None, target_name=None, combined_star_func=None, star_hf_subtraction=True, coords_offset=(0, 0), mppool=None): from breads.instruments.jwstmiri_cal import JWSTMiri_cal dataobj_list = [] for filename in cal_files[0::]: print(filename) preproc_task_list = [] preproc_task_list.append(["compute_med_filt_badpix", {"window_size": 50, "mad_threshold": 50}, True, True]) preproc_task_list.append(["compute_coordinates_arrays", {"targname": target_name}, True, True]) preproc_task_list.append(["convert_MJy_per_sr_to_MJy"]) preproc_task_list.append(["apply_coords_offset", {"coords_offset": coords_offset}]) if combined_star_func is None: preproc_task_list.append(["compute_starspectrum_contnorm", {"x_nodes": wv_nodes, "threshold_badpix": 100, "mppool": mppool, "star_hf_subtraction":star_hf_subtraction}, True, True]) dataobj = JWSTMiri_cal(filename, channel_reduction=channel, utils_dir=utils_dir, save_utils=True, load_utils=True, preproc_task_list=preproc_task_list) if combined_star_func is not None: if star_hf_subtraction: dataobj.reload_starspectrum_contnorm() else: combined_star_func = interp1d(np.arange(0,30,100), np.ones_like(np.arange(0,30,100)), kind="linear", bounds_error=False, fill_value=1) dataobj.star_func = combined_star_func outputs = dataobj.reload_starsubtraction() if outputs is None: outputs = dataobj.compute_starsubtraction(save_utils=True, starsub_dir="starsub1d", threshold_badpix=10, mppool=mppool) dataobj_list.append(dataobj) return dataobj_list def get_combined_regwvs_miri(dataobj_list, channel, wv_sampling=None, use_starsub1d=False, reload=False): from breads.instruments.jwstmiri_cal import JWSTMiri_cal from breads.instruments.jwstmiri_multiple_cals import JWSTMiri_multiple_cals regwvs_dataobj_list = [] for dataobj in dataobj_list: if use_starsub1d: starsub_filename = os.path.join(dataobj.utils_dir, "starsub1d", os.path.basename(dataobj.filename)) print("starsub1d path for combined regwvs miri:", starsub_filename) starsub_dataobj = JWSTMiri_cal(starsub_filename, channel_reduction=channel, utils_dir=dataobj.utils_dir) if dataobj.data_unit == 'MJy': replace_data = dataobj.convert_MJy_per_sr_to_MJy(data_in_MJy_per_sr=starsub_dataobj.data) elif dataobj.data_unit == "MJy/sr": replace_data = starsub_dataobj.data regwvs_filename = dataobj.default_filenames["compute_interpdata_regwvs"].replace("_regwvs.fits", "_starsub1d_regwvs.fits") else: replace_data = None regwvs_filename = dataobj.default_filenames["compute_interpdata_regwvs"] print("regwvs path for combined regwvs miri:", regwvs_filename) if reload == True: regwvs_dataobj = dataobj.reload_interpdata_regwvs(load_filename=regwvs_filename) else: regwvs_dataobj = None if regwvs_dataobj is None: print("[DEBUG] get combined regwvs miri checking wv_sampling", wv_sampling.shape) regwvs_dataobj = dataobj.compute_interpdata_regwvs(save_utils=regwvs_filename, wv_sampling=wv_sampling, replace_data=replace_data) regwvs_dataobj_list.append(regwvs_dataobj) regwvs_combdataobj = JWSTMiri_multiple_cals(regwvs_dataobj_list) return regwvs_combdataobj