riskreg.hpp
Go to the documentation of this file.
1 
10 #pragma once
11 
12 #include <string>
13 #include <complex>
14 #include <memory> // smart pointers (unique_ptr)
15 #include "target.hpp"
16 #include "glm.hpp"
17 #include "utils.hpp"
18 
19 
20 enum Data {Y, A, X1, X2, X3, W};
21 
22 class RiskReg {
23 public:
24  using cx_dbl = target::cx_dbl;
25  using cx_func = target::cx_func;
26 
27  RiskReg() {}
28  RiskReg(const arma::vec &y, const arma::vec &a,
29  const arma::mat &x1, const arma::mat &x2, const arma::mat &x3,
30  const arma::vec &weights, std::string model) {
31  this->type = model;
32  setData(y, a, x1, x2, x3, weights);
33  }
34  ~RiskReg() {}
35 
36  void setData(const arma::vec &y, const arma::vec &a,
37  const arma::mat &x1, const arma::mat &x2, const arma::mat &x3,
38  const arma::vec &weights) {
39  theta = arma::zeros(x1.n_cols + x2.n_cols + x3.n_cols);
40  if (this->type.compare("rr") == 0) {
41  this->model.reset(new target::RR<double>(y, a,
42  x1, x2, x3,
43  theta, weights));
44  } else {
45  this->model.reset(new target::RD<double>(y, a,
46  x1, x2, x3,
47  theta, weights));
48  }
49  }
50 
51  void weights(const arma::vec &weights) {
52  model->weights(weights);
53  }
54 
55  void update(arma::vec &par) {
56  for (unsigned i=0; i < par.n_elem; i++)
57  this->theta(i) = par(i);
58  model->update_par(par);
59  model->calculate(true, true, true);
60  }
61  double logl() {
62  return model->loglik(false)[0];
63  }
64  arma::mat dlogl(bool indiv = false) {
65  return model->score(indiv);
66  }
67  arma::mat pr() {
68  return model->pa();
69  }
70 
71  arma::cx_mat score(arma::cx_vec theta) {
72  model_c->update_par(theta);
73  model_c->calculate(true, true, false);
74  return model_c->score(false);
75  }
76 
77  arma::mat esteq(arma::vec &alpha, arma::vec &pr) {
78  arma::mat res = model->est(alpha, pr);
79  return res;
80  }
81 
82  arma::mat hessian() {
83  arma::cx_vec Yc = arma::conv_to<arma::cx_vec>::from(model->Y());
84  arma::cx_vec Ac = arma::conv_to<arma::cx_vec>::from(model->A());
85  arma::cx_mat X1c = arma::conv_to<arma::cx_mat>::from(model->X1());
86  arma::cx_mat X2c = arma::conv_to<arma::cx_mat>::from(model->X2());
87  arma::cx_mat X3c = arma::conv_to<arma::cx_mat>::from(model->X3());
88  arma::cx_vec thetac = arma::conv_to<arma::cx_vec>::from(theta);
89  arma::cx_vec Wc = arma::conv_to<arma::cx_vec>::from(model->weights());
90  if (this->type.compare("rr") == 0) {
91  model_c.reset(new target::RR<cx_dbl>(Yc, Ac, X1c, X2c, X3c, thetac, Wc));
92  } else {
93  model_c.reset(new target::RD<cx_dbl>(Yc, Ac, X1c, X2c, X3c, thetac, Wc));
94  }
95  arma::mat res = target::deriv(std::bind(&RiskReg::score,
96  this,
97  std::placeholders::_1), theta);
98  return res;
99  }
100 
101  arma::mat operator()(Data index) const {
102  switch (index) {
103  case Y:
104  return model->Y();
105  case A:
106  return model->A();
107  case X1:
108  return model->X1();
109  case X2:
110  return model->X2();
111  case X3:
112  return model->X3();
113  case W:
114  break;
115  }
116  return model->weights();
117  }
118 
119 protected:
120  std::unique_ptr< target::TargetBinary<double> > model;
121  std::unique_ptr< target::TargetBinary<cx_dbl> > model_c;
122  arma::vec theta;
123  std::string type;
124 };
Classes for targeted inference models.
Utility functions for Generalized Linear Models.
Various utility functions and constants.