struct MyModel {
    l1: Linear,
    l2: Linear,
}
impl MyModel {
    fn new (mem: &mut Memory) -> MyModel {
        let l1 = Linear::new(mem, 784, 128);
        let l2 = Linear::new(mem, 128, 10);
        Self {
            l1: l1,
            l2: l2,
        }
    }
}
impl Compute for MyModel {
    fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor {
        let mut o = self.l1.forward(mem, input);
        o = o.relu();
        o = self.l2.forward(mem, &o);
        o
    }
}fn main() {
   let (x, y) = load_mnist();
   let mut m = Memory::new();
   let mymodel = MyModel::new(&mut m);
   train(&mut m, &x, &y, &mymodel, 100, 128, cross_entropy, 0.3);
   let out = mymodel.forward(&m, &x);
   println!("Training Accuracy: {}", accuracy(&y, &out));
}MyModel实现了Compute trait,其中定义了forward方法。在main函数中,加载了Mnist数据集,初始化内存,实例化了MyModel,然后使用100个Epochs、批量大小为128、交叉熵损失和学习率为0.3进行训练。
公式1
trait Compute {
    fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor;
}
struct Linear {
    params: HashMap,
}
impl Linear {
    fn new (mem: &mut Memory, ninputs: i64, noutputs: i64) -> Self {
       let mut p = HashMap::new();
       p.insert("W".to_string(), mem.new_push(&[ninputs,noutputs], true));
       p.insert("b".to_string(), mem.new_push(&[1, noutputs], true));
       Self {
           params: p,
       }
    }
}
impl Compute for Linear {
    fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor {
        let w = mem.get(self.params.get(&"W".to_string()).unwrap());
        let b = mem.get(self.params.get(&"b".to_string()).unwrap());
        input.matmul(w) + b
    }
} fn mse(target: &Tensor, pred: &Tensor) -> Tensor {
   (target - pred).square().mean(Kind::Float)
}
fn cross_entropy (target: &Tensor, pred: &Tensor) -> Tensor {
   let loss = pred.log_softmax(-1, Kind::Float).nll_loss(target);
   loss
}
公式2

公式3
struct Memory {
    size: usize,
    values: Vec,
}
impl Memory {
    fn new() -> Self {
        let v = Vec::new();
        Self {size: 0,
            values: v}
}
    fn push (&mut self, value: Tensor) -> usize {
        self.values.push(value);
        self.size += 1;
        self.size-1
}
    fn new_push (&mut self, size: &[i64], requires_grad: bool) -> usize {
        let t = Tensor::randn(size, (Kind::Float, Device::Cpu)).requires_grad_(requires_grad);
        self.push(t)
}
    fn get (&self, addr: &usize) -> &Tensor {
        &self.values[*addr]
}
    fn apply_grads_sgd(&mut self, learning_rate: f32) {
        let mut g = Tensor::new();
        self.values
        .iter_mut()
        .for_each(|t| {
            if t.requires_grad() {
                g = t.grad();
                t.set_data(&(t.data() - learning_rate*&g));
                t.zero_grad();
            }
        });
    }
    fn apply_grads_sgd_momentum(&mut self, learning_rate: f32) {
        let mut g: Tensor = Tensor::new();
        let mut velocity: Vec= Tensor::zeros(&[self.size as i64], (Kind::Float, Device::Cpu)).split(1, 0);
        let mut vcounter = 0;
        const BETA:f32 = 0.9;
        self.values
        .iter_mut()
        .for_each(|t| {
            if t.requires_grad() {
                g = t.grad();
                velocity[vcounter] = BETA * &velocity[vcounter] + (1.0 - BETA) * &g;
                t.set_data(&(t.data() - learning_rate * &velocity[vcounter]));
                t.zero_grad();
            }
            vcounter += 1;
        });
    }
}  fn train(mem: &mut Memory, x: &Tensor, y: &Tensor, model: &dyn Compute, epochs: i64, batch_size: i64, errfunc: F, learning_rate: f32) 
where F: Fn(&Tensor, &Tensor)-> Tensor
{
let mut error = Tensor::from(0.0);
let mut batch_error = Tensor::from(0.0);
let mut pred = Tensor::from(0.0);
for epoch in 0..epochs {
batch_error = Tensor::from(0.0);
for (batchx, batchy) in get_batches(&x, &y, batch_size, true) {
pred = model.forward(mem, &batchx);
error = errfunc(&batchy, &pred);
batch_error += error.detach();
error.backward();
mem.apply_grads_sgd_momentum(learning_rate);
}
println!("Epoch: {:?} Error: {:?}", epoch, batch_error/batch_size);
}
}
fn get_batches(x: &Tensor, y: &Tensor, batch_size: i64, shuffle: bool) -> impl Iterator- {
let num_rows = x.size()[0];
let num_batches = (num_rows + batch_size - 1) / batch_size;
let indices = if shuffle {
Tensor::randperm(num_rows as i64, (Kind::Int64, Device::Cpu))
} else
{
let rng = (0..num_rows).collect::>(); 
Tensor::from_slice(&rng)
};
let x = x.index_select(0, &indices);
let y = y.index_select(0, &indices);
(0..num_batches).map(move |i| {
let start = i * batch_size;
let end = (start + batch_size).min(num_rows);
let batchx: Tensor = x.narrow(0, start, end - start);
let batchy: Tensor = y.narrow(0, start, end - start);
(batchx, batchy)
})
}
fn load_mnist() -> (Tensor, Tensor) {
   let m = vision::mnist::load_dir("data").unwrap();
   let x = m.train_images;
   let y = m.train_labels;
   (x, y)
}
fn accuracy(target: &Tensor, pred: &Tensor) -> f64 {
   let yhat = pred.argmax(1,true).squeeze();
   let eq = target.eq_tensor(&yhat);
   let accuracy: f64 = (eq.sum(Kind::Int64) / target.size()[0]).double_value(&[]).into();
   accuracy
}use std::{collections::HashMap};
use tch::{Tensor, Kind, Device, vision, Scalar};