﻿#!/usr/bin/env python

import os, sys
import numpy as np

from pyraf import iraf
from iraf import (
    images,  # immatch, imutil
    immatch, # imcombine
    imutil,  # imarith, imstatistics, imreplace, imcopy
)
from astropy.io import fits

# fiber flux template 
#import lib_fiber_flux

### set constants
n_fiber = 110

fiber_pos_2d = ( #from 2022B without sky fibres, edit by K.Isogai on 3 Jul 2022
     (26,  0,   0,  0,  40,  0,   0,  0,  73,  0,   0,  0,  31,  0,   0,  0,  25,  0,   0,  0,  101,  0),
     (0,   0,  81,  0,   0,  0,  45,  0,   0,  0,  42,  0,   0,  0,  22,  0,   0,  0,  10,  0,    0,  0),
     (3,   0,   0,  0,  49,  0,   0,  0,  61,  0,   0,  0,  54,  0,   0,  0,  17,  0,   0,  0,  100,  0),
     (0,   0, 96,  0,   0,  0,  91,  0,   0,  0,  53,  0,   0,  0,  37,  0,   0,  0, 105,  0,    0,  0),
     (6,   0,   0,  0,  69,  0,   0,  0, 98,  0,   0,  0,  76,  0,   0,  0,   8,  0,   0,  0,  104,  0),
     (0,   0,  20,  0,   0,  0, 107,  0,   0,  0,  70,  0,   0,  0,  19,  0,   0,  0,  12,  0,    0,  0),
     (109, 0,   0,  0,  90,  0,   0,  0,  57,  0,   0,  0,  63,  0,   0,  0,  39,  0,   0,  0,  110,  0),
     (0,   0,  18,  0,   0,  0,  52,  0,   0,  0,  75,  0,   0,  0,  41,  0,   0,  0,  32,  0,    0,  0),
     (14,  0,   0,  0,  30,  0,   0,  0,  99,  0,   0,  0,   4,  0,   0,  0,  34,  0,   0,  0,    7,  0),
     (0,   0,  47,  0,   0,  0,  59,  0,   0,  0, 108,  0,   0,  0,   1,  0,   0,  0,  38,  0,    0,  0),
     (83,  0,   0,  0,  84,  0,   0,  0,  64,  0,   0,  0,  67,  0,   0,  0,  28,  0,   0,  0,    5,  0),
     (0,   0,  27,  0,   0,  0,  72,  0,   0,  0,  74,  0,   0,  0,  56,  0,   0,  0,  35,  0,    0,  0),
     (95,  0,   0,  0,  87,  0,   0,  0,  51,  0,   0,  0,  71,  0,   0,  0,  13,  0,   0,  0,   33,  0),
     (0,   0,  65,  0,   0,  0,  85,  0,   0,  0, 102,  0,   0,  0,  46,  0,   0,  0,  15,  0,    0,  0),
     (106, 0,   0,  0,  82,  0,   0,  0,  44,  0,   0,  0,  55,  0,   0,  0,  21,  0,   0,  0,   11,  0),
     (0,   0,  60,  0,   0,  0,  97,  0,   0,  0,  68,  0,   0,  0,  50,  0,   0,  0,  36,  0,    0,  0),
     (94,  0,   0,  0,  29,  0,   0,  0,  58,  0,   0,  0,  77,  0,   0,  0,  43,  0,   0,  0,    9,  0),
     (0,   0,  88,  0,   0,  0, 103,  0,   0,  0,  78,  0,   0,  0,  66,  0,   0,  0,  24,  0,    0,  0),
     (86,  0,   0,  0,  48,  0,   0,  0,  89,  0,   0,  0,  62,  0,   0,  0,  80,  0,   0,  0,   79,  0),
     (0,   0,  16,  0,   0,  0,  93,  0,   0,  0,  92,  0,   0,  0,  23,  0,   0,  0,   2,  0,    0,  0),
     (0,   0,   0,  0,   0,  0,   0,  0,   0,  0,   0,  0,   0,  0,   0,  0,   0,  0,   0,  0,    0,  0))

