save.py 373 B

12345678910111213
  1. def save(model, filename):
  2. def traverse(tensor, f):
  3. for i in tensor:
  4. if len(i.shape) == 0:
  5. f.write(struct.pack(">f", (float)(i.data)))
  6. else:
  7. traverse(i, f)
  8. f = open(filename, 'wb')
  9. for _,param in enumerate(model.named_parameters()):
  10. traverse(param[1], f)
  11. print(str(param[0]))