Skip to content

Commit e8332b4

Browse files
refactor: separate the print_psi from wfc_2d_to_grid and remove the invalid call (#4268)
* refactor: seperate the print_psi from wfc_2d_to_grid and remove the invalid call * correct fn * transfer pointer to vector and fix compiling bug * use Cpxgemr2d to do the gather of wfc * fix the compiling error and ut error * fix bug in diago_cusolver * fix integrate test * use orb_con.ParaV instead of LOWF.ParaV; add unit test * move write_wfc_nao to write_wfc_lcao; and correct some description in doc. * fix bug * fxi bug in read_wfc_nao_test * solve the conflict of CURRENT_SPIN * fix bug in integrate test * fix bug in integrate test * fix bug in integrate * rename wfc_lcao to wfc_nao * fix bug in Makefile.Objiects; and add comments for Cpxgemr2d --------- Co-authored-by: Mohan Chen <mohan.chen.chen.mohan@gmail.com>
1 parent 854d49c commit e8332b4

39 files changed

Lines changed: 804 additions & 635 deletions

File tree

docs/advanced/elec_properties/wfc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ABACUS is able to output electron wave functions in both PW and LCAO basis calcu
55
## wave function in G space
66
For the wave function in G space, one only needs to do a ground-state energy calculation with one additional keyword in the INPUT file: '***[out_wfc_pw](https://abacus-rtd.readthedocs.io/en/latest/advanced/input_files/input-main.html#out-wfc-pw)***' for PW basis calculation, and '***[out_wfc_lcao](https://abacus-rtd.readthedocs.io/en/latest/advanced/input_files/input-main.html#out-wfc-lcao)***' for LCAO basis calculation.
77
In the PW basis case, the wave function is output in a file called `WAVEFUNC${k}.txt`, where `${k}` is the index of K point. \
8-
In the LCAO basis case, several `LOWF_K_${k}.dat` files will be output in multi-k calculation and `LOWF_GAMMA_S1.dat` in gamma-only calculation.
8+
In the LCAO basis case, several `WFC_NAO_K${k}.dat` files will be output in multi-k calculation and `WFC_NAO_GAMMA1.dat` in gamma-only calculation.
99

1010
## wave function in real space
1111

docs/advanced/input_files/input-main.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,14 +1541,14 @@ These variables are used to control the output of properties.
15411541

15421542
- **Type**: Integer
15431543
- **Availability**: Numerical atomic orbital basis
1544-
- **Description**: Whether to output the wavefunction coefficients into files in the folder `OUT.${suffix}`. The files are named as:
1544+
- **Description**: Whether to output the wavefunction coefficients into files in the folder `OUT.${suffix}`. The files are named as `WFC_{GAMMA|K}{index of K point}`, and if [out_app_flag](#out_app_flag) is false, the file name will also contains `_ION{ION step}`, where `ION step` is the index of ionic step:
15451545
- 0: no output
1546-
- 1: (txt format)
1547-
- gamma-only: `LOWF_GAMMA_S1.txt`;
1548-
- non-gamma-only: `LOWF_K_${k}.txt`, where `${k}` is the index of k points.
1546+
- 1: (txt format)
1547+
- gamma-only: `WFC_NAO_GAMMA1_ION1.txt` or `WFC_NAO_GAMMA1.txt`, ...;
1548+
- non-gamma-only: `WFC_NAO_K1_ION1.txt` or `WFC_NAO_K1.txt`, ...;
15491549
- 2: (binary format)
1550-
- gamma-only: `LOWF_GAMMA_S1.dat`;
1551-
- non-gamma-only: `LOWF_K_${k}.dat`, where `${k}` is the index of k points.
1550+
- gamma-only: `WFC_NAO_GAMMA1_ION1.dat` or `WFC_NAO_GAMMA1.dat`, ...;
1551+
- non-gamma-only: `WFC_NAO_K1_ION1.dat` or `WFC_NAO_K1.dat`, ....
15521552

15531553
The corresponding sequence of the orbitals can be seen in [Basis Set](../pp_orb.md#basis-set).
15541554

@@ -2414,7 +2414,7 @@ These variables are used to control molecular dynamics calculations. For more in
24142414

24152415
- **Type**: Boolean
24162416
- **Description**: Control whether to restart molecular dynamics calculations and time-dependent density functional theory calculations.
2417-
- True: ABACUS will read in `${read_file_dir}/Restart_md.dat` to determine the current step `${md_step}`, then read in the corresponding `STRU_MD_${md_step}` in the folder `OUT.$suffix/STRU/` automatically. For tddft, ABACUS will also read in `LOWF_K_${kpoint}` of the last step (You need to set out_wfc_lcao=1 and out_app_flag=0 to obtain this file).
2417+
- True: ABACUS will read in `${read_file_dir}/Restart_md.dat` to determine the current step `${md_step}`, then read in the corresponding `STRU_MD_${md_step}` in the folder `OUT.$suffix/STRU/` automatically. For tddft, ABACUS will also read in `WFC_NAO_K${kpoint}` of the last step (You need to set out_wfc_lcao=1 and out_app_flag=0 to obtain this file).
24182418
- False: ABACUS will start molecular dynamics calculations normally from the first step.
24192419
- **Default**: False
24202420

examples/wfc/lcao_ienvelope_Si2/run.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ if [[ ! -f scf.output ]] ||
1515
[[ ! -f get_wf.output ]] ||
1616
[[ ! -f OUT.ABACUS/running_scf.log ]] ||
1717
[[ ! -f OUT.ABACUS/running_get_wf.log ]] ||
18+
[[ ! -f OUT.ABACUS/WFC_NAO_K1.txt ]] ||
19+
[[ ! -f OUT.ABACUS/WFC_NAO_K36.txt ]] ||
1820
[[ ! ( "$(tail -1 OUT.ABACUS/running_scf.log)" == " Total Time :"* ) ]] ||
1921
[[ ! ( "$(tail -1 OUT.ABACUS/running_get_wf.log)" == " Total Time :"* ) ]]
2022
then

examples/wfc/lcao_scf_Si2/run.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ OMP_NUM_THREADS=${ABACUS_THREADS} mpirun -np ${ABACUS_NPROCS} ${ABACUS_PATH} | t
88

99
if [[ ! -f output ]] ||
1010
[[ ! -f OUT.ABACUS/running_scf.log ]] ||
11+
[[ ! -f OUT.ABACUS/WFC_NAO_K1.txt ]] ||
12+
[[ ! -f OUT.ABACUS/WFC_NAO_K36.txt ]] ||
1113
[[ ! ( "$(tail -1 OUT.ABACUS/running_scf.log)" == " Total Time :"* ) ]]
1214
then
1315
echo "job is failed!"
1416
exit 1
1517
else
1618
echo "job is successed!"
1719
exit 0
18-
fi
20+
fi

source/module_base/scalapack_connector.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,56 @@ extern "C"
115115
std::complex<double> *A, const int *IA, const int *JA, const int *DESCA,
116116
std::complex<double> *B, const int *IB, const int *JB, const int *DESCB,
117117
const int *ICTXT);
118+
119+
// Scalapack wrappers to copy 2D blocks of data
120+
// more info:
121+
// https://netlib.org/scalapack/explore-html/da/db5/pigemr_8c.html
122+
// https://netlib.org/scalapack/explore-html/dd/dcd/pdgemr_8c.html
123+
// https://netlib.org/scalapack/explore-html/d5/dd4/pzgemr_8c.html
124+
// https://netlib.org/scalapack/explore-html/d5/deb/psgemr_8c.html
125+
// https://netlib.org/scalapack/explore-html/d4/dad/pcgemr_8c.html
126+
void Cpigemr2d (int m, int n, int *ptrmyblock, int ia, int ja, int *ma, int *ptrmynewblock, int ib, int jb, int *mb, int globcontext);
127+
void Cpdgemr2d (int m, int n, double *ptrmyblock, int ia, int ja, int *ma, double *ptrmynewblock, int ib, int jb, int *mb, int globcontext);
128+
void Cpzgemr2d (int m, int n, std::complex<double> *ptrmyblock, int ia, int ja, int *ma, std::complex<double> *ptrmynewblock, int ib, int jb, int *mb, int globcontext);
129+
void Cpsgemr2d (int m, int n, float *ptrmyblock, int ia, int ja, int *ma, float *ptrmynewblock, int ib, int jb, int *mb, int globcontext);
130+
void Cpcgemr2d (int m, int n, std::complex<float> *ptrmyblock, int ia, int ja, int *ma, std::complex<float> *ptrmynewblock, int ib, int jb, int *mb, int globcontext);
118131
}
119132

133+
template <typename T>
134+
struct block2d_data_type
135+
{
136+
constexpr static bool value = std::is_same<T, double>::value || std::is_same<T, std::complex<double>>::value || std::is_same<T, float>::value || std::is_same<T, std::complex<float>>::value || std::is_same<T, int>::value;
137+
};
138+
139+
140+
/**
141+
* Copies a 2D block of data from matrix A to matrix B using the Scalapack library.
142+
* This function supports different data types: double, std::complex<double>, float, std::complex<float>, and int.
143+
*
144+
* @tparam T The data type of the matrices A and B.
145+
* @param M The number of rows of matrix A.
146+
* @param N The number of columns of matrix A.
147+
* @param A Pointer to the source matrix A.
148+
* @param IA The starting row index of the block in matrix A.
149+
* @param JA The starting column index of the block in matrix A.
150+
* @param DESCA Descriptor array for matrix A.
151+
* @param B Pointer to the destination matrix B.
152+
* @param IB The starting row index of the block in matrix B.
153+
* @param JB The starting column index of the block in matrix B.
154+
* @param DESCB Descriptor array for matrix B.
155+
* @param ICTXT The context identifier.
156+
*/
157+
template <typename T>
158+
typename std::enable_if<block2d_data_type<T>::value,void>::type Cpxgemr2d(int M, int N, T *A, int IA, int JA, int *DESCA, T *B, int IB, int JB, int *DESCB, int ICTXT)
159+
{
160+
if (std::is_same<T,double>::value) Cpdgemr2d(M, N, reinterpret_cast<double*>(A),IA, JA, DESCA,reinterpret_cast<double*>(B),IB,JB, DESCB,ICTXT);
161+
if (std::is_same<T,std::complex<double>>::value) Cpzgemr2d(M, N, reinterpret_cast<std::complex<double>*>(A),IA, JA, DESCA,reinterpret_cast<std::complex<double>*>(B),IB,JB, DESCB,ICTXT);
162+
if (std::is_same<T,float>::value) Cpsgemr2d(M, N, reinterpret_cast<float*>(A),IA, JA, DESCA,reinterpret_cast<float*>(B),IB,JB, DESCB,ICTXT);
163+
if (std::is_same<T,std::complex<float>>::value) Cpcgemr2d(M, N, reinterpret_cast<std::complex<float>*>(A),IA, JA, DESCA,reinterpret_cast<std::complex<float>*>(B),IB,JB, DESCB,ICTXT);
164+
if (std::is_same<T,int>::value) Cpigemr2d(M, N, reinterpret_cast<int*>(A),IA, JA, DESCA,reinterpret_cast<int*>(B),IB,JB, DESCB,ICTXT);
165+
};
166+
167+
120168
class ScalapackConnector
121169
{
122170
public:

source/module_elecstate/elecstate_lcao.cpp

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -12,81 +12,6 @@
1212
namespace elecstate
1313
{
1414

15-
template <>
16-
void ElecStateLCAO<double>::print_psi(const psi::Psi<double>& psi_in, const int istep)
17-
{
18-
if (!ElecStateLCAO<double>::out_wfc_lcao)
19-
{
20-
return;
21-
}
22-
23-
// output but not do "2d-to-grid" conversion
24-
double** wfc_grid = nullptr;
25-
const int ik = psi_in.get_current_k();
26-
#ifdef __MPI
27-
this->lowf->wfc_2d_to_grid(istep, out_wfc_flag, psi_in.get_pointer(), wfc_grid, ik, this->ekb, this->wg);
28-
#endif
29-
return;
30-
}
31-
32-
template <>
33-
void ElecStateLCAO<std::complex<double>>::print_psi(const psi::Psi<std::complex<double>>& psi_in, const int istep)
34-
{
35-
if (!ElecStateLCAO<std::complex<double>>::out_wfc_lcao
36-
&& !ElecStateLCAO<std::complex<double>>::need_psi_grid)
37-
{
38-
return;
39-
}
40-
41-
// output but not do "2d-to-grid" conversion
42-
std::complex<double>** wfc_grid = nullptr;
43-
int ik = psi_in.get_current_k();
44-
if (ElecStateLCAO<std::complex<double>>::need_psi_grid)
45-
{
46-
wfc_grid = this->lowf->wfc_k_grid[ik];
47-
}
48-
49-
#ifdef __MPI
50-
this->lowf->wfc_2d_to_grid(istep,
51-
ElecStateLCAO<std::complex<double>>::out_wfc_flag,
52-
psi_in.get_pointer(),
53-
wfc_grid,
54-
ik,
55-
this->ekb,
56-
this->wg,
57-
this->klist->kvec_c);
58-
#else
59-
for (int ib = 0; ib < GlobalV::NBANDS; ib++)
60-
{
61-
for (int iw = 0; iw < GlobalV::NLOCAL; iw++)
62-
{
63-
this->lowf->wfc_k_grid[ik][ib][iw] = psi_in(ib, iw);
64-
}
65-
}
66-
#endif
67-
68-
// added by zhengdy-soc, rearrange the wfc_k_grid from [up,down,up,down...] to [up,up...down,down...],
69-
if (ElecStateLCAO<std::complex<double>>::need_psi_grid && GlobalV::NSPIN == 4)
70-
{
71-
int row = this->lowf->gridt->lgd;
72-
std::vector<std::complex<double>> tmp(row);
73-
for (int ib = 0; ib < GlobalV::NBANDS; ib++)
74-
{
75-
for (int iw = 0; iw < row / GlobalV::NPOL; iw++)
76-
{
77-
tmp[iw] = this->lowf->wfc_k_grid[ik][ib][iw * GlobalV::NPOL];
78-
tmp[iw + row / GlobalV::NPOL] = this->lowf->wfc_k_grid[ik][ib][iw * GlobalV::NPOL + 1];
79-
}
80-
for (int iw = 0; iw < row; iw++)
81-
{
82-
this->lowf->wfc_k_grid[ik][ib][iw] = tmp[iw];
83-
}
84-
}
85-
}
86-
87-
return;
88-
}
89-
9015
// multi-k case
9116
template <>
9217
void ElecStateLCAO<std::complex<double>>::psiToRho(const psi::Psi<std::complex<double>>& psi)
@@ -129,15 +54,6 @@ if(!GlobalV::dm_to_rho)
12954
#endif
13055

13156
}
132-
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx" || GlobalV::KS_SOLVER == "lapack"
133-
|| GlobalV::KS_SOLVER == "cusolver" || GlobalV::KS_SOLVER == "cg_in_lcao" || GlobalV::KS_SOLVER == "pexsi")
134-
{
135-
for (int ik = 0; ik < psi.get_nk(); ik++)
136-
{
137-
psi.fix_k(ik);
138-
this->print_psi(psi);
139-
}
140-
}
14157
}
14258
// old 2D-to-Grid conversion has been replaced by new Gint Refactor 2023/09/25
14359
//this->loc->cal_dk_k(*this->lowf->gridt, this->wg, (*this->klist));
@@ -202,12 +118,6 @@ void ElecStateLCAO<double>::psiToRho(const psi::Psi<double>& psi)
202118
for (int ik = 0; ik < psi.get_nk(); ++ik)
203119
{
204120
// for gamma_only case, no convertion occured, just for print.
205-
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx"
206-
|| GlobalV::KS_SOLVER == "cusolver" || GlobalV::KS_SOLVER == "cg_in_lcao")
207-
{
208-
psi.fix_k(ik);
209-
this->print_psi(psi);
210-
}
211121
// old 2D-to-Grid conversion has been replaced by new Gint Refactor 2023/09/25
212122
if (this->loc->out_dm) // keep interface for old Output_DM until new one is ready
213123
{

source/module_elecstate/elecstate_lcao.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ class ElecStateLCAO : public ElecState
6363
// update charge density for next scf step
6464
// void getNewRho() override;
6565

66-
virtual void print_psi(const psi::Psi<TK>& psi_in, const int istep = -1) override;
67-
//virtual void print_psi(const psi::Psi<std::complex<double>>& psi_in, const int istep = -1) override;
68-
6966
// initial density matrix
7067
void init_DM(const K_Vectors* kv, const Parallel_Orbitals* paraV, const int nspin);
7168
DensityMatrix<TK,double>* get_DM() const { return const_cast<DensityMatrix<TK,double>*>(this->DM); }

source/module_elecstate/elecstate_lcao_tddft.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,6 @@ void ElecStateLCAO_TDDFT::psiToRho_td(const psi::Psi<std::complex<double>>& psi)
4242
#endif
4343
}
4444