fiber_pos_x = ( #from 2022B without sky fibres, edit by K.Isogai on 3 Jul 2022
  15, 19,  1, 13, 21,   1, 21, 17, 21, 19, #  0
  21, 19, 17,  1, 19,   3, 17,  3, 15,  3, # 10
  17, 15, 15, 19, 17,   1,  3, 17,  5,  5, # 20
  13, 19, 21, 17, 19,  19, 15, 19, 17,  5, # 30
  15, 11, 17,  9,  7,  15,  3,  5,  5, 15, # 40
   9,  7, 11, 13, 13,  15,  9,  9,  7,  3, # 50
   9, 13, 13,  9,  3,  15, 13, 11,  5, 11, # 60
  13,  7,  9, 11, 11,  13, 13, 11, 21, 17, # 70
   3,  5,  1,  5,  7,  1,  5,  3,   9,  5, # 80
   7, 11,  7,  1,  1,  3,  7,  9,   9, 21, # 90
   21, 11, 7, 21, 19,  1,  7, 11,   1, 21 ) #100

fiber_pos_y = ( #from 2022B without sky fibres, edit by K.Isogai on 3 Jul 2022
  10, 20,   3,  9, 11,  5,  9,  5, 17,  2, #0
  15,  6,  13,  9, 14, 20,  3,  8,  6,  6, #10
  15,  2,  20, 18,  1,  1, 12, 11, 17,  9, #20
   1,  8,  13,  9, 12, 16,  4, 10,  7,  1, #30
   8,  2,  17, 15,  2, 14, 10, 19,  3, 16, #40
  13,  8,   4,  3, 15, 12,  7, 17, 10, 16, #50 
   3, 19,   7, 11, 14, 18, 11, 16,  5,  6, #60
  13, 12,   1, 12,  8,  5, 17, 18, 19, 19, #70
   2, 15,  11, 11, 14, 19, 13, 18, 19,  7, #80
   4, 20,  20, 17, 13,  4, 16,  5,  9,  3, #90
   1, 14,  18,  5,  4, 15,  6, 10,  7,  7 ) #100

### Common functions
def input_file_list(input_file):
  file_list = []

  if (input_file.startswith('@')):
    with open(input_file.lstrip('@'), 'r') as fin:
      for line in fin.readlines():
        file_list.append(line.rstrip('\n\r'))
  else:
    file_list.append(input_file)

  if len(file_list) == 0:
    print('Input argument is NULL.')
    return

  return file_list

### Subtract and cut overscan regions
def sub_cut_overscan(argvs):

# show usage
  if (len(argvs) != 2):
    print('Usage: kools_ifu_red3.subtract_overscan(["input","output"])')
    print('  In case of file lists, put "@" at the head of arguments.')
    return

# set input files
  input_list = []
  output_list = []

  if (argvs[0].startswith("@")):
    tmp_str = argvs[0].split("@")
    fin = open(tmp_str[1], "r")
    for line in fin.readlines():
      input_list.append(line.rstrip("\n\r"))
    fin.close()

  if (argvs[1].startswith("@")):
    tmp_str = argvs[1].split("@")
    fin = open(tmp_str[1], "r")
    for line in fin.readlines():
      output_list.append(line.rstrip("\n\r"))
    fin.close()

  else:
    input_list.append(argvs[0])
    output_list.append(argvs[1])

# set constants
  tmp_fits = 'tmp_overscan.fits'

# subtract ovserscan count and cut overscan regions
  for fits_input, fits_output in zip(input_list, output_list):

    if not os.access(fits_input, os.R_OK):
      print('Cannot open %s.' % fits_input)
      return

    subtract_overscan([fits_input, tmp_fits])
    cut_overscan([tmp_fits, fits_output])

# remove temporary files
  if os.path.exists(tmp_fits):
    os.remove(tmp_fits)

  return

### subtract overscan region count
def subtract_overscan(argvs):

# show usage
  if (len(argvs)!= 2):
    print('Usage: kools_ifu_red3.subtract_overscan(["input","output"])')
    print('  In case of file lists, put "@" for file names.')
    return

# set input files
  input_list = []
  output_list = []

  if (argvs[0].startswith("@")):
    tmp_str = argvs[0].split("@")
    fin = open(tmp_str[1], "r")
    for line in fin.readlines():
      input_list.append(line.rstrip("\n\r"))
    fin.close()

  if (argvs[1].startswith("@")):
    tmp_str = argvs[1].split("@")
    fin = open(tmp_str[1], "r")
    for line in fin.readlines():
      output_list.append(line.rstrip("\n\r"))
    fin.close()

  else:
    input_list.append(argvs[0])
    output_list.append(argvs[1])

