
module ctem

  use shr_kind_mod, only: r8 => shr_kind_r8
  use pmgrid,       only: plon, plev, plevp
  use cam_history
  use cam_logfile,  only: iulog

  implicit none

  private
  
  public :: ctem_inti
  public :: ctem_driver
  public :: ctem_output

  real(r8) :: rplon
  real(r8) :: iref_p(plevp)              ! interface reference pressure for vertical interpolation
  integer  :: ip_b                       ! level index where hybrid levels become purely pressure
  integer  :: zm_limit
  logical  :: twod_output

contains

!================================================================================

  subroutine ctem_driver( u3, v3, omga, pt, h2o, &
			  vth, wth, uv, uw, uzm, u2d, &
                          v2d, th2d, w2d, ip_gm1, thig, &
                          ps, pe, beglat, endlat, beglon, &
                          endlon, grid )

    use physconst, only          : zvir, cappa
    use spmd_utils, only         : iam
    use abortutils, only         : endrun
    use dynamics_vars, only      : T_FVDYCORE_GRID
    use hycoef, only             : ps0
    use interpolate_data, only   : vertinterp
#ifdef SPMD
    use mpishorthand,       only : mpilog, mpiint
    use parutilitiesmodule, only : pargatherint
#endif

!-------------------------------------------------------------
!	... dummy arguments
!-------------------------------------------------------------
    integer, intent(in)   :: beglat, endlat                          ! begin,end latitude indicies
    integer, intent(in)   :: beglon, endlon                          ! begin,end longitude indicies
    integer, intent(out)  :: ip_gm1(beglon:endlon,beglat:endlat)     ! contains level index-1 where blocked points begin
    real(r8), intent(in)  :: ps(beglon:endlon,beglat:endlat)         ! surface pressure (pa)
    real(r8), intent(in)  :: u3(beglon:endlon,plev,beglat:endlat)    ! zonal velocity (m/s)
    real(r8), intent(in)  :: v3(beglon:endlon,plev,beglat:endlat)    ! meridional velocity (m/s)
    real(r8), intent(in)  :: omga(beglon:endlon,plev,beglat:endlat)  ! pressure velocity
    real(r8), intent(in)  :: pe(beglon:endlon,plevp,beglat:endlat)   ! interface pressure (pa)
    real(r8), intent(in)  :: pt(beglon:endlon,beglat:endlat,plev)    ! virtual temperature
    real(r8), intent(in)  :: h2o(beglon:endlon,beglat:endlat,plev)   ! water constituent (kg/kg)
    real(r8), intent(out) :: vth(plevp,beglat:endlat)                ! VTH flux
    real(r8), intent(out) :: uv(plevp,beglat:endlat)                 ! UV flux
    real(r8), intent(out) :: wth(plevp,beglat:endlat)                ! WTH flux
    real(r8), intent(out) :: uw(plevp,beglat:endlat)                 ! UW flux
    real(r8), intent(out) :: uzm(plev,beglat:endlat)                 ! zonally averaged U
    real(r8), intent(out) :: u2d(plevp,beglat:endlat)                ! zonally averaged U
    real(r8), intent(out) :: v2d(plevp,beglat:endlat)                ! zonally averaged V
    real(r8), intent(out) :: th2d(plevp,beglat:endlat)               ! zonally averaged TH
    real(r8), intent(out) :: w2d(plevp,beglat:endlat)                ! zonally averaged W
    real(r8), intent(out) :: thig(beglon:endlon,plevp,beglat:endlat) ! interpolated pot. temperature
    type(T_FVDYCORE_GRID), intent(in) :: grid                        ! FV Dynamics grid