45-
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx" || GlobalV::KS_SOLVER == "lapack")
46-
{
47-
for (int ik = 0; ik < psi.get_nk(); ik++)
48-
{
49-
psi.fix_k(ik);
50-
this->print_psi(psi);
51-
}
52-
}
5345

5446
for (int is = 0; is < GlobalV::NSPIN; is++)
5547
{

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
#include "module_hamilt_lcao/module_deltaspin/spin_constrain.h"
4444

45+
#include "module_io/write_wfc_nao.h"
46+
4547
namespace ModuleESolver
4648
{
4749

@@ -980,25 +982,17 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(const int istep, const int iter)
980982
}
981983

982984
// 2) print wavefunctions
983-
if (this->conv_elec)
984-
{
985-
if (elecstate::ElecStateLCAO<TK>::out_wfc_lcao)
986-
{
987-
elecstate::ElecStateLCAO<TK>::out_wfc_flag = elecstate::ElecStateLCAO<TK>::out_wfc_lcao;
988-
}
989-
990-
for (int ik = 0; ik < this->kv.get_nks(); ik++)
991-
{
992-
if (istep % GlobalV::out_interval == 0)
993-
{
994-
this->psi[0].fix_k(ik);
995-
this->pelec->print_psi(this->psi[0], istep);
996-
}
997-
}
998-
if (elecstate::ElecStateLCAO<TK>::out_wfc_lcao)
999-
{
1000-
elecstate::ElecStateLCAO<TK>::out_wfc_flag = 0;
1001-
}
985+
if (elecstate::ElecStateLCAO<TK>::out_wfc_lcao &&
986+
(this->conv_elec || iter == GlobalV::SCF_NMAX) &&
987+
(istep % GlobalV::out_interval == 0))
988+
{
989+
ModuleIO::write_wfc_nao(elecstate::ElecStateLCAO<TK>::out_wfc_lcao,
990+
this->psi[0],
991+
this->pelec->ekb,
992+
this->pelec->wg,
993+
this->pelec->klist->kvec_c,
994+
this->orb_con.ParaV,
995+
istep);
1002996
}
1003997

1004998
// 3) print potential

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "module_io/td_current_io.h"
88
#include "module_io/write_HS.h"
99
#include "module_io/write_HS_R.h"
10+
#include "module_io/write_wfc_nao.h"
1011

1112
//--------------temporary----------------------------
1213
#include "module_base/blas_connector.h"
@@ -284,22 +285,17 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
284285
}
285286
}
286287