# set constants
  x1_imstat = (523, 537, 1595, 1609)
  x2_imstat = (536, 550, 1608, 1622)
  y1_imstat = (21, 1521)
  y2_imstat = (120, 1620)
  y_min = 1
  y_max = 1640

  fits_template = 'template_subtract_overscan.fits'
  tmp_fits = 'tmp_subtract.fits'

  iraf.imstat.field = 'mean'

# frame loop
  for fits_input, fits_output in zip(input_list, output_list):

    if not os.access(fits_input, os.R_OK):
      print('Cannot open %s.' % fits_input)
      return

# pre-subtract bias count
    image = '%s[%d:%d,%d:%d]' % (fits_input, x1_imstat[0], x2_imstat[0], y1_imstat[0], y2_imstat[0])
    out_imstat = iraf.imstat(image, Stdout = 1)

    if os.path.exists(tmp_fits):
      os.remove(tmp_fits)
    iraf.imarith(fits_input, '-', float(out_imstat[1]), tmp_fits)

# make a subtracting frame
    if os.path.exists(fits_template):
      os.remove(fits_template)
    iraf.imarith(fits_input, '*', 0.333, fits_template)

    for count in range(len(x1_imstat)):
      tmp_sum = 0.0
      for count2 in range(len(y1_imstat)):
# imstat at edge regions
        image = '%s[%d:%d,%d:%d]' % (tmp_fits, x1_imstat[count], x2_imstat[count], y1_imstat[count2], y2_imstat[count2])
        out_imstat = iraf.imstat(image, Stdout = 1)
        tmp_sum = tmp_sum + float(out_imstat[1])

# make an overscan count image
      image = '%s[%d:%d,%d:%d]' % (fits_template, 1 + 536 * count, 536 + 536 * count, y_min, y_max)
      iraf.imreplace(image, tmp_sum / 2.0)

# subtract overscan counts
    if os.path.exists(fits_output):
      os.remove(fits_output)
    iraf.imarith(tmp_fits, '-', fits_template, fits_output)
    print('%s -> %s' % (fits_input, fits_output))

# remove temporary files
  if os.path.exists(tmp_fits):
    os.remove(tmp_fits)
  if os.path.exists(fits_template):
    os.remove(fits_template)

  return

### Cut overscan regions
def cut_overscan(argvs):

# show usage
  if (len(argvs) != 2):
    print('Usage: kools_ifu_red3.cut_overscan(["input","output"])')
    print('  In case of file lists, put "@" for file names.')
    return

# set input files
  input_list = []
  output_list = []

  if (argvs[0].startswith("@")):
    tmp_str = argvs[0].split("@")
    fin = open(tmp_str[1], "r")
    for line in fin.readlines():
      input_list.append(line.rstrip("\n\r"))
    fin.close()

  if (argvs[1].startswith("@")):
    tmp_str = argvs[1].split("@")
    fin = open(tmp_str[1], "r")
    for line in fin.readlines():
      output_list.append(line.rstrip("\n\r"))
    fin.close()

  else:
    input_list.append(argvs[0])
    output_list.append(argvs[1])

# set constants
  x1 = (10, 552, 1082, 1624)
  x2 = (521, 1063, 1593, 2135)

# frame loop
  for fits_input, fits_output in zip(input_list, output_list):

    if not os.access(fits_input, os.R_OK):
      print('Cannot open %s.' % fits_input)
      return

# paste cut images
    if os.path.exists(fits_output):
      os.remove(fits_output)

    image_in = '%s[1:2048,*]' % fits_input
    image_out = '%s' % fits_output
    iraf.imcopy(image_in, image_out)

    for count in range(4):
      image_in = '%s[%d:%d,*]' % (fits_input, x1[count], x2[count])
      image_out = '%s[%d:%d,*]' % (fits_output, 1 + 512 * count, 512 + 512 * count)
      iraf.imcopy(image_in, image_out)

  return

### Adjust gain
def adjust_gain(argvs):

# show usage
  if (len(argvs) != 3):
    print('Usage: kools_ifu_red3.adjust_gain(["reference_fits","input","output"])')
    print('  Reference_fits should be a flat frame.')
    print('  In case of file lists for input and output, put "@" at the head of arguments.')
    return

