MultiAgentDecisionProcess  Release 0.2.1
MDPSolver.cpp
Go to the documentation of this file.
1 
28 #include "MDPSolver.h"
29 #include <float.h>
30 #include <fstream>
31 #include <limits.h>
33 #include "JointBeliefInterface.h"
34 #include "JointAction.h"
35 #include "State.h"
36 #include "BeliefIteratorGeneric.h"
37 
38 using namespace std;
39 
40 //Destructor
42 {
43 }
44 
45 double MDPSolver::GetQ(Index time_step, const JointBeliefInterface& jb,
46  Index jaI) const
47 {
48  double Q = 0.0;
49 #if USE_BeliefIteratorGeneric
51  do Q+=it.GetProbability() * GetQ(time_step,it.GetStateIndex(),jaI);
52  while(it.Next());
53 #else
54  for(Index sI=0; sI < jb.Size(); sI++)
55  Q += jb.Get(sI) * GetQ(time_step,sI,jaI);
56 #endif
57  return(Q);
58 }
59 
61  Index jaI) const
62 {
63  double Q = 0.0;
64 #if USE_BeliefIteratorGeneric
66  do Q+=it.GetProbability() * GetQ(it.GetStateIndex(),jaI);
67  while(it.Next());
68 #else
69  for(Index sI=0; sI < jb.Size(); sI++)
70  Q += jb.Get(sI) * GetQ(sI,jaI);
71 #endif
72 
73  return(Q);
74 }
75 
76 void MDPSolver::Print() const
77 {
78  size_t horizon = GetPU()->GetHorizon();
79  size_t nrS = GetPU()->GetNrStates();
80  size_t nrJA = GetPU()->GetNrJointActions();
81 
82  cout << "States: ";
83  for(Index sI = 0; sI < nrS; sI++)
84  cout << _m_pu->GetState(sI)->SoftPrintBrief() << " ";
85  cout << endl;
86 
87  if(horizon!=MAXHORIZON)
88  {
89  for(size_t t = 0; t!=horizon; t++)
90  for(Index jaI = 0; jaI < nrJA; jaI++)
91  {
92  cout << "Q(t=" << t << ",:," << jaI << ") =\t";
93  for(Index sI = 0; sI < nrS; sI++)
94  cout << " " << GetQ(t,sI,jaI);
95  cout << " " << _m_pu->GetJointAction(jaI)->SoftPrintBrief();
96  cout << endl;
97  }
98  }
99  else
100  {
101  for(Index jaI = 0; jaI < nrJA; jaI++)
102  {
103  cout << "Q(:," << jaI << ") =\t";
104  for(Index sI = 0; sI < nrS; sI++)
105  cout << " " << GetQ(sI,jaI);
106  cout << " " << _m_pu->GetJointAction(jaI)->SoftPrintBrief();
107  cout << endl;
108  }
109 
110  cout << "Policy: ";
111  for(Index sI = 0; sI < nrS; sI++)
112  {
113  cout << _m_pu->GetState(sI)->SoftPrintBrief() << "->";
114  double q,v=-DBL_MAX;
115  Index aMax=INT_MAX;
116 
117  for(Index jaI = 0; jaI < nrJA; jaI++)
118  {
119  q=GetQ(sI,jaI);
120  if(q>v)
121  {
122  v=q;
123  aMax=jaI;
124  }
125  }
126  cout << _m_pu->GetJointAction(aMax)->SoftPrintBrief() << " ";
127  }
128  cout << endl;
129  }
130 }
131 
133 {
134  double q,v=-DBL_MAX;
135  Index aMax=INT_MAX;
136 
137  for(size_t a=0;a!=GetPU()->GetNrJointActions();++a)
138  {
139  q=GetQ(time_step,sI,a);
140  if(q>v)
141  {
142  v=q;
143  aMax=a;
144  }
145  }
146 
147  return(aMax);
148 }
149 
151 {
152  return(LoadQTable(filename,
153  GetPU()->GetNrStates(),
154  GetPU()->GetNrJointActions()));
155 }
156 
158  unsigned int nrS,
159  unsigned int nrA)
160 {
161  const int bufsize=65536;
162  char buffer[bufsize];
163 
164  ifstream fp(filename.c_str());
165  if(!fp)
166  {
167  cerr << "MDPSolver::LoadQTable: failed to "
168  << "open file " << filename << endl;
169  }
170 
171  size_t a,s;
172  double q;
173 
174  QTable Q(nrS,nrA);
175 
176  s=0;
177  while(!fp.getline(buffer,bufsize).eof())
178  {
179  istringstream is(buffer);
180  a=0;
181  while(is >> q)
182  Q(s,a++)=q;
183 
184  if(a!=nrA)
185  throw(E("MDPSolver::LoadQTable wrong number of actions"));
186 
187  s++;
188  }
189 
190  if(s!=nrS)
191  throw(E("MDPSolver::LoadQTable wrong number of states"));
192 
193  return(Q);
194 }
195 
196 QTables MDPSolver::LoadQTables(string filename, int nrTables)
197 {
198  return(LoadQTables(filename,
199  GetPU()->GetNrStates(),
200  GetPU()->GetNrJointActions(),
201  nrTables));
202 }
203 
205  unsigned int nrS,
206  unsigned int nrA,
207  unsigned int nrTables)
208 {
209  const int bufsize=65536;
210  char buffer[bufsize];
211 
212  ifstream fp(filename.c_str());
213  if(!fp)
214  {
215  cerr << "MDPSolver::LoadQTables: failed to "
216  << "open file " << filename << endl;
217  }
218 
219  size_t a,s,i;
220  double q;
221 
222  QTable Q(nrS,nrA);
223  QTables Qs;
224  for(i=0;i!=nrTables;i++)
225  Qs.push_back(Q);
226 
227  s=0;
228  i=0;
229  while(!fp.getline(buffer,bufsize).eof())
230  {
231  istringstream is(buffer);
232  a=0;
233  while(is >> q)
234  Qs[i](s,a++)=q;
235 
236  if(a!=nrA)
237  throw(E("MDPSolver::LoadQTables wrong number of actions"));
238 
239  s++;
240  if(s==nrS)
241  {
242  i++;
243  s=0;
244  }
245  }
246 
247  if(i!=nrTables)
248  throw(E("MDPSolver::LoadQTables wrong number of tables"));
249 
250  return(Qs);
251 }
252 
253 void MDPSolver::SaveQTable(const QTable &Q, string filename)
254 {
255  ofstream fp(filename.c_str());
256  if(!fp)
257  {
258  stringstream ss;
259  ss << "MDPSolver::SaveQTable: failed to open file " << filename << endl;
260  throw E(ss.str());
261  }
262 
263  fp.precision(16);
264 
265  unsigned int nrS=Q.size1(),
266  nrA=Q.size2();
267 
268  for(unsigned int s=0;s!=nrS;++s)
269  {
270  for(unsigned int a=0;a!=nrA;++a)
271  {
272  fp << Q(s,a);
273  if(a!=nrA-1)
274  fp << " ";
275  }
276  fp << endl;
277  }
278 }
279 
280 void MDPSolver::SaveQTables(const QTables &Qs, string filename)
281 {
282  ofstream fp(filename.c_str());
283  if(!fp)
284  {
285  stringstream ss;
286  ss << "MDPSolver::SaveQTables: failed to open file " << filename << endl;
287  throw E(ss);
288  }
289 
290  fp.precision(16);
291 
292  unsigned int nrS=Qs[0].size1(),
293  nrA=Qs[0].size2(),
294  h=Qs.size();
295 
296  for(unsigned int k=0;k!=h;++k)
297  for(unsigned int s=0;s!=nrS;++s)
298  {
299  for(unsigned int a=0;a!=nrA;++a)
300  {
301  fp << Qs[k](s,a);
302  if(a!=nrA-1)
303  fp << " ";
304  }
305  fp << endl;
306  }
307 }