287-
if (this->conv_elec)
288+
if (elecstate::ElecStateLCAO<std::complex<double>>::out_wfc_lcao &&
289+
(this->conv_elec || iter == GlobalV::SCF_NMAX) &&
290+
(istep % GlobalV::out_interval == 0) )
288291
{
289-
if (elecstate::ElecStateLCAO<std::complex<double>>::out_wfc_lcao)
290-
{
291-
elecstate::ElecStateLCAO<std::complex<double>>::out_wfc_flag
292-
= elecstate::ElecStateLCAO<std::complex<double>>::out_wfc_lcao;
293-
}
294-
for (int ik = 0; ik < kv.get_nks(); ik++)
295-
{
296-
if (istep % GlobalV::out_interval == 0)
297-
{
298-
this->psi[0].fix_k(ik);
299-
this->pelec->print_psi(this->psi[0], istep);
300-
}
301-
}
302-
elecstate::ElecStateLCAO<std::complex<double>>::out_wfc_flag = 0;
292+
ModuleIO::write_wfc_nao(elecstate::ElecStateLCAO<std::complex<double>>::out_wfc_lcao,
293+
this->psi[0],
294+
this->pelec->ekb,
295+
this->pelec->wg,
296+
this->pelec->klist->kvec_c,
297+
this->orb_con.ParaV,
298+
istep);
303299
}
304300

305301
// Calculate new potential according to new Charge Density

0 commit comments

Comments
 (0)