!-------------------------------------------------------------
!	... local variables
!-------------------------------------------------------------
    real(r8), parameter :: hscale = 7000._r8          ! pressure scale height
    real(r8), parameter :: navp   = 1.e35_r8
    
    real(r8) :: pinterp
    real(r8) :: w(beglon:endlon,plev,beglat:endlat)          ! vertical velocity
    real(r8) :: th(beglon:endlon,plev,beglat:endlat)         ! pot. temperature

    real(r8) :: pm(beglon:endlon,plev,beglat:endlat)         ! mid-point pressure
    real(r8) :: pexf                                         ! Exner function
    real(r8) :: psurf

    real(r8) :: ui(beglon:endlon,plevp)                      ! interpolated zonal velocity
    real(r8) :: vi(beglon:endlon,plevp)                      ! interpolated meridional velocity
    real(r8) :: wi(beglon:endlon,plevp)                      ! interpolated vertical velocity
    real(r8) :: thi(beglon:endlon,plevp)                     ! interpolated pot. temperature
    
    real(r8) :: um(plevp)                                    ! zonal mean zonal velocity
    real(r8) :: vm(plevp)                                    ! zonal mean meridional velocity
    real(r8) :: wm(plevp)                                    ! zonal mean vertical velocity
    real(r8) :: thm(plevp)                                   ! zonal mean pot. temperature

    real(r8) :: ud(beglon:endlon,plevp)                      ! zonal deviation of zonal velocity
    real(r8) :: vd(beglon:endlon,plevp)                      ! zonal deviation of meridional velocity
    real(r8) :: wd(beglon:endlon,plevp)                      ! zonal deviation of vertical velocity
    real(r8) :: thd(beglon:endlon,plevp)                     ! zonal deviation of pot. temperature

    real(r8) :: vthp(beglon:endlon,plevp)                    ! zonal deviation of zonal velocity
    real(r8) :: wthp(beglon:endlon,plevp)                    ! zonal deviation of meridional velocity
    real(r8) :: uvp(beglon:endlon,plevp)                     ! zonal deviation of vertical velocity
    real(r8) :: uwp(beglon:endlon,plevp)                     ! zonal deviation of pot. temperature

    real(r8) :: dummy(plon,plevp)
    real(r8) :: dum2(plon)
    real(r8) :: rdiv(plevp)
    
    integer  :: ip_gm1g(plon,beglat:endlat)                  ! contains level index-1 where blocked points begin
    integer  :: zm_cnt(plevp)                                ! counter
    integer  :: i,j,k
    integer  :: nlons
    integer  :: astat
    integer  :: t, dest, src

    logical  :: has_zm(plevp,beglat:endlat)                   ! .true. the (z,y) point is a valid zonal mean 

!omp parallel do private (i,j,k,pexf,psurf)
lat_loop1 : &
    do j = beglat, endlat
       do k = 1, plev
          do i = beglon, endlon
!-------------------------------------------------------------
! Calculate pressure and Exner function
!-------------------------------------------------------------
             pm(i,k,j) = 0.5 * ( pe(i,k,j) + pe(i,k+1,j) )
             pexf      = (ps0 / pm(i,k,j))**cappa
!-------------------------------------------------------------
! Convert virtual temperature to temperature and calculate potential temperature
!-------------------------------------------------------------
             th(i,k,j) = pt(i,j,k) / (1. + zvir*h2o(i,j,k)) 
             th(i,k,j) = th(i,k,j) * pexf
!-------------------------------------------------------------
! Calculate vertical velocity
!-------------------------------------------------------------
             w(i,k,j)  = - hscale * omga(i,k,j) / pm(i,k,j)
          end do
       end do
!-------------------------------------------------------------
! Keep track of where the bottom is in each column 
! (i.e., largest index for which P(k) <= PS)
!-------------------------------------------------------------
       ip_gm1(:,j) = plevp
       do i = beglon, endlon
          psurf = ps(i,j)
          do k = ip_b+1, plevp
             if( iref_p(k) <= psurf ) then
                ip_gm1(i,j) = k
             end if
          end do
       end do
    end do lat_loop1

    nlons = endlon - beglon + 1