# set constant
  x1_imstat = (511, 513, 1023, 1025, 1535, 1537)
  x2_imstat = (512, 514, 1024, 1026, 1536, 1538)
  y1_imstat = 401
  y2_imstat = 1200
  
  iraf.imstat.field = 'mean'
  
  gain_list = [1, 1, 1.55, 1]
  
  fits_template = 'for-gain-correction.fits'

# set input files
  input_list = input_file_list(argvs[1])
  output_list = input_file_list(argvs[2])

# measure count ratios at the edges of CCD channels
  out_imstat_list = []
  for count in range(len(x1_imstat)):
    image = '%s[%d:%d,%d:%d]' % (argvs[0], x1_imstat[count], x1_imstat[count] + 1, y1_imstat, y2_imstat)
    ret = iraf.imstat(image, Stdout = 1)
    out_imstat_list.append(float(ret[1]))

# calculate gains at CCD channels
  gain_list[0] = gain_list[2] * (out_imstat_list[3] / out_imstat_list[2]) * (out_imstat_list[1] / out_imstat_list[0])
  gain_list[1] = gain_list[2] * (out_imstat_list[3] / out_imstat_list[2])
  gain_list[3] = gain_list[2] * (out_imstat_list[4] / out_imstat_list[5])

  print(gain_list)

# make a template fits file
  if os.path.exists(fits_template):
    os.remove(fits_template)
  iraf.imcopy(input_list[0], fits_template)

  for count in range(4):
    image = '%s[%d:%d,*]' % (fits_template, 1 + 512 * count, 512 + 512 * count)
    print(image)
    iraf.imreplace(image, gain_list[count])

# multiply frames by gain
  for fits_input, fits_output in zip(input_list, output_list):
    if not os.access(fits_input, os.R_OK):
      print('Cannot open %s.' % fits_input)
      return

    if os.path.exists(fits_output):
      os.remove(fits_output)
    iraf.imarith(fits_input, '*', fits_template, fits_output)

  return

### make rough 2D fits
def fiber_make_image(argvs):
# show usage
  if (len(argvs) != 3):
    print('Usage: kools_ifu_red3.fiber_make_image(["input_fits",lambda_start,lambda_end])')
    print('  Lambda is written in pixel.')
    return

  lambda_start, lambda_end = argvs[1:3]

# input the fits file
  data_in, header = fits.getdata(argvs[0], view=np.ndarray, header=True)

# insert sky spectra if blank
  input_shape = np.shape(data_in)
  if input_shape[0] < n_fiber:
    insert_zero = np.zeros((n_fiber - input_shape[0], input_shape[1]))
    data_in = np.insert(data_in, 78, insert_zero, axis=0)

# make the reconstruct image
  tmp_size = np.shape(fiber_pos_2d)
  data_out = np.zeros((tmp_size[0], tmp_size[1]))

  for count_fiber in range(n_fiber):
    data_out[fiber_pos_y[count_fiber] - 1, fiber_pos_x[count_fiber] - 1] = data_in[count_fiber, lambda_start:(lambda_end + 1)].sum()

# fill 2x2 pixels for a fiber
  data_roll_x = np.roll(data_out, 1, axis=1)
  data_roll_y = np.roll(data_out, 1, axis=0)
  data_roll_xy = np.roll(data_out, (1, 1), axis=(0, 1))
  data_out = data_out + data_roll_x + data_roll_y + data_roll_xy

# output the fits file
  output_fits = argvs[0].replace('.fits', '') + '-' + str(lambda_start) + '-' + str(lambda_end) + '.fits'
  fits.writeto(output_fits, data_out, header, overwrite=True)
  print(argvs[0], '->', output_fits)

  return

### make rough 3D fits
def fiber_make_3d_image(argvs):
# show usage
  if (len(argvs) != 1):
    print('Usage: kools_ifu_red.fiber_make_3d_image(["input_fits"])')
    return

# input file
  input_fits = argvs[0]
  if not os.access(input_fits, os.R_OK):
    print('Cannot open %s.'.format(input_fits))
    return

  data_in, header_in = fits.getdata(input_fits, view=np.ndarray, header=True)

# insert sky spectra if blank
  input_shape = np.shape(data_in)
  if input_shape[0] < n_fiber:
    insert_zero = np.zeros((n_fiber - input_shape[0], input_shape[1]))
    data_in = np.insert(data_in, 78, insert_zero, axis=0)

