[.NET]快快樂樂學LINQ系列-Aggregate() 簡介
前言
這篇文章打算從實務的一個例子,來介紹 Aggregate()
。
Aggregate 的意思就是合計,那跟 Sum()
有什麼不同呢?實際上 Aggregate()
當然可以做到 Sum()
的功能,而且擁有更大的彈性。
這個方法讓開發人員可以自行選定,在巡覽每一次的 iteration 時,暫存一個特定的值來跟這一次的 current item 進行結合、處理或運算。
而在這個說明的基礎底下, Sum()
只是其中的一種應用,暫存的特定值就是最後 Sum() 的結果,而巡覽的 iteration 時,將 item 透過 selector 投射出來的值,與暫存的結果進行加總,這就是 Sum()
。同樣地, Max()
與 Min()
也可以用 Aggregate()
來設計,暫存的值仍是最後的結果,以 Max()
為例,那只需要把 Sum()
裡面,原本用來加總的動作,改成比大小,比較大的,就放到暫存結果。
那麼,什麼情況用 Aggregate()
會比用 Sum()
來得有效率呢?請見下面的範例。
範例
需求介紹:
- 日結表的資料中,會存放每一天有異動的庫存量。
- 當進行月結時,需要將各區的異動庫存量進行加總,產生一筆月結資料。
Scenario 如下:
可以看到五月份有三筆日結資料,當針對五月份進行月結結轉時,期望月結資料的良品區、不良品區與客退區的數量,會是五月份三筆日結的加總。
測試程式與自動產生的 production code 如下:
using System.Collections.Generic;
using System.Linq;
using Rhino.Mocks;
using TechTalk.SpecFlow;
using TechTalk.SpecFlow.Assist;
namespace AggregateSample
{
[Binding]
[Scope(Feature = "StockManagement")]
public class StockManagementSteps
{
private MonthlyStockSettlementService target;
private IDailyStockDao dailyStockDao;
[BeforeScenario]
public void BeforeScenario()
{
this.dailyStockDao = MockRepository.GenerateStub<IDailyStockDao>();
this.target = new MonthlyStockSettlementService(dailyStockDao);
}
[Given(@"欲結算年月為 (.*)")]
public void Given欲結算年月為(string yearMonth)
{
ScenarioContext.Current.Set<string>(yearMonth, "yearMonth");
}
[Given(@"ProductId為 (.*)")]
public void GivenProductId為(string productId)
{
ScenarioContext.Current.Set<string>(productId, "id");
}
[Given(@"日結資料為")]
public void Given日結資料為(Table table)
{
var dailyStockSettlements = table.CreateSet<DailyStockSettlement>();
var yearmonth = ScenarioContext.Current.Get<string>("yearMonth");
var productId = ScenarioContext.Current.Get<string>("id");
this.dailyStockDao.Stub(x => x.GetDailyStocksByYearMonth(yearmonth, productId)).Return(dailyStockSettlements);
}
[When(@"呼叫月結結轉")]
public void When呼叫月結結轉()
{
var yearMonth = ScenarioContext.Current.Get<string>("yearMonth");
var productId = ScenarioContext.Current.Get<string>("id");
MonthlyStockSettlement actual = this.target.Snapshot(yearMonth, productId);
ScenarioContext.Current.Set<MonthlyStockSettlement>(actual);
}
[Then(@"月結資料應為")]
public void Then月結資料應為(Table table)
{
var actual = ScenarioContext.Current.Get<MonthlyStockSettlement>();
table.CompareToInstance(actual);
}
}
public class MonthlyStockSettlementService
{
private IDailyStockDao dailyStockDao;
public MonthlyStockSettlementService(IDailyStockDao dailyStockDao)
{
// TODO: Complete member initialization
this.dailyStockDao = dailyStockDao;
}
internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
{
throw new System.NotImplementedException();
}
}
public class MonthlyStockSettlement
{
//| ProductId | YearMonth | QualifiedProductSection | DefectProductSection | ReturnGoodSection |
}
public interface IDailyStockDao
{
IEnumerable<DailyStockSettlement> GetDailyStocksByYearMonth(string yearmonth, string productId);
}
public class DailyStockSettlement
{
//| ProductId | Date | QualifiedProductSection | DefectProductSection | ReturnGoodSection |
}
}
上面 production code 的部分,全都是從測試程式撰寫過程自動產生出來的。而接下來要做的事情,只需要將 Snapshot()
完成且通過測試即可。
針對三個欄位做 Sum()
第一個作法,是針對三個欄位做 Sum()
的動作,程式碼如下:
public class MonthlyStockSettlementService
{
private IDailyStockDao dailyStockDao;
public MonthlyStockSettlementService(IDailyStockDao dailyStockDao)
{
this.dailyStockDao = dailyStockDao;
}
internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
{
var dailyStockSettlements = this.dailyStockDao.GetDailyStocksByYearMonth(yearMonth, productId);
var result = new MonthlyStockSettlement
{
ProductId = productId,
YearMonth = yearMonth,
QualifiedProductSection = dailyStockSettlements.Sum(x => x.QualifiedProductSection),
DefectProductSection = dailyStockSettlements.Sum(x => x.DefectProductSection),
ReturnGoodSection = dailyStockSettlements.Sum(x => x.ReturnGoodSection)
};
return result;
}
}
public class MonthlyStockSettlement
{
//| ProductId | YearMonth | QualifiedProductSection | DefectProductSection | ReturnGoodSection |
public string ProductId { get; set; }
public string YearMonth { get; set; }
public int QualifiedProductSection { get; set; }
public int DefectProductSection { get; set; }
public int ReturnGoodSection { get; set; }
}
public interface IDailyStockDao
{
IEnumerable<DailyStockSettlement> GetDailyStocksByYearMonth(string yearmonth, string productId);
}
public class DailyStockSettlement
{
//| ProductId | Date | QualifiedProductSection | DefectProductSection | ReturnGoodSection |
public string ProductId { get; set; }
public DateTime Date { get; set; }
public int QualifiedProductSection { get; set; }
public int DefectProductSection { get; set; }
public int ReturnGoodSection { get; set; }
}
為了取得月結的結果,上述程式碼針對「良品區」、「不良品區」、「退貨區」的三個欄位,用了 3 次 Sum()
進行加總,看起來好像很酷,但其實為了三個欄位的加總,原本「只需要針對日結表資料用一次 loop 針對三個欄位加總」的動作,現在卻用了 3 次 loop ,這是不合理的。
使用 Loop 來做
用 foreach loop 反而只要透過一個暫存結果,跑一次 loop 的動作而已,程式碼如下:
internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
{
var dailyStockSettlements = this.dailyStockDao.GetDailyStocksByYearMonth(yearMonth, productId);
//var result = new MonthlyStockSettlement
//{
// ProductId = productId,
// YearMonth = yearMonth,
// QualifiedProductSection = dailyStockSettlements.Sum(x => x.QualifiedProductSection),
// DefectProductSection = dailyStockSettlements.Sum(x => x.DefectProductSection),
// ReturnGoodSection = dailyStockSettlements.Sum(x => x.ReturnGoodSection)
//};
var result = new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth };
foreach (var dailyStockSettlement in dailyStockSettlements)
{
result.QualifiedProductSection += dailyStockSettlement.QualifiedProductSection;
result.DefectProductSection += dailyStockSettlement.DefectProductSection;
result.ReturnGoodSection += dailyStockSettlement.ReturnGoodSection;
}
return result;
}
看起來很簡單也不難懂,對吧?的確,以加總搭配這麼簡單的需求,沒有太大的差異,但還是來看一下,針對這樣的結構與需求,可以使用 LINQ 的方式來取代迴圈的作業。更後面的段落,則來看這樣的執行方式,是如何抽象成 Aggregate()
的 function 。
使用 Aggregate() 來做
接著透過 Aggregate()
來取代原本迴圈的程式碼,程式碼如下:
internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
{
var dailyStockSettlements = this.dailyStockDao.GetDailyStocksByYearMonth(yearMonth, productId);
//var result = new MonthlyStockSettlement
//{
// ProductId = productId,
// YearMonth = yearMonth,
// QualifiedProductSection = dailyStockSettlements.Sum(x => x.QualifiedProductSection),
// DefectProductSection = dailyStockSettlements.Sum(x => x.DefectProductSection),
// ReturnGoodSection = dailyStockSettlements.Sum(x => x.ReturnGoodSection)
//};
//var result = new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth };
//foreach (var dailyStockSettlement in dailyStockSettlements)
//{
// result.QualifiedProductSection += dailyStockSettlement.QualifiedProductSection;
// result.DefectProductSection += dailyStockSettlement.DefectProductSection;
// result.ReturnGoodSection += dailyStockSettlement.ReturnGoodSection;
//}
//return result;
return dailyStockSettlements.Aggregate(new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth },
(result, d) =>
{
result.QualifiedProductSection += d.QualifiedProductSection;
result.DefectProductSection += d.DefectProductSection;
result.ReturnGoodSection += d.ReturnGoodSection;
return result;
});
}
看到跟 foreach loop 的差異如下:
- 把 loop 外面的那一行暫存結果,放到第一個參數
-
第二個參數是一個
Func<T1, T2, T1>
的委派。T1 指的就是第一個參數那個暫存結果,T2 則是 foreach 巡覽的每一個 item 。
就只是一種把一堆實作細節抽象出來,用更有彈性的方式來取代而已。
上面的 Aggregate() 該怎麼自己寫
有了 foreach loop 與 Aggregate()
的對應,相信要自己寫出 LINQ 的方法,應該也不是件難事吧。這邊先貼上這個方法的簽章:
public static TAccumulate Aggregate<TSource, TAccumulate>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func);
接著,寫一個自己的 MyAggregate() 來取代原本 LINQ 的 Aggregate()
,程式碼如下:
internal MonthlyStockSettlement Snapshot(string yearMonth, string productId)
{
var dailyStockSettlements = this.dailyStockDao.GetDailyStocksByYearMonth(yearMonth, productId);
//var result = new MonthlyStockSettlement
//{
// ProductId = productId,
// YearMonth = yearMonth,
// QualifiedProductSection = dailyStockSettlements.Sum(x => x.QualifiedProductSection),
// DefectProductSection = dailyStockSettlements.Sum(x => x.DefectProductSection),
// ReturnGoodSection = dailyStockSettlements.Sum(x => x.ReturnGoodSection)
//};
//var result = new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth };
//foreach (var dailyStockSettlement in dailyStockSettlements)
//{
// result.QualifiedProductSection += dailyStockSettlement.QualifiedProductSection;
// result.DefectProductSection += dailyStockSettlement.DefectProductSection;
// result.ReturnGoodSection += dailyStockSettlement.ReturnGoodSection;
//}
//return result;
//return dailyStockSettlements.Aggregate(new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth },
// (result, d) =>
// {
// result.QualifiedProductSection += d.QualifiedProductSection;
// result.DefectProductSection += d.DefectProductSection;
// result.ReturnGoodSection += d.ReturnGoodSection;
// return result;
// });
return dailyStockSettlements.MyAggregate(new MonthlyStockSettlement { ProductId = productId, YearMonth = yearMonth },
(result, d) =>
{
result.QualifiedProductSection += d.QualifiedProductSection;
result.DefectProductSection += d.DefectProductSection;
result.ReturnGoodSection += d.ReturnGoodSection;
return result;
});
}
public static class MyLinqExtension
{
public static TAccumulate MyAggregate<TSource, TAccumulate>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func)
{
var result = seed;
foreach (var item in source)
{
result = func(result, item);
}
return result;
}
}
抽象的過程只是進行下面幾個步驟:
-
把
DailyStockSettlement
換成泛型TSource
-
把
MonthlyStockSettlement
換成泛型TAccumulate
-
把 foreach 裡面要做的事情,換成
Func<TAccumulate, TSource, TAccumulate>
-
把 foreach 巡覽
IEnumerable<DailyStockSettlement>
封裝到 extension method 中
就可以讓這種 foreach loop + 迴圈外面放一個暫存的結果,透過泛型 + 匿名委派 + Lambda 的方式抽象成各種型別跟各種處理都能使用,是不是很神奇呢?
這邊也順手把三種簽章都寫出來,只是刻意不把重複的程式碼重構。
public static class MyLinqExtension
{
public static TSource MyAggregate<TSource>(this IEnumerable<TSource> source, Func<TSource, TSource, TSource> func)
{
// 因為要把第一個item當初始值,所以用原始的 iterator 寫法比較有效率,用 foreach 看起來很醜
using (var iterator = source.GetEnumerator())
{
if (!iterator.MoveNext())
{
throw new InvalidOperationException("Source seqence was empty");
}
//第一個item
var result = iterator.Current;
while (iterator.MoveNext())
{
var next = iterator.Current;
result = func(result, next);
}
return result;
}
}
public static TAccumulate MyAggregate<TSource, TAccumulate>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func)
{
var result = seed;
foreach (var item in source)
{
result = func(result, item);
}
return result;
}
public static TResult MyAggregate<TSource, TAccumulate, TResult>(this IEnumerable<TSource> source, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func, Func<TAccumulate, TResult> resultSelector)
{
var result = seed;
foreach (var item in source)
{
result = func(result, item);
}
return resultSelector(result);
}
}
很多概念都是互通的,看懂了簽章,知道沒有 LINQ 的時候怎麼用 foreach 寫,接下來怎麼把 foreach 的「使用方式」抽象化,就是這一堆 LINQ to Objects 的方法了。
結論
自己在實務上碰到的例子,比這複雜很多,所以我自己一開始也是寫了 3 個 Sum()
,赫然發現很蠢,LINQ 應該要有對應的方法來幫我解決這樣的需求。但我真的忘了是哪一個了,回推回原始的 foreach loop 可能怎麼寫時,我才想到這樣的 foreach loop 使用方式,就是用 Aggregate()
來取代。
所以,希望大家不要只為了讓程式碼看起來好像有在用 LINQ ,看起來很酷,而忽略了每一個 Sum()
其實都是完整的走完 IEnumerable<TSource>
一輪的動作。
by the way, 讀者可以自己練習,用 Aggregate()
來做出 Max()
, Min()
跟 Sum()
囉。
blog 與課程更新內容,請前往新站位置:http://tdd.best/