#ifdef SPMD    
    if( grid%twod_decomp == 1 ) then
       if (grid%iam .lt. grid%npes_xy) then
          call pargatherint( grid%commxy_x, 0, ip_gm1, grid%strip2dx, ip_gm1g )
       endif
    else
       ip_gm1g(:,:) = ip_gm1(:,:)
    end if
#else
    ip_gm1g(:,:) = ip_gm1(:,:)
#endif
#ifdef CTEM_DIAGS
    write(iulog,*) '===================================================='
    do j = beglat,endlat
       write(iulog,'(''iam,myidxy_x,myidxy_y,j = '',4i4)') iam,grid%myidxy_x,grid%myidxy_y,j
       write(iulog,'(20i3)') ip_gm1(:,j)
    end do
    if( grid%myidxy_x == 0 ) then
       do j = beglat,endlat
          write(iulog,*) '===================================================='
          write(iulog,'(''iam,myidxy_x,myidxy_y,j = '',4i4)') iam,grid%myidxy_x,grid%myidxy_y,j
          write(iulog,'(20i3)') ip_gm1g(:,j)
       end do
       write(iulog,*) '===================================================='
#else
#ifdef SPMD    
    if( grid%myidxy_x == 0 ) then
#endif
#endif
lat_loop2 : &
       do j = beglat, endlat
          zm_cnt(:ip_b) = plon
          do k = ip_b+1, plevp
             zm_cnt(k) = count( ip_gm1g(:,j) >= k )
          end do
          has_zm(:ip_b,j) = .true.
          do k = ip_b+1, plevp
             has_zm(k,j) = zm_cnt(k) >= zm_limit
          end do
       end do lat_loop2
#ifdef SPMD    
    end if
    if( grid%twod_decomp == 1 ) then
       call mpibcast( has_zm, plevp*(endlat-beglat+1), mpilog, 0, grid%commxy_x )
       call mpibcast( ip_gm1g, plon*(endlat-beglat+1), mpiint, 0, grid%commxy_x )
    end if
#endif

#ifdef CTEM_DIAGS
    if( grid%myidxy_y == 12 ) then
       write(iulog,*) '^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^'
       write(iulog,'(''iam,myidxy_x,myidxy_y,j = '',4i4)') iam,grid%myidxy_x,grid%myidxy_y,beglat
       write(iulog,*) 'has_zm'
       write(iulog,'(20l2)') has_zm(:,beglat)
       write(iulog,*) 'ip_gm1g'
       write(iulog,'(20i4)') ip_gm1g(:,beglat)
       write(iulog,*) '^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^'
    end if
#endif

lat_loop3 : &
    do j = beglat, endlat
!-------------------------------------------------------------
! Vertical interpolation
!-------------------------------------------------------------
       do k = 1, plevp
          pinterp = iref_p(k)
!-------------------------------------------------------------
!      Zonal velocity
!-------------------------------------------------------------
          call vertinterp( nlons, nlons, plev, pm(beglon,1,j), pinterp, &
                           u3(beglon,1,j), ui(beglon,k) )
!-------------------------------------------------------------
!      Meridional velocity
!-------------------------------------------------------------
          call vertinterp( nlons, nlons, plev, pm(beglon,1,j), pinterp, &
                           v3(beglon,1,j), vi(beglon,k) )
!-------------------------------------------------------------
!      Vertical velocity
!-------------------------------------------------------------
          call vertinterp( nlons, nlons, plev, pm(beglon,1,j), pinterp, &
                           w(beglon,1,j), wi(beglon,k) )
!-------------------------------------------------------------
!      Pot. Temperature
!-------------------------------------------------------------
          call vertinterp( nlons, nlons, plev, pm(beglon,1,j), pinterp, &
                           th(beglon,1,j), thi(beglon,k) )
       end do