# prepare for an output file
  tmp_size = np.shape(fiber_pos_2d)
  out_spec_size = int(header_in['NAXIS1'])
  output_fits = input_fits.replace('.fits', '-3d.fits')

# make output data
  data_out = np.zeros((out_spec_size, tmp_size[0], tmp_size[1] + 1))
  for count in range(n_fiber):
    data_out[:, fiber_pos_y[count] - 1, fiber_pos_x[count] - 1] = data_in[count, :]

# fill 2x2 pixels for a fiber
  data_roll_x = np.roll(data_out, 1, axis=2)
  data_roll_y = np.roll(data_out, 1, axis=1)
  data_roll_xy = np.roll(data_out, (1, 1), axis=(1, 2))
  data_out = data_out + data_roll_x + data_roll_y + data_roll_xy

# make output header
  header_out = header_in
  header_out['CRVAL1'] = 1
  header_out['CDELT1'] = 1
  header_out['CD1_1'] = 1

  header_out['CRVAL3'] = header_in['CRVAL1'] 
  header_out['CDELT3'] = header_in['CDELT1'] 
  header_out['CD3_3'] = header_in['CD1_1'] 

# output fits
  fits.writeto(output_fits, data_out, header_out, overwrite=True)
  print(input_fits, '->', output_fits)

  return

### Output fiber ID from position 
def fiber_pos_to_id(argvs):

# show usage
  if (len(argvs) != 2):
    print('The number of argments does not match.')
    print('Usage: kools_ifu_red3.fiber_pos_to_id([pos_x,pos_y])')
    return

  pos_x = int(argvs[0])
  pos_y = int(argvs[1])

  if (pos_x < 1 or pos_y < 1 or pos_x > len(fiber_pos_2d[0]) or pos_y > len(fiber_pos_2d)):
    print("Fiber position is out of range.")
    return

  ID = fiber_pos_2d[pos_y - 1][pos_x - 1]

  if (ID == 0):
    print("Fiber position is out of range.")
  else:
    print("position(x, y) = (%d, %d)  ID = %d" % (pos_x, pos_y, ID))

  return ID

### Output fiber ID from position 
def fiber_id_to_pos(argvs):

# show usage
  if (len(argvs) != 1):
    print("The number of argments does not match.")
    print('Usage: kools_ifu_red3.fiber_id_to_pos([fiberID])')
    return

  ID = int(argvs[0])
  if 1 <= ID <= n_fiber:
    print("ID = %d  position(x, y) = (%d, %d)" % (ID, fiber_pos_x[ID-1], fiber_pos_y[ID-1]))
  else:
    print("ID must be between 1 and {}.".format(n_fiber))

  return (fiber_pos_x[ID-1], fiber_pos_y[ID-1])

### Subtract sky emission
def subtract_sky(argvs):

# show usage
  if (len(argvs) != 3):
    print('Usage: kools_ifu_red3.subtract_sky(["input","output",[sky_y]])')
    print('  sky_y must be a list.')
    print('  sky_y example: [17,18,19,20,21,22,95,96,97,98,99,100,101,102,103,104,105]')
    print('  sky_y example: [95]')
    return

# set constants
  input_name = argvs[0]
  output_name = argvs[1]
  sky_y_list = list(set(argvs[2]))

  sky_file_list = ''
  fits_sky_1d = 'tmp_sky_1d.fits'
  fits_sky_2d = 'tmp_sky_2d.fits'

# input file
  if not os.access(input_name, os.R_OK):
    print('Cannot open %s.' % fits_input)
    return

  input_fits = fits.open(input_name)
  input_data = output_data = input_fits[0].data
  input_header = input_fits[0].header

# make sky spectrum
  sky_data = []
  for sky_y in sky_y_list:
    if 1 <= sky_y <= n_fiber:
      sky_data.append(input_data[sky_y - 1, :])

  if len(sky_data) == 0:
    print('Input sky_y is invalid.')
    print('No file is created.')
    return

  sky_1d = np.median(sky_data, axis=0)
#  print('fits_sky_1d =', fits_sky_1d)

# subtract sky spectrum from data
  for count in range(n_fiber):
    output_data[count, :] -= sky_1d

# output fits
  hdu = fits.PrimaryHDU(output_data, header = input_header)
  hdu.writeto(output_name, overwrite = True)

  print(input_name, '->', output_name)

  return