#ifdef CTEM_DIAGS
       if( j == endlat ) then
       write(iulog,*) '^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^'
       write(iulog,'(''iam,myidxy_x,myidxy_y,j = '',4i4)') iam,grid%myidxy_x,grid%myidxy_y,j
       write(iulog,*) 'iref_p'
       write(iulog,'(5g15.7)') iref_p(:)
       write(iulog,'(''pm(endlon,:,'',i2,'')'')') j
       write(iulog,'(5g15.7)') pm(endlon,:,j)
       write(iulog,'(''u3(endlon,:,'',i2,'')'')') j
       write(iulog,'(5g15.7)') u3(endlon,:,j)
       write(iulog,*) 'ui(endlon,:)'
       write(iulog,'(5g15.7)') ui(endlon,:)
       write(iulog,*) '^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^'
       end if
#endif

!-------------------------------------------------------------
! Calculate zonal averages
!-------------------------------------------------------------
       do k = ip_b+1, plevp
          if( has_zm(k,j) ) then
             where( ip_gm1(beglon:endlon,j) < k )
                ui(beglon:endlon,k)  = 0._r8
                vi(beglon:endlon,k)  = 0._r8
                wi(beglon:endlon,k)  = 0._r8
                thi(beglon:endlon,k) = 0._r8
             endwhere
          end if
       end do

       call par_xsum( grid, u3(beglon,1,j), plev, uzm(1,j) )
       call par_xsum( grid, ui, plevp, um )
       call par_xsum( grid, vi, plevp, vm )
       call par_xsum( grid, wi, plevp, wm )
       call par_xsum( grid, thi, plevp, thm )
       do k = 1,plev
          uzm(k,j) = uzm(k,j) * rplon
       end do
#ifdef CTEM_DIAGS
       if( j == endlat .and. grid%myidxy_y == 12 ) then
          write(iulog,*) '$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$'
          write(iulog,'(''iam,myidxy_x,myidxy_y,j = '',4i4)') iam,grid%myidxy_x,grid%myidxy_y,j
          write(iulog,*) 'um after par_xsum'
          write(iulog,'(5g15.7)') um(:)
          write(iulog,*) '$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$'
       end if
#endif
       do k = 1, ip_b
          um(k)     = um(k) * rplon
          vm(k)     = vm(k) * rplon
          wm(k)     = wm(k) * rplon
          thm(k)    = thm(k) * rplon
          u2d(k,j)  = um(k)
          v2d(k,j)  = vm(k)
          th2d(k,j) = thm(k)
          w2d(k,j)  = wm(k)
       end do
       do k = ip_b+1, plevp
          if( has_zm(k,j) ) then
             rdiv(k)   = 1._r8/count( ip_gm1g(:,j) >= k )
             um(k)     = um(k) * rdiv(k)
             vm(k)     = vm(k) * rdiv(k)
             wm(k)     = wm(k) * rdiv(k)
             thm(k)    = thm(k) * rdiv(k)
             u2d(k,j)  = um(k)
             v2d(k,j)  = vm(k)
             th2d(k,j) = thm(k)
             w2d(k,j)  = wm(k)
          else
             u2d(k,j)  = navp
             v2d(k,j)  = navp
             th2d(k,j) = navp
             w2d(k,j)  = navp
          end if
       end do

!-------------------------------------------------------------
! Calculate zonal deviations
!-------------------------------------------------------------
       do k = 1, ip_b
          ud(beglon:endlon,k)  = ui(beglon:endlon,k)  - um(k)
          vd(beglon:endlon,k)  = vi(beglon:endlon,k)  - vm(k)
          wd(beglon:endlon,k)  = wi(beglon:endlon,k)  - wm(k)
          thd(beglon:endlon,k) = thi(beglon:endlon,k) - thm(k)
       end do

       do k = ip_b+1, plevp
          if( has_zm(k,j) ) then
             where( ip_gm1g(beglon:endlon,j) >= k )
                ud(beglon:endlon,k)  = ui(beglon:endlon,k) - um(k)
                vd(beglon:endlon,k)  = vi(beglon:endlon,k) - vm(k)
                wd(beglon:endlon,k)  = wi(beglon:endlon,k) - wm(k)
                thd(beglon:endlon,k) = thi(beglon:endlon,k) - thm(k)
             elsewhere
                ud(beglon:endlon,k)  = 0._r8
                vd(beglon:endlon,k)  = 0._r8
                wd(beglon:endlon,k)  = 0._r8
                thd(beglon:endlon,k) = 0._r8
             endwhere
          end if
       end do

!-------------------------------------------------------------
! Calculate fluxes
!-------------------------------------------------------------
       do k = 1, ip_b
          vthp(:,k) = vd(:,k) * thd(:,k)
          wthp(:,k) = wd(:,k) * thd(:,k)
          uwp(:,k)  = wd(:,k) * ud(:,k)
          uvp(:,k)  = vd(:,k) * ud(:,k)
       end do

       do k = ip_b+1, plevp
          if( has_zm(k,j) ) then
             vthp(:,k) = vd(:,k) * thd(:,k)
             wthp(:,k) = wd(:,k) * thd(:,k)
             uwp(:,k)  = wd(:,k) * ud(:,k)
             uvp(:,k)  = vd(:,k) * ud(:,k)
          else
             vthp(:,k) = 0._r8
             wthp(:,k) = 0._r8
             uwp(:,k)  = 0._r8
             uvp(:,k)  = 0._r8
          end if
       end do

#ifdef CTEM_DIAGS
       if( j == endlat .and. grid%myidxy_y == 12 ) then
          write(iulog,*) '#################################################'
          write(iulog,*) 'DIAGNOSTICS before par_xsum'
          write(iulog,'(''iam,myidxy_x,myidxy_y,j = '',4i4)') iam,grid%myidxy_x,grid%myidxy_y,j
          write(iulog,*) 'has_zm'
          write(iulog,*) has_zm(:,j)
          write(iulog,*) 'rdiv'
          write(iulog,'(5g15.7)') rdiv(:)
          write(iulog,*) 'wm'
          write(iulog,'(5g15.7)') wm(:)
          write(iulog,*) 'um'
          write(iulog,'(5g15.7)') um(:)
          write(iulog,*) 'uw'
          write(iulog,'(5g15.7)') uw(:)
          write(iulog,*) '#################################################'
       end if
#endif
       call par_xsum( grid, vthp, plevp, vth(1,j) )
       call par_xsum( grid, wthp, plevp, wth(1,j) )
       call par_xsum( grid, uvp, plevp, uv(1,j) )
       call par_xsum( grid, uwp, plevp, uw(1,j) )
#ifdef CTEM_DIAGS
       if( j == endlat .and. grid%myidxy_y == 12 ) then
          write(iulog,*) '#################################################'
          write(iulog,'(''iam,myidxy_x,myidxy_y,j = '',4i4)') iam,grid%myidxy_x,grid%myidxy_y,j
          write(iulog,*) 'uw after par_xsum'
          write(iulog,'(5g15.7)') uw(:,j)
          write(iulog,*) '#################################################'
       end if
#endif
       do k = 1, ip_b
          vth(k,j) = vth(k,j) * rplon
          wth(k,j) = wth(k,j) * rplon
          uw(k,j)  = uw(k,j) * rplon
          uv(k,j)  = uv(k,j) * rplon
       end do
       do k = ip_b+1, plevp
          if( has_zm(k,j) ) then
             vth(k,j) = vth(k,j) * rdiv(k)
             wth(k,j) = wth(k,j) * rdiv(k)
             uw(k,j)  = uw(k,j) * rdiv(k)
             uv(k,j)  = uv(k,j) * rdiv(k)
          else
             vth(k,j) = navp
             wth(k,j) = navp
             uw(k,j)  = navp
             uv(k,j)  = navp
          end if
       end do

       thig(:,:,j) = thi(:,:)
    end do lat_loop3

  end subroutine ctem_driver

  subroutine ctem_output( lchnk, ncol, lons, ip_gm1, vth, &
                          wth, uv, uw, u2d, v2d,  &
                          th2d, w2d, thi ) 

  use ppgrid, only : pcols

!-------------------------------------------------------------
!	... dummy arguments
!-------------------------------------------------------------
    integer, intent(in)   :: lchnk
    integer, intent(in)   :: ncol
    integer, intent(in)   :: lons(pcols)
    integer, intent(in)   :: ip_gm1(pcols)
    real(r8), intent(in)  :: vth(pcols,plevp)
    real(r8), intent(in)  :: wth(pcols,plevp)
    real(r8), intent(in)  :: uv(pcols,plevp)
    real(r8), intent(in)  :: uw(pcols,plevp)
    real(r8), intent(in)  :: u2d(pcols,plevp)
    real(r8), intent(in)  :: v2d(pcols,plevp)
    real(r8), intent(in)  :: th2d(pcols,plevp)
    real(r8), intent(in)  :: w2d(pcols,plevp)
    real(r8), intent(in)  :: thi(pcols,plevp)

!-------------------------------------------------------------
!	... local variables
!-------------------------------------------------------------
    integer  :: i, k
    real(r8) :: dum2(ncol)
 
!-------------------------------------------------------------
! Do the 2D output
!-------------------------------------------------------------
       if( twod_output ) then
          dum2(:) = 0._r8
          do i = 1,ncol
             if( lons(i) <= plevp ) then
                dum2(i) = vth(i,lons(i))
             end if
          end do
          call outfld( 'VTH2d', dum2, ncol, lchnk )

          dum2(:) = 0._r8
          do i = 1,ncol
             if( lons(i) <= plevp ) then
                dum2(i) = wth(i,lons(i))
             end if
          end do
          call outfld( 'WTH2d', dum2, ncol, lchnk )

          dum2(:) = 0._r8
          do i = 1,ncol
             if( lons(i) <= plevp ) then
                dum2(i) = uv(i,lons(i))
             end if
          end do
          call outfld( 'UV2d', dum2, ncol, lchnk )

          dum2(:) = 0._r8
          do i = 1,ncol
             if( lons(i) <= plevp ) then
                dum2(i) = uw(i,lons(i))
             end if
          end do
          call outfld( 'UW2d', dum2, ncol, lchnk )

          dum2(:) = 0._r8
          do i = 1,ncol
             if( lons(i) <= plevp ) then
                dum2(i) = u2d(i,lons(i))
             end if
          end do
          call outfld( 'U2d', dum2, ncol, lchnk )

          dum2(:) = 0._r8
          do i = 1,ncol
             if( lons(i) <= plevp ) then
                dum2(i) = v2d(i,lons(i))
             end if
          end do
          call outfld( 'V2d', dum2, ncol, lchnk )

          dum2(:) = 0._r8
          do i = 1,ncol
             if( lons(i) <= plevp ) then
                dum2(i) = th2d(i,lons(i))
             end if
          end do
          call outfld( 'TH2d', dum2, ncol, lchnk )

          dum2(:) = 0._r8
          do i = 1,ncol
             if( lons(i) <= plevp ) then
                dum2(i) = w2d(i,lons(i))
             end if
          end do
          call outfld( 'W2d', dum2, ncol, lchnk )
       end if

       dum2(:ncol) = ip_gm1(:ncol)
       call outfld( 'MSKtem', dum2, ncol, lchnk )
!-------------------------------------------------------------
! 3D output
!-------------------------------------------------------------
       call outfld( 'VTH3d', vth, pcols, lchnk )
       call outfld( 'WTH3d', wth, pcols, lchnk )
       call outfld( 'UV3d', uv, pcols, lchnk )
       call outfld( 'UW3d', uw, pcols, lchnk )
       call outfld( 'TH', thi, pcols, lchnk )

  end subroutine ctem_output

!=================================================================================

  subroutine ctem_inti

  use spmd_utils, only : masterproc
  use hycoef, only     : hyai, hybi, ps0

!-------------------------------------------------------------
!	... local variables
!-------------------------------------------------------------
    integer :: k

    twod_output = plon >= plevp
    if( masterproc ) then
       if( .not. twod_output ) then
          write(iulog,*) 'At this resolution, no TEM diagnostic is provided in the seconday tapes.'
       end if
    end if

    rplon    = 1._r8/plon
    zm_limit = plon/3

!-------------------------------------------------------------
! Calculate reference pressure
!-------------------------------------------------------------
    do k = 1, plevp
       iref_p(k) = (hyai(k) + hybi(k)) * ps0
    end do
    if( masterproc ) then
       write(iulog,*) 'ctem_inti: iref_p'
       write(iulog,'(1p5g15.7)') iref_p(:)
    end if

!-------------------------------------------------------------
! Find level where hybrid levels become purely pressure 
!-------------------------------------------------------------
    ip_b = -1
    do k = 1,plev
       if( hybi(k) == 0._r8 ) ip_b = k
    end do

!-------------------------------------------------------------
! Initialize output buffer
!-------------------------------------------------------------
    call addfld ('VTH3d ','MK/S    ',plevp, 'A','Meridional Heat Flux: 3D zon. mean', phys_decomp )
    call addfld ('WTH3d ','MK/S    ',plevp, 'A','Vertical Heat Flux: 3D zon. mean', phys_decomp )
    call addfld ('UV3d  ','M2/S2   ',plevp, 'A','Meridional Flux of Zonal Momentum: 3D zon. mean', phys_decomp )
    call addfld ('UW3d  ','M2/S2   ',plevp, 'A','Vertical Flux of Zonal Momentum: 3D zon. mean', phys_decomp )
    if( twod_output ) then
       call addfld ('VTH2d ','MK/S    ',1, 'A','Meridional Heat Flux: 2D prj of zon. mean',phys_decomp )
       call addfld ('WTH2d ','MK/S    ',1, 'A','Vertical Heat Flux: 2D prj of zon. mean',phys_decomp )
       call addfld ('UV2d  ','M2/S2   ',1, 'A','Meridional Flux of Zonal Momentum: 2D prj of zon. mean',phys_decomp )
       call addfld ('UW2d  ','M2/S2   ',1, 'A','Vertical Flux of Zonal Momentum; 2D prj of zon. mean',phys_decomp )
       call addfld ('U2d   ','M/S     ',1, 'A','Zonal-Mean zonal wind',phys_decomp )
       call addfld ('V2d   ','M/S     ',1, 'A','Zonal-Mean meridional wind',phys_decomp )
       call addfld ('W2d   ','M/S     ',1, 'A','Zonal-Mean vertical wind',phys_decomp )
       call addfld ('TH2d  ','K       ',1, 'A','Zonal-Mean potential temp',phys_decomp )
    end if
    call addfld ('TH    ','K       ',plevp, 'A','Potential Temperature', phys_decomp )
    call addfld ('MSKtem','unitless',1    , 'A','TEM mask', phys_decomp )
    
!-------------------------------------------------------------
! primary tapes: 3D fields
!-------------------------------------------------------------
    call add_default ('VTH3d', 1, ' ')
    call add_default ('WTH3d', 1, ' ')
    call add_default ('UV3d' , 1, ' ')
    call add_default ('UW3d' , 1, ' ')
    call add_default ('TH' , 1, ' ')
    call add_default ('MSKtem',1, ' ')

!-------------------------------------------------------------
! secondary tapes: 2D fields
!-------------------------------------------------------------
    if( twod_output ) then
       call add_default ('VTH2d', 2, ' ')
       call add_default ('WTH2d', 2, ' ')
       call add_default ('UV2d' , 2, ' ')
       call add_default ('UW2d' , 2, ' ')
       call add_default ('TH' , 2, ' ')
       call add_default ('MSKtem',2, ' ')
    end if

  end subroutine ctem_inti

end module ctem
